In [None]:
import os
import time
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

try:
    from torchsummary import summary
except:
    !pip install torchsummary
    from torchsummary import summary

In [None]:
# random crop the hazy image and its clear image
def random_crop(gt_image, hazy_image, target_shape = (224, 224)):
    assert gt_image.shape[0] >= target_shape[0]
    assert gt_image.shape[1] >= target_shape[1]
    assert gt_image.shape[0] == hazy_image.shape[0]
    assert gt_image.shape[1] == hazy_image.shape[1]
    
    x = np.random.randint(0, gt_image.shape[1] - target_shape[0])
    y = np.random.randint(0, gt_image.shape[0] - target_shape[1])
    
    gt_image = gt_image[y : y + target_shape[0], x : x + target_shape[1]]
    hazy_image = hazy_image[y : y + target_shape[0], x : x + target_shape[1]]
    return hazy_image, gt_image


# load image 
def load_image(hazy_image_path, gt_image_path):
    gt_image = Image.open(gt_image_path)
    gt_image = np.array(gt_image, dtype=np.float32)
    hazy_image = Image.open(hazy_image_path)
    hazy_image = np.array(hazy_image, dtype=np.float32)
    # Random Crop the image as suggested in paper
    hazy_image, gt_image = random_crop(gt_image, hazy_image)
    gt_image /= 255
    hazy_image /= 255
    
    gt_img_tensor = torch.from_numpy(gt_image)
    hazy_img_tensor = torch.from_numpy(hazy_image)
    
    return hazy_img_tensor.permute(2, 0, 1), gt_img_tensor.permute(2, 0, 1)


In [None]:
gt_path = '../input/dehaze/clear_images/'
hazy_path = '../input/dehaze/haze/'

# Split dataset into train and validation splits
def get_data_splits(gt_images_path, hazy_images_path):
    
    gt_image_paths = list(glob.glob(gt_images_path + '*.jpg'))
    hazy_image_paths = list(glob.glob(hazy_images_path + "*.jpg"))
    
    gt_images = []
    hazy_images = []
    
    for gt_image in gt_image_paths:
        img_name = gt_image.split('/')[-1].split('.')[0]
        for hazy_image in hazy_image_paths:
            if hazy_image.find(img_name) != -1:
                gt_images.append(gt_image)
                hazy_images.append(hazy_image)
        
    
    total_images = len(gt_images)
    
    temp = list(zip(gt_images, hazy_images)) 
    np.random.shuffle(temp) 
    gt_images, hazy_images = zip(*temp)
        
    gt_images = list(gt_images)
    hazy_images = list(hazy_images)
    
    
    train_gt = gt_images[: int(total_images * 0.9)]
    train_hazy = hazy_images[: int(total_images * 0.9)]
    val_gt = gt_images[int(total_images * 0.9) : ]
    val_hazy = hazy_images[int(total_images * 0.9) : ]
    
    
    return {
        'train_gt': train_gt,
        'train_hazy': train_hazy,
        'val_gt': val_gt,
        'val_hazy': val_hazy
    }



In [None]:
# Custom Dataset with Hazy anf Ground Truth Images
class CustomDataset(Dataset):
    def __init__(self, hazy_image_paths, gt_image_paths):
        self.gt_image_paths = gt_image_paths
        self.hazy_image_paths = hazy_image_paths
        
    def __getitem__(self, index):
        return load_image(self.hazy_image_paths[index], self.gt_image_paths[index])

    
    def __len__(self):
        return len(self.gt_image_paths)

In [None]:
batch_size = 35
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

data_splits = get_data_splits(gt_path, hazy_path)
train_dataset = CustomDataset(data_splits['train_hazy'], data_splits['train_gt'])
val_dataset = CustomDataset(data_splits['val_hazy'], data_splits['val_gt'])

# Training and Validation Data Loaders

