# DehazingNet Model 
The DehazingNet model is trained using the Mean Squared error and Perceptual Loss. 

## Installing Additional Libraries 

In [1]:
!pip install wandb torch-enhance torchmetrics lpips -q

## Importing Libraries

In [2]:
import gc,os
import numpy as np
import pandas as pd 
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import lpips
import pywt

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as tt 
from torch_enhance.losses import VGG as PerceptualLoss
from torchmetrics.image import PeakSignalNoiseRatio,StructuralSimilarityIndexMeasure

import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value = user_secrets.get_secret("WANDB")
wandb.login(key=secret_value)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Setting the Configuration
Various hyperparameters are set here which helps to keep the associated settings uniform across the script. To monitor and track the model Weights and Biases has been used. No artifacts are being logged into the server.  

In [3]:
class CFG:
    lr=1e-4
    epochs=15
    train=True
    stats=False
    image_shape=(256,256)
    train_bs=32
    val_bs=4
    test_bs=4
    es_patience=5
    device="cuda" if torch.cuda.is_available() else "cpu"


wandb.init(
    project="CT5129-Image Dehazing",
    config={
    "learning_rate": CFG.lr,
    "architecture": "CNN",
    "dataset": "Image Dehazing Dataset",
    "epochs": CFG.epochs,
     "training_bs":CFG.train_bs,
        "validation_bs":CFG.val_bs,
        "test_bs":CFG.test_bs,
        "device":CFG.device,
        "optimizer":"Adam",
        "es_patience":CFG.es_patience,
    }
)


