# Model Testing Script

This script is used to evaluate models at various stages. At each stage, the PSNR and SSIM functions are used to evaluate the output at each stage.

## Importing Libraries

In [1]:
!pip install torch-enhance torchmetrics lpips -q

In [2]:
import gc,os,cv2
from glob import glob
import numpy as np
import pandas as pd 
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import lpips
import pywt
import shutil,time

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as tt 
from torch_enhance.losses import VGG as PerceptualLoss
from torchmetrics.image import PeakSignalNoiseRatio,StructuralSimilarityIndexMeasure

## Dataset Preparation

In [3]:
test_sample=pd.read_csv('/kaggle/input/dehazing-dataset-thesis/dehazing_dataset_sample_test.csv')
test_data=pd.read_csv('/kaggle/input/dehazing-dataset-thesis/dehazing_dataset_test.csv')

In [4]:
class DehazingDataset_RGB(Dataset):
    def __init__(self,dataset,in_transforms=None,out_transforms=None):
        self.dataset=dataset
        self.in_transforms=in_transforms
        self.out_transforms=out_transforms
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self,idx):
        hazy_img_path=self.dataset.iloc[idx,1]
        clear_img_path=self.dataset.iloc[idx,0]
        if self.in_transforms:
            hazy_img=self.in_transforms(Image.open(str(hazy_img_path)))
        if self.out_transforms:
            clear_img=self.out_transforms(Image.open(str(clear_img_path)))
        return {'hazy':hazy_img,
               'gt':clear_img}

    
    
class DehazingDataset_YCBCR(Dataset):
    def __init__(self,dataset,in_transforms=None,out_transforms=None):
        self.dataset=dataset
        self.in_transforms=in_transforms
        self.out_transforms=out_transforms
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self,idx):
        hazy_img_path=self.dataset.iloc[idx,1]
        clear_img_path=self.dataset.iloc[idx,0]
        if self.in_transforms:
            hazy_img=self.in_transforms(Image.open(str(hazy_img_path)).convert('YCbCr'))
        if self.out_transforms:
            clear_img=self.out_transforms(Image.open(str(clear_img_path)).convert('YCbCr'))
        return {'hazy':hazy_img,
               'gt':clear_img} 



In [5]:
input_transforms_rgb=tt.Compose([
    tt.transforms.Resize((256,256)),
    tt.ToTensor(),
    tt.Normalize(mean=(0.6344,0.5955,0.5857),std=(0.1742,0.1798,0.1871))
])
output_transforms_rgb=tt.Compose([
    tt.transforms.Resize((256,256)),
    tt.ToTensor(),
    tt.Normalize(mean=(0.4556,0.3837,0.3642),std=(0.2689,0.2691,0.2828))
])

test_dataset_rgb=DehazingDataset_RGB(test_data,input_transforms_rgb,output_transforms_rgb)
test_sample_dataset_rgb=DehazingDataset_RGB(test_sample,input_transforms_rgb,output_transforms_rgb)


In [6]:
input_transforms_ycbcr=tt.Compose([
    tt.transforms.Resize((256,256)),
    tt.ToTensor(),
    tt.Normalize(mean=(0.6041,0.4889,0.5205),std=(0.1769,0.0279,0.0251))
])
output_transforms_ycbcr=tt.Compose([
    tt.transforms.Resize((256,256)),
    tt.ToTensor(),
    tt.Normalize(mean=(0.4011,0.4784,0.5378),std=(0.2667,0.0479,0.0414))
])

test_dataset_ycbcr=DehazingDataset_YCBCR(test_data,input_transforms_ycbcr,output_transforms_ycbcr)
test_sample_dataset_ycbcr=DehazingDataset_YCBCR(test_sample,input_transforms_ycbcr,output_transforms_ycbcr)


In [7]:
test_sample_loader_rgb=DataLoader(test_sample_dataset_rgb,batch_size=1)
test_loader_rgb=DataLoader(test_dataset_rgb,batch_size=1)
test_sample_loader_ycbcr=DataLoader(test_sample_dataset_ycbcr,batch_size=1)
test_loader_ycbcr=DataLoader(test_dataset_ycbcr,batch_size=1)

## Model Initilaisations