train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = True)

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.res_b1 = self.get_res_block(2)
        self.res_b2 = self.get_res_block(2)
        self.res_b3 = self.get_res_block(3)
        self.res_b4 = self.get_res_block(4)
        self.relu = nn.ReLU(inplace = True)
    
    def get_res_block(self, block_size = 1, in_dim = 128, out_dim = 128):
        layers = []
        for i in range(block_size + 1):
            layers.append(nn.Conv2d(in_dim, out_dim, 3, padding = 1))
            if i != block_size:
                layers.append(nn.ReLU(inplace = True))
        return nn.Sequential(*layers)
        
        
    
    def forward(self, image):
        output = self.res_b1(image)
        res_b1_image = self.relu(image + output)
        
        output = self.res_b2(res_b1_image)
        res_b2_image = self.relu(res_b1_image + output)
        
        output = self.res_b3(res_b2_image)
        res_b3_image = self.relu(res_b2_image + output)
        
        output = self.res_b4(res_b3_image)
        res_b4_image = res_b3_image + output
        
        return res_b4_image

In [None]:
class GMAN(nn.Module):
    def __init__(self, in_dim = 3, hidden_dim = 64):
        super(GMAN, self).__init__()
        self.relu = nn.ReLU(inplace = True)
        self.gman = nn.Sequential(
            nn.Conv2d(in_dim, hidden_dim, 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(hidden_dim, hidden_dim * 2, 3, padding = 1, stride = 2),
            nn.ReLU(inplace = True),
            nn.Conv2d(hidden_dim * 2, hidden_dim * 2, 3, padding = 1, stride = 2),
            nn.ReLU(inplace = True),
            ResidualBlock(),
            nn.ReLU(inplace = True),
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, 2, stride = 2),
            nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, stride = 2),
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.Conv2d(hidden_dim, in_dim, 3, padding = 1),
        )
    
    
    def forward(self, image):
        return self.relu(image + self.gman(image))
    
        

In [None]:
def init_weights(m):
    if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.008)
        m.bias.data.fill_(0.01)

In [None]:
net = GMAN().to(device)
net.apply(init_weights)

In [None]:
def show_image(hazy_image, gt_image, predicted_image):
    
    title = ['Hazy Image', 'Ground Truth Image', 'Predicted']
    
    plt.figure(figsize=(15, 15))
    
    
    display_list = [
                        hazy_image.cpu().permute(1, 2, 0).numpy(),
                        gt_image.cpu().permute(1, 2, 0).numpy(),
                        predicted_image.detach().cpu().permute(1, 2, 0).numpy()
                   ]
    
    
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
        
    plt.show()
    

In [None]:
n_epochs = 10
opt = torch.optim.Adam(net.parameters(), lr = 0.0001)
criterion = nn.MSELoss()


for epoch in range(n_epochs):
    total_train_loss = 0
    total_val_loss = 0
    start_time = time.time()
    batch_no = 1
    
    print(f'Epoch {epoch + 1} started...')
   
    net.train()
    for (hazy_images, gt_images) in train_dataloader:
        curr_batch_size = hazy_images.size()
        hazy_images = hazy_images.to(device)
        gt_images = gt_images.to(device)
        
        outputs = net(hazy_images)
        
        train_loss = criterion(outputs, gt_images)
        opt.zero_grad()
        train_loss.backward()
        opt.step()
        total_train_loss += train_loss.item() 
        batch_no += 1
      
    
    print(f'Total train loss: {total_train_loss}')
    
    if epoch % 2:
        for(hazy_images, gt_images) in val_dataloader:
            with torch.no_grad():
                hazy_images = hazy_images.to(device)
                gt_images = gt_images.to(device)
                outputs = net(hazy_images)

                show_image(hazy_images[0], gt_images[0], outputs[0])


                val_loss = criterion(outputs, gt_images)
                total_val_loss += val_loss.item() 
        
        print(f'Total validation loss: {total_val_loss}')
    
    end_time = time.time()
    
    print(f'Epoch {epoch + 1} ended, time taken: {end_time - start_time}s')
        
        

In [None]:
torch.save(net.state_dict(), 'state_dict_model.pt')