[34m[1mwandb[0m: Currently logged in as: [33mhemanthh17[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.3 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.16.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240629_083327-fss23oih[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mfirm-smoke-87[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/hemanthh17/CT5129-Image%20Dehazing[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/hemanthh17/CT5129-Image%20Dehazing/runs/fss23oih[0m


In [4]:
train_data=pd.read_csv('/kaggle/input/dehazing-dataset-thesis/dehaing_dataset_train.csv')
val_data=pd.read_csv('/kaggle/input/dehazing-dataset-thesis/dehazing_dataset_val.csv')
test_data=pd.read_csv('/kaggle/input/dehazing-dataset-thesis/dehazing_dataset_test.csv')

## Computing the Image Statistics

In this step we compute the channel mean and standard deviation. This will help normalize the image and pass into the model. 

In [5]:
if CFG.stats:
    img_transformer=tt.transforms.Compose([
        tt.transforms.Resize(CFG.image_shape),
        tt.transforms.ToTensor()
    ])
    hazy_img_list=[img_transformer(Image.open(img)).to(CFG.device) for img in train_data.Hazy.values]
    hazy_images_stack=torch.stack(hazy_img_list,dim=1)
    hazy_images_stack=hazy_images_stack.permute(1,0,2,3)
    print("Hazy Images Stack Dimension:",hazy_images_stack.shape)
    print("Mean of the hazy images per channel:",torch.mean(hazy_images_stack,dim=(0,2,3)))
    print("Standard Deviation of the hazy images per channel:",torch.std(hazy_images_stack,dim=(0,2,3)))
    del hazy_img_list,hazy_images_stack
    gc.collect()
    clear_img_list=[img_transformer(Image.open(img)).to(CFG.device) for img in train_data.GT.values]
    clear_images_stack=torch.stack(clear_img_list,dim=1)
    clear_images_stack=clear_images_stack.permute(1,0,2,3)
    print("Clear Images Stack Dimension:",clear_images_stack.shape)
    print("Mean of the hazy images per channel:",torch.mean(clear_images_stack,dim=(0,2,3)))
    print("Standard Deviation of the hazy images per channel:",torch.std(clear_images_stack,dim=(0,2,3)))
    del clear_img_list,clear_images_stack
    gc.collect()

## Creating the Pytorch Dataset

In [6]:
input_transforms=tt.Compose([
    tt.transforms.Resize(CFG.image_shape),
    tt.ToTensor(),
    tt.Normalize(mean=(0.6344,0.5955,0.5857),std=(0.1742,0.1798,0.1871))
])
output_transforms=tt.Compose([
    tt.transforms.Resize(CFG.image_shape),
    tt.ToTensor(),
    tt.Normalize(mean=(0.4556,0.3837,0.3642),std=(0.2689,0.2691,0.2828))
])

In [7]:
class DehazingDataset(Dataset):
    def __init__(self,dataset,in_transforms=None,out_transforms=None):
        self.dataset=dataset
        self.in_transforms=in_transforms
        self.out_transforms=out_transforms
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self,idx):
        hazy_img_path=self.dataset.iloc[idx,1]
        clear_img_path=self.dataset.iloc[idx,0]
        if self.in_transforms:
            hazy_img=self.in_transforms(Image.open(str(hazy_img_path)))
        if self.out_transforms:
            clear_img=self.out_transforms(Image.open(str(clear_img_path)))
        return {'hazy':hazy_img,
               'gt':clear_img}

        

In [8]:
train_dataset=DehazingDataset(train_data,input_transforms,output_transforms)
val_dataset=DehazingDataset(val_data,input_transforms,output_transforms)
test_dataset=DehazingDataset(test_data,input_transforms,output_transforms)

train_loader=DataLoader(train_dataset,batch_size=CFG.train_bs)
val_loader=DataLoader(val_dataset,batch_size=CFG.val_bs)
test_loader=DataLoader(test_dataset,batch_size=CFG.test_bs)

## Model Architecture

In [9]:
class PixelAttention(nn.Module):
    def __init__(self,channel,reduct_ratio=8):
        super(PixelAttention,self).__init__()
        reduced_channel=max(1,channel//reduct_ratio)
        self.pixel_attention=nn.Sequential(
            nn.Conv2d(channel,channel//reduced_channel,kernel_size=1,padding=0,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel//reduced_channel,1,kernel_size=1,padding=0,bias=True),
            nn.Sigmoid()
        )
    def forward(self,feature):
        x=self.pixel_attention(feature)
        return x*feature

class ChannelAttention(nn.Module):
    def __init__(self,input_channels,reduct_ratio=8):
        super(ChannelAttention,self).__init__()
        reduced_channel=max(1,input_channels//reduct_ratio)
        self.avg_pooler=nn.AdaptiveAvgPool2d(1)
        self.fcn=nn.Sequential(
            nn.Linear(input_channels,reduced_channel),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channel,input_channels)
        )

    def forward(self,input_feature):
        n,c,_,_=input_feature.size()
        x=self.avg_pooler(input_feature).view(n,c)
        x=F.sigmoid(self.fcn(x).view(n,c,1,1))
        return input_feature*x

class AttentionBlock(nn.Module):
    def __init__(self,dims,kernel_size=1):
        super(AttentionBlock,self).__init__()
        self.conv1=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.conv2=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.ca=ChannelAttention(dims)
        self.pa=PixelAttention(dims)
    def forward(self,img):
        feat=F.relu(self.conv1(img),inplace=True)
        feat=feat+img
        feat=F.relu(self.conv1(feat),inplace=True)
        feat=self.ca(feat)
        feat=self.pa(feat)
        feat+=img
        return feat
class DehazingNet(nn.Module):
    def __init__(self):
        super(DehazingNet,self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.attn1=AttentionBlock(3)
        self.conv3=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=5,stride=1,padding=2)
        self.conv4=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=7,stride=1,padding=3)
        self.attn2=AttentionBlock(3)
        self.conv5=nn.Conv2d(in_channels=12,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.conv_dwt=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.b=1

    def forward(self,x):
        x1=F.relu(self.conv1(x))
        x2=F.relu(self.conv2(x1))
        x2=self.attn1(x2)
        cat1=torch.cat((x1,x2),1)
        x3=F.relu(self.conv3(cat1))
        cat2=torch.cat((x2,x3),1)
        x4=F.relu(self.conv4(cat2))
        x4=self.attn2(x4)
        cat3=torch.cat((x1,x2,x3,x4),1)
        k=F.relu(self.conv5(cat3))
        output=k*x-k+self.b
        return F.relu(output)



## Loss Definition

The loss functions required to train or observe have been defined here.

In [10]:
class TotalVariationLoss(nn.Module):
    def __init__(self,wt=1):
        super(TotalVariationLoss,self).__init__()
        self.wt=wt
    def forward(self,x):
        wid_var=torch.sum(torch.pow(x[:,:,:,:-1]-x[:,:,:,1:],2))
        ht_var=torch.sum(torch.pow(x[:,:,:-1,:]-x[:,:,1:,:],2))
        return self.wt*(ht_var+wid_var)
class FFTLoss(nn.Module):
    def __init__(self):
        super(FFTLoss,self).__init__()
        self.l1_loss=nn.L1Loss()
    def forward(self,out,gt):
        fft_out=torch.fft.fftn(out,dim=(-2,-1))
        fft_gt=torch.fft.fftn(gt,dim=(-2,-1))
        amp_out=torch.abs(fft_out)
        ph_out=torch.angle(fft_out)
        amp_gt=torch.abs(fft_gt)
        ph_gt=torch.angle(fft_gt)
        amp_loss=self.l1_loss(amp_out,amp_gt)
        ph_loss=self.l1_loss(ph_out,ph_gt)
        return amp_loss+ph_loss
    
perceptual_loss=lpips.LPIPS(net='vgg').to(CFG.device)
tv_loss=TotalVariationLoss().to(CFG.device)
mse_loss=nn.MSELoss().to(CFG.device)
fft_loss=FFTLoss().to(CFG.device)



Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 152MB/s]


Loading model from: /opt/conda/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


# Training the Model
Here the model is being trained on the image pairs and as an intial step Xavier weight initialisation is being used. The model is being parallelised using the DataParallel function as 2 NVidia T4 GPUs have been used on Kaggle to train.

In [11]:
def weight_init(m):
    if isinstance(m,torch.nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias,0)
dehaze_model=DehazingNet()
dehaze_model.apply(weight_init)
dehaze_model=nn.DataParallel(dehaze_model)
dehaze_model=dehaze_model.to(CFG.device)
print("Model Loaded to GPU..")
optimizer=optim.Adam(dehaze_model.parameters(),lr=CFG.lr,weight_decay=1e-4)
scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.1,patience=3,verbose=True)


Model Loaded to GPU..


### Training Loop
The training loop iteratively loops through the training data and trains the model based on the loss functions used. To efficiently optimise the memory usage. Garbage Collection and cache clearence is done in a periodic basis.

In [12]:
if CFG.train:
    best_val_loss=float('inf')
    gc.collect()
    torch.cuda.empty_cache()
    dehaze_model.train()
    patience=CFG.es_patience
    for epoch in tqdm(range(CFG.epochs)):
        total_train_loss=0
        total_val_loss=0
        for pair in train_loader:
            hazy,clear=pair['hazy'].to(CFG.device),pair['gt'].to(CFG.device)
            model_out=dehaze_model(hazy)
            loss_tv=tv_loss(model_out)
            loss_mse=mse_loss(model_out,clear)
            loss_perceptual=perceptual_loss(model_out,clear).mean()
            loss_fft=fft_loss(model_out,clear)
            train_loss=(0.5)*(loss_mse+loss_perceptual)
            total_train_loss+=train_loss.item()
            optimizer.zero_grad()
            train_loss.backward()
            torch.nn.utils.clip_grad_norm_(dehaze_model.parameters(),1e-2)
            optimizer.step()
            wandb.log({"Training FFT Loss":loss_fft,"Training TV Loss":loss_tv,"Training MSE Loss":loss_mse,"Training Perceptual Loss":loss_perceptual})
            del hazy,clear,model_out,loss_tv,loss_mse,loss_fft
            torch.cuda.empty_cache()
        training_loss=total_train_loss/len(train_loader.dataset)
        wandb.log({"Average Training Loss":training_loss,"Epoch": epoch+1})

        dehaze_model.eval()
        for val_pair in val_loader:
            val_hazy,val_clear=val_pair['hazy'].to(CFG.device),val_pair['gt'].to(CFG.device)
            val_output=dehaze_model(val_hazy)
            val_loss_mse=mse_loss(val_output,val_clear)
            val_loss_tv=tv_loss(val_output)
            val_loss_perceptual=perceptual_loss(val_output,val_clear).mean()
            val_loss_fft=fft_loss(val_output,val_clear)
            val_loss=(0.5)*(val_loss_mse+val_loss_perceptual)
            total_val_loss+=val_loss.item()
            wandb.log({"Validation FFT Loss":val_loss_fft,"Validation TV Loss":val_loss_tv,"Validation MSE Loss":val_loss_mse,"Validation Perceptual Loss":val_loss_perceptual})
            del val_hazy,val_clear,val_output,val_loss_mse,val_loss_tv,val_loss_fft,val_loss
        validation_loss=total_val_loss/len(val_loader.dataset)
        wandb.log({"Average Validation Loss":validation_loss})
        if (best_val_loss-validation_loss)>1e-4:
            best_val_loss=validation_loss
            print('Saving Model..')
            torch.save(dehaze_model.state_dict(),'dehazer-model-trained-best.pth')    
            patience=CFG.es_patience
        else:
            patience-=1
            print(f'Patience decreased to {patience}..')
            if patience == 0:
                print('Early stopping triggered...')
                break
        
        scheduler.step(validation_loss)
        gc.collect()
        torch.cuda.empty_cache()
    torch.save(dehaze_model.state_dict(),'dehazer-model-trained.pth')
  

            

  0%|          | 0/15 [00:00<?, ?it/s]

Saving Model..


  7%|▋         | 1/15 [17:56<4:11:13, 1076.67s/it]

Saving Model..


 13%|█▎        | 2/15 [33:05<3:31:54, 978.03s/it] 

Saving Model..


 20%|██        | 3/15 [48:11<3:09:00, 945.02s/it]

Saving Model..


 27%|██▋       | 4/15 [1:03:16<2:50:20, 929.12s/it]

Saving Model..


 33%|███▎      | 5/15 [1:18:22<2:33:27, 920.79s/it]

Saving Model..


 40%|████      | 6/15 [1:33:32<2:17:34, 917.22s/it]

Saving Model..


 47%|████▋     | 7/15 [1:48:38<2:01:48, 913.61s/it]

Saving Model..


 53%|█████▎    | 8/15 [2:03:57<1:46:46, 915.17s/it]

Saving Model..


 60%|██████    | 9/15 [2:19:18<1:31:42, 917.03s/it]

Saving Model..


 73%|███████▎  | 11/15 [2:49:54<1:01:10, 917.68s/it]

Patience decreased to 4..


 80%|████████  | 12/15 [3:05:06<45:47, 915.81s/it]  

Patience decreased to 3..
Saving Model..


 87%|████████▋ | 13/15 [3:20:16<30:28, 914.13s/it]

Saving Model..


 93%|█████████▎| 14/15 [3:35:33<15:15, 915.02s/it]

Saving Model..


100%|██████████| 15/15 [3:50:37<00:00, 922.53s/it]