In [8]:
class PixelAttention(nn.Module):
    def __init__(self,channel,reduct_ratio=8):
        super(PixelAttention,self).__init__()
        reduced_channel=max(1,channel//reduct_ratio)
        self.pixel_attention=nn.Sequential(
            nn.Conv2d(channel,channel//reduced_channel,kernel_size=1,padding=0,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel//reduced_channel,1,kernel_size=1,padding=0,bias=True),
            nn.Sigmoid()
        )
    def forward(self,feature):
        x=self.pixel_attention(feature)
        return x*feature

class ChannelAttention(nn.Module):
    def __init__(self,input_channels,reduct_ratio=8):
        super(ChannelAttention,self).__init__()
        reduced_channel=max(1,input_channels//reduct_ratio)
        self.avg_pooler=nn.AdaptiveAvgPool2d(1)
        self.fcn=nn.Sequential(
            nn.Linear(input_channels,reduced_channel),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channel,input_channels)
        )

    def forward(self,input_feature):
        n,c,_,_=input_feature.size()
        x=self.avg_pooler(input_feature).view(n,c)
        x=F.sigmoid(self.fcn(x).view(n,c,1,1))
        return input_feature*x

class AttentionBlock(nn.Module):
    def __init__(self,dims,kernel_size=1):
        super(AttentionBlock,self).__init__()
        self.conv1=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.conv2=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.ca=ChannelAttention(dims)
        self.pa=PixelAttention(dims)
    def forward(self,img):
        feat=F.relu(self.conv1(img),inplace=True)
        feat=feat+img
        feat=F.relu(self.conv1(feat),inplace=True)
        feat=self.ca(feat)
        feat=self.pa(feat)
        feat+=img
        return feat
class DWT_DehazingNet(nn.Module):
    def __init__(self):
        super(DWT_DehazingNet,self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.attn1=AttentionBlock(3)
        self.conv3=nn.Conv2d(in_channels=9,out_channels=3,kernel_size=5,stride=1,padding=2)
        self.conv4=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=7,stride=1,padding=3)
        self.attn2=AttentionBlock(3)
        self.conv5=nn.Conv2d(in_channels=15,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.conv_dwt=nn.Conv2d(in_channels=12,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.b=1

    def forward(self,x):
        dwt_coeffs=pywt.dwt2(x.cpu(),wavelet='db4')
        LL,(LH,HL,HH)=dwt_coeffs
        dwt_out=torch.concat([torch.from_numpy(LL),torch.from_numpy(LH),torch.from_numpy(HL),torch.from_numpy(HH)],dim=1)
        x1=F.relu(self.conv1(x))
        dwt_out=tt.Resize((256,256))(dwt_out)
        dwt_in=self.conv_dwt(dwt_out)
        x2=F.relu(self.conv2(x1))
        x2=self.attn1(x2)
        cat1=torch.cat((x1,x2,dwt_in),1)
        x3=F.relu(self.conv3(cat1))
        cat2=torch.cat((x2,x3),1)
        x4=F.relu(self.conv4(cat2))
        x4=self.attn2(x4)
        cat3=torch.cat((x1,x2,x3,x4,dwt_in),1)
        k=F.relu(self.conv5(cat3))
        return F.relu(k*x-k+self.b)
class DehazingNet(nn.Module):
    def __init__(self):
        super(DehazingNet,self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.attn1=AttentionBlock(3)
        self.conv3=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=5,stride=1,padding=2)
        self.conv4=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=7,stride=1,padding=3)
        self.attn2=AttentionBlock(3)
        self.conv5=nn.Conv2d(in_channels=12,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.conv_dwt=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.b=1

    def forward(self,x):
        x1=F.relu(self.conv1(x))
        x2=F.relu(self.conv2(x1))
        x2=self.attn1(x2)
        cat1=torch.cat((x1,x2),1)
        x3=F.relu(self.conv3(cat1))
        cat2=torch.cat((x2,x3),1)
        x4=F.relu(self.conv4(cat2))
        x4=self.attn2(x4)
        cat3=torch.cat((x1,x2,x3,x4),1)
        k=F.relu(self.conv5(cat3))
        return F.relu(k*x-k+self.b)
class AODnet(nn.Module):
    def __init__(self):
        super(AODnet,self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.conv3=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=5,stride=1,padding=2)
        self.conv4=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=7,stride=1,padding=3)
        self.conv5=nn.Conv2d(in_channels=12,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.b=1

    def forward(self, x):
        x1=F.relu(self.conv1(x))
        x2=F.relu(self.conv2(x1))
        cat1=torch.cat((x1,x2),1)
        x3=F.relu(self.conv3(cat1))
        cat2=torch.cat((x2,x3),1)
        x4=F.relu(self.conv4(cat2))
        cat3=torch.cat((x1,x2,x3,x4),1)
        k=F.relu(self.conv5(cat3))
        return F.relu(k*x-k+self.b)

class AOD_net_pretrained(nn.Module):
	def __init__(self):
		super(AOD_net_pretrained, self).__init__()
		self.relu=nn.ReLU(inplace=True)
		self.e_conv1=nn.Conv2d(3,3,1,1,0,bias=True) 
		self.e_conv2=nn.Conv2d(3,3,3,1,1,bias=True) 
		self.e_conv3=nn.Conv2d(6,3,5,1,2,bias=True) 
		self.e_conv4=nn.Conv2d(6,3,7,1,3,bias=True) 
		self.e_conv5=nn.Conv2d(12,3,3,1,1,bias=True) 		
	def forward(self,x):
		source=[]
		source.append(x)
		x1=self.relu(self.e_conv1(x))
		x2=self.relu(self.e_conv2(x1))
		concat1=torch.cat((x1,x2),1)
		x3=self.relu(self.e_conv3(concat1))
		concat2=torch.cat((x2,x3),1)
		x4=self.relu(self.e_conv4(concat2))
		concat3=torch.cat((x1,x2,x3,x4),1)
		x5=self.relu(self.e_conv5(concat3))
		return self.relu((x5*x)-x5+1) 


In [9]:
dwt_dehazenet_rgb=nn.DataParallel(DWT_DehazingNet())
dwt_dehazenet_rgb_new=nn.DataParallel(DWT_DehazingNet())
dwt_dehazenet_ycbcr_3l=nn.DataParallel(DWT_DehazingNet())
dwt_dehazenet_ycbcr=nn.DataParallel(DWT_DehazingNet())
dehazenet_ycbcr=nn.DataParallel(DehazingNet())
aod_ycbcr_l2=nn.DataParallel(AODnet())
aod_ycbcr_mse=nn.DataParallel(AODnet())
aod_ycbcr_3l=nn.DataParallel(AODnet())
aod_pretrained=AOD_net_pretrained()
dehazenet_rgb=nn.DataParallel(DehazingNet())
dehazenet_rgb_fft=nn.DataParallel(DehazingNet())
dwt_dehazenet_rgb_fft=nn.DataParallel(DWT_DehazingNet())


dwt_dehazenet_ycbcr_3l.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dwt-dehazenet-ycbcr-3l.pth',map_location=torch.device('cpu')))
dwt_dehazenet_rgb_fft.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dwt-dehazenet-fft-rgb.pth',map_location=torch.device('cpu')))
dwt_dehazenet_rgb.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazing-rgb-dwt-2l.pth',map_location=torch.device('cpu')))
dwt_dehazenet_ycbcr.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazing-ycbcr-dwt-2l.pth',map_location=torch.device('cpu')))
dehazenet_ycbcr.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazing-ycbcr-2l.pth',map_location=torch.device('cpu')))
aod_ycbcr_l2.load_state_dict(torch.load('/kaggle/input/dehazing-models-ct5129/aodnet-ycbcr-2l.pth',map_location=torch.device('cpu')))
aod_ycbcr_3l.load_state_dict(torch.load('/kaggle/input/dehazing-models-ct5129/aodnet-ycbcr-3l.pth',map_location=torch.device('cpu')))
aod_ycbcr_mse.load_state_dict(torch.load('/kaggle/input/dehazing-models-ct5129/aodnet-ycbcr-mse.pth',map_location=torch.device('cpu')))
aod_pretrained.load_state_dict(torch.load('/kaggle/input/dehazing-models-ct5129/aodnet-pretrained.pth',map_location=torch.device('cpu')))
dehazenet_rgb.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazenet-rgb-2l.pth',map_location=torch.device('cpu')))
dehazenet_rgb_fft.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazenet-mse_per_fft-rgb.pth',map_location=torch.device('cpu')))


<All keys matched successfully>

In [10]:
class PixelAttention(nn.Module):
    def __init__(self,channel,reduct_ratio=8):
        super(PixelAttention,self).__init__()
        reduced_channel=max(1,channel//reduct_ratio)
        self.pixel_attention=nn.Sequential(
            nn.Conv2d(channel,channel//reduced_channel,kernel_size=1,padding=0,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel//reduced_channel,1,kernel_size=1,padding=0,bias=True),
            nn.Sigmoid()
        )
    def forward(self,feature):
        x=self.pixel_attention(feature)
        return x*feature

class ChannelAttention(nn.Module):
    def __init__(self,input_channels,reduct_ratio=8):
        super(ChannelAttention,self).__init__()
        reduced_channel=max(1,input_channels//reduct_ratio)
        self.avg_pooler=nn.AdaptiveAvgPool2d(1)
        self.fcn=nn.Sequential(
            nn.Linear(input_channels,reduced_channel),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channel,input_channels)
        )

    def forward(self,input_feature):
        n,c,_,_=input_feature.size()
        x=self.avg_pooler(input_feature).view(n,c)
        x=F.sigmoid(self.fcn(x).view(n,c,1,1))
        return input_feature*x

class AttentionBlock(nn.Module):
    def __init__(self,dims,kernel_size=1):
        super(AttentionBlock,self).__init__()
        self.conv1=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.conv2=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.ca=ChannelAttention(dims)
        self.pa=PixelAttention(dims)
    def forward(self,img):
        feat=F.relu(self.conv1(img),inplace=True)
        feat=feat+img
        feat=F.relu(self.conv1(feat),inplace=True)
        feat=self.ca(feat)
        feat=self.pa(feat)
        feat+=img
        return feat
class DehazingNet(nn.Module):
    def __init__(self):
        super(DehazingNet, self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.attn1=AttentionBlock(3)
        self.conv3=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=5,stride=1,padding=2)
        self.conv4=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=7,stride=1,padding=3)
        self.attn2=AttentionBlock(3)
        self.conv5=nn.Conv2d(in_channels=12,out_channels=3, kernel_size=3, stride=1,padding=1)
        self.b=1
        self.pono=PONO(affine=False)
        self.ms=MS()

    def forward(self, x):
        x1=F.relu(self.conv1(x))
        x2=self.attn1(F.relu(self.conv2(x1)))
        cat1=torch.cat((x1,x2),1)
        x1,mean1,std1=self.pono(x1)
        x2, mean2,std2=self.pono(x2)
        x3=F.relu(self.conv3(cat1))
        cat2=torch.cat((x2,x3),1)
        x3=self.ms(x3,mean1,std1)
        x4=self.attn2(F.relu(self.conv4(cat2)))
        x4=self.ms(x4,mean2,std2)
        cat3=torch.cat((x1,x2,x3,x4),1)
        k=F.relu(self.conv5(cat3))
        return F.relu(k*x-k+self.b)

class PONO(nn.Module):
    def __init__(self,input_size=None,stats=False,affine=True,eps=1e-5):
        super(PONO, self).__init__()
        self.return_stats=stats
        self.input_size=input_size
        self.eps=eps
        self.affine=affine
        if affine:
            self.beta=nn.Parameter(torch.zeros(1,1,*input_size))
            self.gamma=nn.Parameter(torch.ones(1,1,*input_size))
        else:
            self.beta,self.gamma=None,None
    def forward(self, x):
        mean=x.mean(dim=1,keepdim=True)
        std=(x.var(dim=1,keepdim=True)+self.eps).sqrt()
        x=(x-mean)/std
        if self.affine:
            x=x*self.gamma+self.beta
        return x,mean,std

class MS(nn.Module):
    def __init__(self,beta=None,gamma=None):
        super(MS,self).__init__()
        self.gamma,self.beta=gamma,beta
    def forward(self,x,beta=None,gamma=None):
        beta=self.beta if beta is None else beta
        gamma=self.gamma if gamma is None else gamma
        if gamma is not None:
            x*=gamma
        if beta is not None:
            x+=beta
        return x
dehazenet_pono_fft=nn.DataParallel(DehazingNet())
dehazenet_pono_fft.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazenet-pono-rgb-fft.pth',map_location=torch.device('cpu')))


<All keys matched successfully>

In [11]:
class PixelAttention(nn.Module):
    def __init__(self,channel,reduct_ratio=8):
        super(PixelAttention,self).__init__()
        self.pixel_attention=nn.Sequential(
            nn.Conv2d(channel,channel//reduct_ratio,kernel_size=1,padding=0,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel//reduct_ratio,1,kernel_size=1,padding=0,bias=True),
            nn.Sigmoid()
        )
    def forward(self,feature):
        x=self.pixel_attention(feature)
        return x*feature

class ChannelAttention(nn.Module):
    def __init__(self,input_channels,reduct_ratio=8):
        super(ChannelAttention,self).__init__()
        self.avg_pooler=nn.AdaptiveAvgPool2d(1)
        self.fcn=nn.Sequential(
            nn.Linear(input_channels,input_channels//reduct_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(input_channels//reduct_ratio,input_channels)
        )

    def forward(self,input_feature):
        n,c,_,_=input_feature.size()
        x=self.avg_pooler(input_feature).view(n,c)
        x=F.sigmoid(self.fcn(x).view(n,c,1,1))
        return input_feature*x

class AttentionBlock(nn.Module):
    def __init__(self,dims,kernel_size=1):
        super(AttentionBlock,self).__init__()
        self.conv1=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.conv2=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.ca=ChannelAttention(dims)
        self.pa=PixelAttention(dims)
    def forward(self,img):
        feat=F.relu(self.conv1(img),inplace=True)
        feat=feat+img
        feat=F.relu(self.conv1(feat),inplace=True)
        feat=self.ca(feat)
        feat=self.pa(feat)
        feat+=img
        return feat

class ComplexNet_RGB(nn.Module):
    def __init__(self,input_channels=3,output_channels=3,features=128):
        super(ComplexNet_RGB,self).__init__()
        self.encoder_c1=nn.Conv2d(input_channels,features,kernel_size=3,padding=1,padding_mode='reflect',stride=2)
        self.encoder_c2=nn.Conv2d(features+(output_channels*4),features*2,kernel_size=3,padding=1,padding_mode='reflect',stride=2)
        self.enc_chan_attn1=AttentionBlock(features*2)
        self.encoder_n1=nn.InstanceNorm2d(features*2,affine=True)
        self.encoder_c3=nn.Conv2d(features*2,features*4,kernel_size=3,padding=1,stride=2)
        self.encoder_c4=nn.Conv2d(features*4,features*8,kernel_size=3,padding=1,stride=2)
        self.enc_chan_attn2=AttentionBlock(features*8)
        self.encoder_n2=nn.InstanceNorm2d(features*8,affine=True)

        self.decoder_c1=nn.ConvTranspose2d(features*8,features*4,kernel_size=4,stride=2,padding=1)
        self.dec_chan_attn1=AttentionBlock(features*4)
        self.decoder_c2=nn.ConvTranspose2d(features*4,features*2,kernel_size=4,stride=2,padding=1)
        self.decoder_c3=nn.ConvTranspose2d(features*2,features,kernel_size=4,stride=2,padding=1)
        self.decoder_c4=nn.ConvTranspose2d(features+(output_channels*4),features,kernel_size=4,stride=2,padding=1)
        self.dec_chan_attn2=AttentionBlock(features)
        self.decoder_c5=nn.Conv2d(features,features,kernel_size=3,padding=1)
        self.decoder_c6=nn.Conv2d(features,output_channels,kernel_size=3,padding=1)


    def forward(self, hazy):
        dwt_coeffs=pywt.dwt2(hazy.cpu(),wavelet='db4')
        LL,(LH,HL,HH)=dwt_coeffs
        dwt_out=torch.concat([torch.from_numpy(LL),torch.from_numpy(LH),torch.from_numpy(HL),torch.from_numpy(HH)],dim=1)
        dwt_out=tt.Resize((128,128))(dwt_out)
        x1=F.relu(self.encoder_c1(hazy),inplace=True)
        x1=torch.concat([x1,dwt_out],dim=1)
        x2=self.encoder_n1(self.enc_chan_attn1(F.relu(self.encoder_c2(x1),inplace=True)))
        x3=F.relu(self.encoder_c3(x2),inplace=True)
        x4=self.encoder_n2(self.enc_chan_attn2(F.relu(self.encoder_c4(x3),inplace=True)))

        x5=F.relu(self.dec_chan_attn1(F.relu(self.decoder_c1(x4),inplace=True)+x3),inplace=True)
        x6=F.relu(self.decoder_c2(x5),inplace=True)+x2
        x7=F.relu(self.decoder_c3(x6),inplace=True)
        x7=torch.concat([x7,dwt_out],dim=1)
        x8=F.relu(self.decoder_c4(x7),inplace=True)
        x9=F.relu(self.decoder_c5(x8),inplace=True)
        x10=F.relu(self.decoder_c6(x9),inplace=True)
        return x10
complexnet_rgb=nn.DataParallel(ComplexNet_RGB())
complexnet_rgb.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/complexnet-rgb-3l.pth',map_location=torch.device('cpu')))


<All keys matched successfully>

In [12]:
class PixelAttention(nn.Module):
    def __init__(self,channel,reduct_ratio=8):
        super(PixelAttention,self).__init__()
        self.pixel_attention=nn.Sequential(
            nn.Conv2d(channel,channel//reduct_ratio,kernel_size=1,padding=0,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel//reduct_ratio,1,kernel_size=1,padding=0,bias=True),
            nn.Sigmoid()
        )
    def forward(self,feature):
        x=self.pixel_attention(feature)
        return x*feature

class ChannelAttention(nn.Module):
    def __init__(self,input_channels,reduct_ratio=8):
        super(ChannelAttention,self).__init__()
        self.avg_pooler=nn.AdaptiveAvgPool2d(1)
        self.fcn=nn.Sequential(
            nn.Linear(input_channels,input_channels//reduct_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(input_channels//reduct_ratio,input_channels)
        )

    def forward(self,input_feature):
        n,c,_,_=input_feature.size()
        x=self.avg_pooler(input_feature).view(n,c)
        x=F.sigmoid(self.fcn(x).view(n,c,1,1))
        return input_feature*x

class AttentionBlock(nn.Module):
    def __init__(self,dims,kernel_size=1):
        super(AttentionBlock,self).__init__()
        self.conv1=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.conv2=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.ca=ChannelAttention(dims)
        self.pa=PixelAttention(dims)
    def forward(self,img):
        feat=F.relu(self.conv1(img),inplace=True)
        feat=feat+img
        feat=F.relu(self.conv1(feat),inplace=True)
        feat=self.ca(feat)
        feat=self.pa(feat)
        feat+=img
        return feat


class ComplexNet_YCRCB(nn.Module):
    def __init__(self,input_channels=3,output_channels=3,features=128):
        super(ComplexNet_YCRCB,self).__init__()
        self.encoder_c1=nn.Conv2d(input_channels,features,kernel_size=3,padding=1,padding_mode='reflect',stride=2)
        self.encoder_c2=nn.Conv2d(features+(output_channels*4),features*2,kernel_size=3,padding=1,padding_mode='reflect',stride=2)
        self.enc_chan_attn1=AttentionBlock(features*2)
        self.encoder_n1=nn.InstanceNorm2d(features*2,affine=True)
        self.encoder_c3=nn.Conv2d(features*2,features*4,kernel_size=3,padding=1,stride=2)
        self.encoder_c4=nn.Conv2d(features*4,features*8,kernel_size=3,padding=1,stride=2)
        self.enc_chan_attn2=AttentionBlock(features*8)
        self.encoder_n2=nn.InstanceNorm2d(features*8,affine=True)

        self.decoder_c1=nn.ConvTranspose2d(features*8,features*4,kernel_size=4,stride=2,padding=1)
        self.dec_chan_attn1=AttentionBlock(features*4)
        self.decoder_c2=nn.ConvTranspose2d(features*4,features*2,kernel_size=4,stride=2,padding=1)
        self.decoder_c3=nn.ConvTranspose2d(features*2,features,kernel_size=4,stride=2,padding=1)
        self.decoder_c4=nn.ConvTranspose2d(features+(output_channels*4),features,kernel_size=4,stride=2,padding=1)
        self.dec_chan_attn2=AttentionBlock(features)
        self.decoder_c5=nn.Conv2d(features,features,kernel_size=3,padding=1)
        self.decoder_c6=nn.Conv2d(features,output_channels,kernel_size=3,padding=1)


    def forward(self, hazy):
        dwt_coeffs=pywt.dwt2(hazy.cpu(),wavelet='db4')
        LL,(LH,HL,HH)=dwt_coeffs
        dwt_out=torch.concat([torch.from_numpy(LL),torch.from_numpy(LH),torch.from_numpy(HL),torch.from_numpy(HH)],dim=1)
        dwt_out=tt.Resize((128,128))(dwt_out)
        x1=F.relu(self.encoder_c1(hazy),inplace=True)
        x1=torch.concat([x1,dwt_out],dim=1)
        x2=self.encoder_n1(self.enc_chan_attn1(F.relu(self.encoder_c2(x1),inplace=True)))
        x3=F.relu(self.encoder_c3(x2),inplace=True)
        x4=self.encoder_n2(self.enc_chan_attn2(F.relu(self.encoder_c4(x3),inplace=True)))

        x5=F.relu(self.dec_chan_attn1(F.relu(self.decoder_c1(x4),inplace=True)+x3),inplace=True)
        x6=F.relu(self.decoder_c2(x5),inplace=True)+x2
        x7=F.relu(self.decoder_c3(x6),inplace=True)
        x7=torch.concat([x7,dwt_out],dim=1)
        x8=F.relu(self.decoder_c4(x7),inplace=True)
        x9=F.relu(self.decoder_c5(x8),inplace=True)
        x10=F.relu(self.decoder_c6(x9),inplace=True)
        return x10
complexnet_ycbcr=nn.DataParallel(ComplexNet_YCRCB())
complexnet_ycbcr.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/complexnet-ycbcr-3l.pth',map_location=torch.device('cpu')))



<All keys matched successfully>

## Helper Functions

In [13]:
def save_image_ycbcr(img_tensor,file_path):
    if img_tensor.shape[0]!=3:
        raise ValueError("Input tensor must have 3 channels only...")
    ycbcr_array=tensor_denormalize_ycbcr(img_tensor).permute(1,2,0).cpu().detach().numpy()
    ycbcr_image=Image.fromarray((np.clip(ycbcr_array,0,1)*255).astype(np.uint8),mode='YCbCr')
    rgb_image=ycbcr_image.convert('RGB')
    rgb_image.save(file_path)
    
def tensor_denormalize_ycbcr(out_tensor,mean=[0.4011,0.4784,0.5378],std=[0.2667,0.0479,0.0414]):
    if len(out_tensor.shape)==3:
        out_tensor=out_tensor.unsqueeze(0)
    mean=torch.tensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3)
    std=torch.tensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3)    
    denorm_tensor=(out_tensor*std)+mean
    return denorm_tensor.squeeze(0) 
def save_image_rgb(img_tensor,file_path):
    if img_tensor.shape[0]!=3:
        raise ValueError("Input tensor must have 3 channels only...")
    rgb_array=tensor_denormalize_rgb(img_tensor).permute(1,2,0).cpu().detach().numpy()
    rgb_image=Image.fromarray((np.clip(rgb_array,0,1)*255).astype(np.uint8),mode='RGB')
    rgb_image.save(file_path)

def tensor_denormalize_rgb(out_tensor,mean=[0.4556,0.3837,0.3642],std=[0.2689,0.2691,0.2828]):
    if len(out_tensor.shape)==3:
        out_tensor=out_tensor.unsqueeze(0)
    mean=torch.tensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3)
    std=torch.tensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3)    
    denorm_tensor=(out_tensor*std)+mean
    return denorm_tensor.squeeze(0)

ssim_fn=StructuralSimilarityIndexMeasure()
psnr_fn=PeakSignalNoiseRatio()

def metrics_calculator(out_path,clear_path):
    out_tensor=tt.ToTensor()(cv2.imread(out_path))
    clear_tensor=tt.ToTensor()(cv2.imread(clear_path))
    return psnr_fn(out_tensor,clear_tensor),ssim_fn(out_tensor.unsqueeze(0),clear_tensor.unsqueeze(0))
    
    

In [14]:
def model_evaluator(dehaze_model,exp_name,data_loader,img_format='RGB',mode='sample'):
    psnr_data,ssim_data,time_data=[],[],[]
    os.makedirs(f'/kaggle/working/{exp_name}',exist_ok=True)
    i=0
    dehaze_model.eval()
    for pair in data_loader:
        inp=pair['hazy']
        clear=pair['gt']        
        img_tensor=inp.cpu()
        start_time=time.time()
        model_out=dehaze_model(img_tensor)
        time_data.append(time.time()-start_time)
        out_tensor=model_out.cpu() 
        if img_format=="RGB":
            save_image_rgb(img_tensor.squeeze(),f'/kaggle/working/{exp_name}/input_image_{i}.png')
            save_image_rgb(out_tensor.squeeze(),f'/kaggle/working/{exp_name}/output_image_{i}.png')
            save_image_rgb(clear.squeeze(),f'/kaggle/working/{exp_name}/clear_image_{i}.png')
        else:
            save_image_ycbcr(img_tensor.squeeze(),f'/kaggle/working/{exp_name}/input_image_{i}.png')
            save_image_ycbcr(out_tensor.squeeze(),f'/kaggle/working/{exp_name}/output_image_{i}.png')
            save_image_ycbcr(clear.squeeze(),f'/kaggle/working/{exp_name}/clear_image_{i}.png')
        output_calc_metric=metrics_calculator(f'/kaggle/working/{exp_name}/output_image_{i}.png',
                                             f'/kaggle/working/{exp_name}/clear_image_{i}.png')
        psnr_data.append(output_calc_metric[0].detach().numpy())
        ssim_data.append(output_calc_metric[1].detach().numpy())
        i+=1
        gc.collect()
    
    print(f"Avg Time for {exp_name}: {sum(time_data)/len(time_data)}")
    

## Primary Postprocessing

In [15]:
def unsharp_mask(image,kernel_size=(5,5),sigma=0.4,amount=1.0,threshold=1):
    blurred=cv2.GaussianBlur(image,kernel_size,sigma)
    sharpened=float(amount+1)*image-float(amount)*blurred
    sharpened=np.maximum(sharpened,np.zeros(sharpened.shape))
    sharpened=np.minimum(sharpened,255*np.ones(sharpened.shape))
    sharpened=sharpened.round().astype(np.uint8)
    if threshold>0:
        low_contrast_mask=np.absolute(image-blurred)<threshold
        np.copyto(sharpened,image,where=low_contrast_mask)
    return sharpened

def clahe(image):
    clahe=cv2.createCLAHE(clipLimit=1,tileGridSize=(2,2))
    lab=cv2.cvtColor(image,cv2.COLOR_BGR2LAB)
    lab[:,:,0]=clahe.apply(lab[:,:,0])
    return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)

def enhance_image(image_path):
    image=cv2.imread(image_path)
    img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_sharpened=unsharp_mask(img)
    image_clahe=clahe(image_sharpened)
    image_tensor=tt.ToTensor()(image_clahe)   
    return image_tensor

In [16]:
def processed_tester(path,exp_name):
    psnr_data_raw,ssim_data_raw=[],[]
    psnr_data_proc,ssim_data_proc=[],[]
    for i in range(len(glob(path+'*.png'))//3):
        proc_out_tensor=enhance_image(path+f'output_image_{i}.png')
        proc_arr=proc_out_tensor.permute(1,2,0).cpu().detach().numpy()
        proc_image=Image.fromarray((np.clip(proc_arr,0,1)*255).astype(np.uint8),mode='RGB')
        proc_image.save(path+f'processed_image_{i}.png')       
        out_tensor=tt.ToTensor()(cv2.imread(path+f'output_image_{i}.png'))
        clear_tensor=tt.ToTensor()(cv2.imread(path+f'clear_image_{i}.png'))
        psnr_data_raw.append(psnr_fn(out_tensor,clear_tensor))
        psnr_data_proc.append(psnr_fn(proc_out_tensor,clear_tensor))
        ssim_data_raw.append(ssim_fn(out_tensor.unsqueeze(0),clear_tensor.unsqueeze(0)))
        ssim_data_proc.append(ssim_fn(proc_out_tensor.unsqueeze(0),clear_tensor.unsqueeze(0)))
        gc.collect()
    print(f"Avg PSNR Data for {exp_name} Raw: {sum(psnr_data_raw)/len(psnr_data_raw)}")
    print(f"Avg SSIM Data for {exp_name} Raw: {sum(ssim_data_raw)/len(ssim_data_raw)}")
    print(f"Avg PSNR Data for {exp_name} Processed: {sum(psnr_data_proc)/len(psnr_data_proc)}")
    print(f"Avg SSIM Data for {exp_name} Processed: {sum(ssim_data_proc)/len(ssim_data_proc)}")
def evaluation_printer(dehaze_model,exp_name,data_loader,img_format='RGB',mode='sample'):
    print(f'----------------{exp_name}-----------------------------')
    model_evaluator(dehaze_model,exp_name,data_loader,img_format,mode)
    processed_tester(f'/kaggle/working/{exp_name}/',exp_name)
    gc.collect()
    shutil.make_archive(f'/kaggle/working/{exp_name}','zip',f'/kaggle/working/{exp_name}')
    print('\n')

    

In [17]:
evaluation_printer(dehazenet_rgb_fft,'Dehazenet_3L_RGB_FFT',test_sample_loader_rgb,'RGB')
evaluation_printer(dwt_dehazenet_rgb_fft,'DWT_Dehazenet_RGB_FFT',test_sample_loader_rgb,'RGB')
evaluation_printer(aod_pretrained,'AODNet_Pretrained',test_sample_loader_rgb,'RGB')
evaluation_printer(dwt_dehazenet_rgb,'DWT_Dehazenet_RGB',test_sample_loader_rgb,'RGB')
evaluation_printer(dehazenet_rgb,'Dehazenet_2L_RGB',test_sample_loader_rgb,'RGB')
evaluation_printer(dehazenet_pono_fft,'PONO_Dehazenet_RGB_FFT',test_sample_loader_rgb,'RGB')
evaluation_printer(complexnet_rgb,'ComplexNet_3L_RGB',test_sample_loader_rgb,'RGB')


evaluation_printer(dwt_dehazenet_ycbcr_3l,'DWT_Dehazenet_YCbCr_3L',test_sample_loader_ycbcr,'YCBCR')
evaluation_printer(dwt_dehazenet_ycbcr,'DWT_Dehazenet_YCbCr',test_sample_loader_ycbcr,'YCBCR')
evaluation_printer(aod_ycbcr_mse,'AODNet_MSE_YCbCr',test_sample_loader_ycbcr,'YCBCR')
evaluation_printer(aod_ycbcr_l2,'AODNet_2L_YCbCr',test_sample_loader_ycbcr,'YCBCR')
evaluation_printer(aod_ycbcr_3l,'AODNet_3L_YCbCr',test_sample_loader_ycbcr,'YCBCR')
evaluation_printer(dehazenet_ycbcr,'Dehazenet_2L_YCbCr',test_sample_loader_ycbcr,'YCBCR')
evaluation_printer(complexnet_ycbcr,'ComplexNet_3L_YCbCr',test_sample_loader_ycbcr,'YCBCR')


----------------Dehazenet_3L_RGB_FFT-----------------------------
Avg Time for Dehazenet_3L_RGB_FFT: 0.04374547004699707
Avg PSNR Data for Dehazenet_3L_RGB_FFT Raw: 13.652524948120117
Avg SSIM Data for Dehazenet_3L_RGB_FFT Raw: 0.6173181533813477
Avg PSNR Data for Dehazenet_3L_RGB_FFT Processed: 14.763494491577148
Avg SSIM Data for Dehazenet_3L_RGB_FFT Processed: 0.6328158974647522


----------------DWT_Dehazenet_RGB_FFT-----------------------------
Avg Time for DWT_Dehazenet_RGB_FFT: 0.04293990135192871
Avg PSNR Data for DWT_Dehazenet_RGB_FFT Raw: 13.52770709991455
Avg SSIM Data for DWT_Dehazenet_RGB_FFT Raw: 0.6198670268058777
Avg PSNR Data for DWT_Dehazenet_RGB_FFT Processed: 14.515462875366211
Avg SSIM Data for DWT_Dehazenet_RGB_FFT Processed: 0.6338726282119751


----------------AODNet_Pretrained-----------------------------
Avg Time for AODNet_Pretrained: 0.3241743564605713
Avg PSNR Data for AODNet_Pretrained Raw: 12.435722351074219
Avg SSIM Data for AODNet_Pretrained Raw: 0.4474

In [18]:
evaluation_printer(dehazenet_rgb_fft,'Dehazenet_3L_RGB_FFT_Test',test_loader_rgb,'RGB','Test')
evaluation_printer(dwt_dehazenet_rgb_fft,'DWT_Dehazenet_RGB_FFT_Test',test_loader_rgb,'RGB','Test')
evaluation_printer(aod_pretrained,'AODNet_Pretrained_Test',test_loader_rgb,'RGB','Test')
evaluation_printer(dwt_dehazenet_rgb,'DWT_Dehazenet_RGB_Test',test_loader_rgb,'RGB','Test')
evaluation_printer(dehazenet_rgb,'Dehazenet_2L_RGB_Test',test_loader_rgb,'RGB','Test')
evaluation_printer(dehazenet_pono_fft,'PONO_Dehazenet_RGB_FFT_Test',test_loader_rgb,'RGB','Test')
evaluation_printer(complexnet_rgb,'ComplexNet_3L_RGB_Test',test_loader_rgb,'RGB','Test')


evaluation_printer(dwt_dehazenet_ycbcr_3l,'DWT_Dehazenet_YCbCr_3L_Test',test_loader_ycbcr,'YCBCR','Test')
evaluation_printer(dwt_dehazenet_ycbcr,'DWT_Dehazenet_YCbCr_Test',test_loader_ycbcr,'YCBCR','Test')
evaluation_printer(aod_ycbcr_mse,'AODNet_MSE_YCbCr_Test',test_loader_ycbcr,'YCBCR','Test')
evaluation_printer(aod_ycbcr_l2,'AODNet_2L_YCbCr_Test',test_loader_ycbcr,'YCBCR','Test')
evaluation_printer(aod_ycbcr_3l,'AODNet_3L_YCbCr_Test',test_loader_ycbcr,'YCBCR','Test')
evaluation_printer(dehazenet_ycbcr,'Dehazenet_2L_YCbCr_Test',test_loader_ycbcr,'YCBCR','Test')
evaluation_printer(complexnet_ycbcr,'ComplexNet_3L_YCbCr_Test',test_loader_ycbcr,'YCBCR','Test')


----------------Dehazenet_3L_RGB_FFT_Test-----------------------------
Avg Time for Dehazenet_3L_RGB_FFT_Test: 0.027318309820615328
Avg PSNR Data for Dehazenet_3L_RGB_FFT_Test Raw: 12.593730926513672
Avg SSIM Data for Dehazenet_3L_RGB_FFT_Test Raw: 0.6547240614891052
Avg PSNR Data for Dehazenet_3L_RGB_FFT_Test Processed: 12.936782836914062
Avg SSIM Data for Dehazenet_3L_RGB_FFT_Test Processed: 0.6025488376617432


----------------DWT_Dehazenet_RGB_FFT_Test-----------------------------
Avg Time for DWT_Dehazenet_RGB_FFT_Test: 0.03966274123925429
Avg PSNR Data for DWT_Dehazenet_RGB_FFT_Test Raw: 12.429469108581543
Avg SSIM Data for DWT_Dehazenet_RGB_FFT_Test Raw: 0.6580450534820557
Avg PSNR Data for DWT_Dehazenet_RGB_FFT_Test Processed: 12.837409019470215
Avg SSIM Data for DWT_Dehazenet_RGB_FFT_Test Processed: 0.6074617505073547


----------------AODNet_Pretrained_Test-----------------------------
Avg Time for AODNet_Pretrained_Test: 0.3293459919782785
Avg PSNR Data for AODNet_Pretrained

## Image weighted addition

In [19]:
def alpha_blending_fn(image1,image2,alpha=0.6):
    blended=cv2.addWeighted(image1, alpha, image2, 1 - alpha, 4)
    return blended

def image_addition(coeff,img_path1,img_path2,clear_path):
    img1=cv2.cvtColor(cv2.imread(img_path1),cv2.COLOR_BGR2RGB)
    img2=cv2.cvtColor(cv2.imread(img_path2),cv2.COLOR_BGR2RGB)
    clear_img=tt.ToTensor()(cv2.cvtColor(cv2.imread(clear_path),cv2.COLOR_BGR2RGB))
    img_f=tt.ToTensor()(alpha_blending_fn(img1,img2))    
    psnr_value=psnr_fn(img_f,clear_img)
    ssim_value=ssim_fn(img_f.unsqueeze(0),clear_img.unsqueeze(0))
    return psnr_value,ssim_value,img_f
def list_avg(li):
    return sum(li)/len(li)
def addn_exp(exp_name1,exp_name2):
    coeffs=[i/10 for i in range(10)]
    coeff_dict={}
    path1=f'/kaggle/working/{exp_name1}/'
    path2=f'/kaggle/working/{exp_name2}/'
    for coeff in coeffs:
        psnr_data_proc,ssim_data_proc=[],[]
        psnr_data_raw,ssim_data_raw=[],[]
        psnr_data_c1,ssim_data_c1=[],[]
        psnr_data_c2,ssim_data_c2=[],[]
        for i in range(len(glob(path1+'*.png'))//4):
            output_img_path1=path1+f'output_image_{i}.png'
            clear_img_path=path1+f'clear_image_{i}.png'
            processed_img_path1=path1+f'processed_image_{i}.png'
            output_img_path2=path2+f'output_image_{i}.png'
            processed_img_path2=path2+f'processed_image_{i}.png'
            
            raw_metrics=image_addition(coeff,output_img_path1,output_img_path2,clear_img_path)
            proc_metrics=image_addition(coeff,processed_img_path1,processed_img_path2,clear_img_path)
            c1_metrics=image_addition(coeff,processed_img_path1,output_img_path2,clear_img_path)
            c2_metrics=image_addition(coeff,output_img_path1,processed_img_path2,clear_img_path)
            
            psnr_data_proc.append(proc_metrics[0])
            ssim_data_proc.append(proc_metrics[1])
            psnr_data_raw.append(raw_metrics[0])
            ssim_data_raw.append(raw_metrics[1])
            psnr_data_c1.append(c1_metrics[0])
            ssim_data_c1.append(c1_metrics[1])
            psnr_data_c2.append(c2_metrics[0])
            ssim_data_c2.append(c2_metrics[1])
        coeff_dict[coeff]=[[list_avg(psnr_data_raw),list_avg(ssim_data_raw),],
                           [list_avg(psnr_data_proc),list_avg(ssim_data_proc)],
                          [list_avg(psnr_data_c1),list_avg(ssim_data_c1)],
                          [list_avg(psnr_data_c2),list_avg(ssim_data_c2)]]
        gc.collect()
    return coeff_dict

#addn_exp('Dehazenet_2L_RGB_Test','DWT_Dehazenet_RGB_Test')
        
            
            

In [20]:
def save_image_final(img_tensor,file_path):
    if img_tensor.shape[0]!=3:
        raise ValueError("Input tensor must have 3 channels only...")
    rgb_array=img_tensor.permute(1,2,0).cpu().detach().numpy()
    rgb_image=Image.fromarray((np.clip(rgb_array,0,1)*255).astype(np.uint8),mode='RGB')
    rgb_image.save(file_path)

## Performing Alpha Blending

In [21]:
def final_merge(exp_name1,exp_name2,coeff=0.9):
    os.makedirs(f'/kaggle/working/{exp_name1+exp_name2}',exist_ok=True)
    path1=f'/kaggle/working/{exp_name1}/'
    path2=f'/kaggle/working/{exp_name2}/'
    final_psnr,final_ssim=[],[]
    for i in range(len(glob(path1+'*.png'))//4):
            output_img_path1=path1+f'output_image_{i}.png'
            clear_img_path=path1+f'clear_image_{i}.png'
            processed_img_path1=path1+f'processed_image_{i}.png'
            output_img_path2=path2+f'output_image_{i}.png'
            processed_img_path2=path2+f'processed_image_{i}.png'
            exp_outputs=image_addition(coeff,processed_img_path1,processed_img_path2,clear_img_path)
            final_psnr.append(exp_outputs[0])
            final_ssim.append(exp_outputs[1])
            clear_img_tensor=tt.ToTensor()(cv2.cvtColor(cv2.imread(clear_img_path),cv2.COLOR_BGR2RGB))
            save_image_final(exp_outputs[2].squeeze(),f'/kaggle/working/{exp_name1+exp_name2}/merged_image_{i}.png')
            save_image_final(clear_img_tensor.squeeze(),f'/kaggle/working/{exp_name1+exp_name2}/clear_image_{i}.png')
            shutil.make_archive(f'/kaggle/working/{exp_name1+exp_name2}','zip',f'/kaggle/working/{exp_name1+exp_name2}')
            gc.collect()
    print(list_avg(final_psnr),list_avg(final_ssim))

final_merge('Dehazenet_2L_RGB_Test','DWT_Dehazenet_RGB_Test',0.6)
final_merge('Dehazenet_2L_RGB','DWT_Dehazenet_RGB',0.6)

    

tensor(15.1650) tensor(0.7176)
tensor(15.1012) tensor(0.6353)


## Postprocessing after Blending

In [22]:
def enhance_image_merged(image_path):
    image=cv2.imread(image_path)
    img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img2=cv2.fastNlMeansDenoisingColored(img,None,10,10,3,21)
    image_sharpened=unsharp_mask(img2)
    image_clahe=clahe(image_sharpened)
    image_tensor=tt.ToTensor()(image_clahe)   
    return image_tensor
def final_image_processor(img_path,clear_img_path):
    img_tensor=enhance_image_merged(img_path)
    clear_img=tt.ToTensor()(cv2.cvtColor(cv2.imread(clear_img_path),cv2.COLOR_BGR2RGB))
    psnr_value=psnr_fn(img_tensor,clear_img)
    ssim_value=ssim_fn(img_tensor.unsqueeze(0),clear_img.unsqueeze(0))
    return psnr_value,ssim_value,img_tensor

def proc_final_merge(exp_name,proc_path):
    os.makedirs(f'/kaggle/working/{exp_name}',exist_ok=True)
    path1=f'/kaggle/working/{proc_path}/'
    final_psnr,final_ssim=[],[]
    for i in range(len(glob(path1+'*.png'))//2):
            inp_img_path1=path1+f'merged_image_{i}.png'
            clear_img_path=path1+f'clear_image_{i}.png'
            exp_outputs=final_image_processor(inp_img_path1,clear_img_path)
            final_psnr.append(exp_outputs[0])
            final_ssim.append(exp_outputs[1])
            clear_img_tensor=tt.ToTensor()(cv2.cvtColor(cv2.imread(clear_img_path),cv2.COLOR_BGR2RGB))
            save_image_final(exp_outputs[2].squeeze(),f'/kaggle/working/{exp_name}/final_merged_image_{i}.png')
            save_image_final(clear_img_tensor.squeeze(),f'/kaggle/working/{exp_name}/clear_image_{i}.png')
            gc.collect()
    shutil.make_archive(f'/kaggle/working/{exp_name}','zip',f'/kaggle/working/{exp_name}')
    print(list_avg(final_psnr),list_avg(final_ssim))
proc_final_merge('Final_Processed_Test','Dehazenet_2L_RGB_TestDWT_Dehazenet_RGB_Test')
proc_final_merge('Final_Processed','Dehazenet_2L_RGBDWT_Dehazenet_RGB')
    

tensor(15.8543) tensor(0.6613)
tensor(15.3281) tensor(0.5469)
