Importing all necessary packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os

Model Architecture

In [None]:
class AODNet(nn.Module):
    def __init__(self):
        super(AODNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 3, kernel_size=1, bias=True)
        self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True)
        self.conv3 = nn.Conv2d(3, 3, kernel_size=5, padding=2, bias=True)
        self.conv4 = nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True)
        self.conv5 = nn.Conv2d(3, 3, kernel_size=1, bias=True)
        
        self.b = 1
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        cat1 = torch.cat((x1, x2), 1)
        x3 = self.conv3(cat1)
        cat2 = torch.cat((x2, x3), 1)
        x4 = self.conv4(cat2)
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = self.conv5(cat3)

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return torch.clamp(output, 0, 1)

Setting up dataset

In [None]:
class DehazeDataset(Dataset):
    def __init__(self, hazy_dir, clear_dir, transform=None):
        self.hazy_dir = hazy_dir
        self.clear_dir = clear_dir
        self.transform = transform
        self.image_names = os.listdir(hazy_dir)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        hazy_path = os.path.join(self.hazy_dir, self.image_names[idx])
        clear_path = os.path.join(self.clear_dir, self.image_names[idx])

        hazy_image = Image.open(hazy_path).convert('RGB')
        clear_image = Image.open(clear_path).convert('RGB')

        if self.transform:
            hazy_image = self.transform(hazy_image)
            clear_image = self.transform(clear_image)

        return hazy_image, clear_image

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Create dataset (needs change)
dataset = DehazeDataset('path/to/hazy/images', 'path/to/clear/images', transform=transform)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

Training loop

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AODNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, dataloader, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for hazy_inputs, clear_targets in dataloader:
            hazy_inputs = hazy_inputs.to(device)
            clear_targets = clear_targets.to(device)

            optimizer.zero_grad()
            outputs = model(hazy_inputs)
            loss = criterion(outputs, clear_targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * hazy_inputs.size(0)

        epoch_loss = running_loss / len(dataloader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

        # Save model every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'aodnet_epoch_{epoch+1}.pth')

    print('Training complete')
    return model

In [None]:
trained_model = train_model(model, dataloader, criterion, optimizer, num_epochs=25)

In [None]:
torch.save(trained_model.state_dict(), 'aodnet_final.pth')


Dehaze the image 

In [None]:
def dehaze_image(model, image_path):
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    
    with torch.no_grad():
        hazy_image = Image.open(image_path).convert('RGB')
        hazy_tensor = transform(hazy_image).unsqueeze(0).to(device)
        dehazed_tensor = model(hazy_tensor)
    
    # Convert tensor to image
    dehazed_image = transforms.ToPILImage()(dehazed_tensor.squeeze().cpu())
    
    # Display images
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.imshow(hazy_image)
    ax1.set_title('Hazy Image')
    ax1.axis('off')
    ax2.imshow(dehazed_image)
    ax2.set_title('Dehazed Image')
    ax2.axis('off')
    plt.show()

Dehaze

In [None]:
# Load the trained model
model.load_state_dict(torch.load('aodnet_final.pth'))

# Dehaze an image (replace with your image path)
dehaze_image(model, 'path/to/test/hazy_image.jpg')