# Enhancer CNN training 
## Training a custom CNN architecture

In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from glob import glob
from tqdm import tqdm
import cv2
from torchmetrics.image import PeakSignalNoiseRatio,StructuralSimilarityIndexMeasure

torch.manual_seed(42)
random.seed(42)

input_folder_list=list(glob('/kaggle/input/combination-dataset/Dehazenet_2L_RGB_TestDWT_Dehazenet_RGB_Test/*.png'))
target_folder_list=list(glob('/kaggle/input/combination-dataset/DWT_Dehazenet_RGB_Test/clear_image_*.png'))
target_folder_list.sort()
input_folder_list.sort()
batch_size=16
num_epochs=40
learning_rate=1e-4

class DehazingDataset(Dataset):
    def __init__(self,input_folder,target_folder,transform=None):
        self.input_files=[file for file in input_folder_list]
        self.target_files=[file for file in target_folder_list]
        self.transform=transform

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_img=Image.open(self.input_files[idx]).convert('RGB')
        target_img=Image.open(self.target_files[idx]).convert('RGB')
        if self.transform:
            input_img=self.transform(input_img)
            target_img=self.transform(target_img)
        return input_img, target_img
    
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=1,padding=1)
        self.conv2=nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1)
        self.conv3=nn.Conv2d(in_channels=32,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.relu1=nn.ReLU()
        self.relu2=nn.ReLU()
        self.relu3=nn.ReLU()

    def forward(self,x):
        x=self.relu1(self.conv1(x))
        x=self.relu2(self.conv2(x))
        x=self.relu3(self.conv3(x))
        return x

transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

dataset=DehazingDataset(input_folder_list[:int(0.1*len(input_folder_list))],
                        target_folder_list[:int(0.1*len(target_folder_list))],transform=transform)
dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)

model=SimpleCNN()
criterion=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr=learning_rate)

for epoch in tqdm(range(num_epochs)):
    running_loss=0.0
    for i,data in enumerate(dataloader,0):
        inputs,targets=data
        optimizer.zero_grad()
        outputs=model(inputs)
        loss=criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        if (i+1) % 10==0:
            running_loss=0.0

print('Finished Training...')
torch.save(model.state_dict(),'end_cnn.pth')


100%|██████████| 40/40 [33:01<00:00, 49.53s/it]

Finished Training





## Enhancer CNN Inference 

In [2]:
psnr_fn,ssim_fn=PeakSignalNoiseRatio(),StructuralSimilarityIndexMeasure()
psnr_li,ssim_li=[],[]
for inp,clear in zip(input_folder_list[int(0.1*len(input_folder_list)):],target_folder_list[int(0.1*len(target_folder_list)):]):
    clear_image_path=clear
    merged_image_path=inp
    clear=transforms.ToTensor()(cv2.cvtColor(cv2.imread(clear_image_path),cv2.COLOR_BGR2RGB))
    proc=transforms.ToTensor()(cv2.cvtColor(cv2.imread(merged_image_path),cv2.COLOR_BGR2RGB))
    model_proc=model(proc.unsqueeze(0))
    psnr_value=psnr_fn(model_proc,clear.unsqueeze(0))
    ssim_value=ssim_fn(model_proc,clear.unsqueeze(0))
    psnr_li.append(psnr_value)
    ssim_li.append(ssim_value)
print(f"PSNR: {sum(psnr_li)/len(psnr_li):.4f}, SSIM: {sum(ssim_li)/len(ssim_li):.4f}")

PSNR: 18.8249, SSIM: 0.7565
