Importing all necessary packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

Model Architecture

In [None]:
class AODNet(nn.Module):
    def __init__(self):
        super(AODNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 3, kernel_size=1, bias=True)
        self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True)
        self.conv3 = nn.Conv2d(3, 3, kernel_size=5, padding=2, bias=True)
        self.conv4 = nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True)
        self.conv5 = nn.Conv2d(3, 3, kernel_size=1, bias=True)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = self.conv5(x)
        return x

Setting up dataset

In [None]:
class DehazingDataset(Dataset):
    def __init__(self, root_dir, transform=transforms.ToTensor()):
        self.root_dir = root_dir
        self.transform = transform
        self.images = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path)
        image = self.transform(image)
        return image

#need to change
dataset = DehazingDataset('path/to/dataset')

In [None]:
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

Training loop

In [None]:
Copydevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AODNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(data_loader, 0):
        inputs, _ = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss / len(data_loader)}')

Dehaze the image 

In [None]:
def dehaze_image(image_path):
    image = Image.open(image_path)
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0)
    output = model(image)
    output = output.squeeze(0)
    output = output.detach().numpy()
    output = np.transpose(output, (1, 2, 0))
    return output

image_path = 'path/to/test/image.jpg'
dehazed_image = dehaze_image(image_path)
plt.imshow(dehazed_image)
plt.show()