In [4]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from torchvision.utils import make_grid

import os,sys
import random
from PIL import Image

from metrics import *

In [None]:
# Dataset for RESIDE
# in each dataset there are two folders: 'clear' and 'hazy'

class RESIDE(torch.utils.data.Dataset):
    def __init__(self, root,train,format='.png'):
        super(RESIDE, self).__init__()
        self.root = root
        self.train = train
        self.format = format
        self.hazy_dir = os.path.join(root,'hazy')
        self.clear_dir = os.path.join(root,'clear')
        self.hazy_imgs = [os.path.join(self.hazy_dir,x) for x in os.listdir(self.hazy_dir) if x.endswith(format)]
        self.clear_imgs = [os.path.join(self.clear_dir,x) for x in os.listdir(self.clear_dir) if x.endswith(format)]

    def __getitem__(self, index):
        hazy_img = Image.load(self.hazy_imgs[index])
        # as the name convention between hazy and clear is not same, we find it by it's file name.
        clear_img_id = self.hazy_imgs[index].split('/')[-1].split('_')[0] 
        clear_img = Image.load(
            os.path.join(self.clear_dir,clear_img_id+self.format)
        )
        hazy_img = T.ToTensor()(hazy_img)
        clear_img = T.ToTensor()(clear_img)
        return hazy_img, clear_img

    def __len__(self):
        return len(self.hazy_imgs)

In [None]:
# Dataset for HAZE100
# in each dataset there are two folders: 'clear' and 'hazy'

class RESIDE(torch.utils.data.Dataset):
    def __init__(self, root,train,format='.png'):
        super(RESIDE, self).__init__()
        self.root = root
        self.train = train
        self.format = format
        self.hazy_dir = os.path.join(root,'hazy')
        self.clear_dir = os.path.join(root,'clear')
        self.hazy_imgs = [os.path.join(self.hazy_dir,x) for x in os.listdir(self.hazy_dir) if x.endswith(format)]
        self.clear_imgs = [os.path.join(self.clear_dir,x) for x in os.listdir(self.clear_dir) if x.endswith(format)]

    def __getitem__(self, index):
        hazy_img = Image.load(self.hazy_imgs[index])
        # as the name convention between hazy and clear is not same, we find it by it's file name.
        clear_img_id = self.hazy_imgs[index].split('/')[-1].split('_')[0] 
        clear_img = Image.load(
            os.path.join(self.clear_dir,clear_img_id+self.format)
        )
        hazy_img = T.ToTensor()(hazy_img)
        clear_img = T.ToTensor()(clear_img)
        return hazy_img, clear_img

    def __len__(self):
        return len(self.hazy_imgs)

In [None]:
""" 
    Model for dehazing;
"""
def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)

# pixel attention
class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.pa(x)
        return x * y

# channel attention
class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

class DehazeBlock(nn.Module):
    def __init__(self,conv,dim,kernel_size,):
        super(DehazeBlock, self).__init__()
        self.conv1 = conv(dim,dim,kernel_size,bias=True)
        self.conv2 = conv(dim,dim,kernel_size,bias=True)
        self.act1 = nn.ReLU(inplace=True)
        ## attention layers
        self.pa = PALayer(self.dim)
        self.ca = CALayer(self.dim)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.act1(out)
        out = self.conv2(out)
        out = self.pa(out)
        out = self.ca(out)
        out = out + residual
        return out
    
class  
