# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torchvision.models import vgg16
import torchvision.models as models
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import glob
import random
import numbers
import cv2
from torch.autograd import Variable
import torch.utils.data as data
from torch.utils.data import random_split
import skimage
data_dir = '/kaggle/input/euvp-dataset/EUVP Dataset/Paired'

In [2]:
output_dir = 'output_images'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Model

In [4]:
class UnderwaterCNN(nn.Module):
    def __init__(self):
        super(UnderwaterCNN, self).__init__()
        
        # First stage - Dehazing
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU(inplace=True)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        
        # Second stage - Color Correction
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(32)
        self.relu5 = nn.ReLU(inplace=True)
        self.conv6 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(64)
        self.relu6 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv7 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.relu7 = nn.ReLU(inplace=True)
        self.conv8 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(64)
        self.relu8 = nn.ReLU(inplace=True)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        
        # Third stage - Image Enhancement
        self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(32)
        self.relu9 = nn.ReLU(inplace=True)
        self.conv10 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn10 = nn.BatchNorm2d(64)
        self.relu10 = nn.ReLU(inplace=True)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv11 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(128)
        self.relu11 = nn.ReLU(inplace=True)
        self.conv12 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(64)
        self.relu12 = nn.ReLU(inplace=True)
        self.conv13 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn13 = nn.BatchNorm2d(128)
        self.relu13 = nn.ReLU(inplace=True)
        self.conv14 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.bn14 = nn.BatchNorm2d(64)
        self.relu14 = nn.ReLU(inplace=True)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')
        
        # Output
        self.conv15 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
           
        # First stage - Dehazing
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool1(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.upsample1(x)
        
        # Second stage - Color Correction
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu5(x)
        x = self.conv6(x)
        x = self.bn6(x)
        x = self.relu6(x)
        x = self.maxpool2(x)
        x = self.conv7(x)
        x = self.bn7(x)
        x = self.relu7(x)
        x = self.conv8(x)
        x = self.bn8(x)
        x = self.relu8(x)
        x = self.upsample2(x)
        
        # Third stage - Image Enhancement
        x = self.conv9(x)
        x = self.bn9(x)
        x = self.relu9(x)
        x = self.conv10(x)
        x = self.bn10(x)
        x = self.relu10(x)
        x = self.maxpool3(x)
        x = self.conv11(x)
        x = self.bn11(x)
        x = self.relu11(x)
        x = self.conv12(x)
        x = self.bn12(x)
        x = self.relu12(x)
        x = self.conv13(x)
        x = self.bn13(x)
        x = self.relu13(x)
        x = self.conv14(x)
        x = self.bn14(x)
        x = self.relu14(x)
        x = self.upsample3(x)
        # Output
        x = self.conv15(x)
        x = self.sigmoid(x)
        
        return x

# Dataset Processing

In [5]:
class ToTensor(object):
    def __call__(self, sample):
        hazy_image, clean_image = sample['hazy'], sample['clean']
        hazy_image = np.array(hazy_image)
        hazy_image = torch.from_numpy(hazy_image.astype(np.float32))
        hazy_image = torch.transpose(torch.transpose(hazy_image, 2, 0), 1, 2)
        clean_image = np.array(clean_image)
        clean_image = torch.from_numpy(clean_image.astype(np.float32))
        clean_image = torch.transpose(torch.transpose(clean_image, 2, 0), 1, 2)
        return {'hazy': hazy_image,
                'clean': clean_image}



class EUVP_Dataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.filesA, self.filesB = self.get_file_paths(self.data_dir)
        self.len = min(len(self.filesA), len(self.filesB))
        self.transform = transform
      
    def __len__(self):
        return self.len

    def __getitem__(self, index):
        hazy_im = cv2.resize(cv2.imread(self.filesA[index % self.len]), (256,256),
                                 interpolation=cv2.INTER_AREA)

        hazy_im = hazy_im[:, :, ::-1]    
        hazy_im = np.float32(hazy_im) / 255.0


        clean_im = cv2.resize(cv2.imread(self.filesB[index % self.len]), (256,256),
                                  interpolation=cv2.INTER_AREA)

        clean_im = clean_im[:, :, ::-1]  
        clean_im = np.float32(clean_im) / 255.0

        sample = {'hazy': hazy_im, 
                  'clean': clean_im}    
        if self.transform != None:
            sample = self.transform(sample)
    
        return sample


    def get_file_paths(self, data_dir):
        sub_dirs = ['underwater_imagenet', 'underwater_dark', 'underwater_scenes']
        filesA, filesB = [], []
        for sd in sub_dirs:
            filesA += sorted(glob.glob(os.path.join(data_dir, sd, 'trainA') + "/*.*"))
            filesB += sorted(glob.glob(os.path.join(data_dir, sd, 'trainB') + "/*.*"))
        return filesA, filesB 

# Training

In [7]:
def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    train_loss = 0
    for i, data in enumerate(dataloader, 0):
        input_img, target = data['input_img'].to(device), data['target'].to(device)
        optimizer.zero_grad()
        output = model(input_img)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
    avg_loss = train_loss / len(dataloader)
    return avg_loss

# Testing

In [8]:
def test_model(model, dataloader, device):
    model.eval()
    ssim_score = 0
    psnr_score = 0
    with torch.no_grad():
        for i, data in enumerate(dataloader, 0):
            input_img, target = data['input_img'].to(device), data['target'].to(device)
            output = model(input_img)
            ssim_score += ssim(output, target, data_range=1.0, size_average=False)
            psnr_score += 10 * log10(1.0 / mean_squared_error(output, target))
    
    avg_ssim = ssim_score / len(dataloader)
    avg_psnr = psnr_score / len(dataloader)
    
    return avg_ssim, avg_psnr

# Main Function

In [9]:
def main(data_dir):
    epochs = 30
    batch_size = 16
    learning_rate = 0.00001
    

    transform = transforms.Compose([ToTensor()])
    data = EUVP_Dataset(data_dir, transform=transform)
    train_data, val_data = random_split(data, [int(0.8 * len(data)), int(0.2 * len(data))])
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UnderwaterCNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.MSELoss()
    best_psnr = 0
    best_ssim = 0
    for epoch in range(epochs):
        model.train()
        for i, sample in enumerate(train_dataloader):
            hazy = sample['hazy'].to(device)
            clean = sample['clean'].to(device)
            output = model(hazy)

            loss = criterion(output, clean)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        
    model.eval()  
    val_loss = 0
    with torch.no_grad():
        avg_psnr = 0
        avg_ssim = 0
        epoch_dir = os.path.join(output_dir, f'epoch_{epoch}')
        os.makedirs(epoch_dir, exist_ok=True)
        for j, sample in enumerate(val_dataloader):
            hazy = sample['hazy'].to(device)
            clean = sample['clean'].to(device)
            output = model(hazy)

            for k in range(output.shape[0]):
                output_img = output[k].cpu().numpy().transpose(1, 2, 0)
                output_img = np.clip(output_img, 0, 1)
                output_path = os.path.join(epoch_dir, f'batch_{j}_output_{k}.png')
                skimage.io.imsave(output_path, output_img)

                input_img = hazy[k].cpu().numpy().transpose(1, 2, 0)
                input_img = np.clip(input_img, 0, 1)
                input_path = os.path.join(output_dir, f'batch_{j}_input_{k}.png')
                skimage.io.imsave(input_path, input_img)

            val_loss += criterion(output, clean).item()
            avg_psnr += compare_psnr(clean.cpu().numpy(), output.cpu().numpy())
            avg_ssim += compare_ssim(clean.cpu().numpy(), output.cpu().numpy(), multichannel=True, win_size=3)


        avg_psnr = avg_psnr / len(val_dataloader)
        avg_ssim = avg_ssim / len(val_dataloader)
        val_loss = val_loss / len(val_dataloader)

    if avg_psnr > best_psnr:
        best_psnr = avg_psnr
        torch.save(model.state_dict(), 'best_psnr.pth')
    if avg_ssim > best_ssim:
        best_ssim = avg_ssim
        torch.save(model.state_dict(), 'best_ssim.pth')

    print("Epoch: {}, Training Loss: {:.4f}, Validation Loss: {:.4f}, PSNR: {:.4f}, SSIM: {:.4f}".format(epoch+1, loss.item(), val_loss, avg_psnr, avg_ssim))


if __name__ == '__main__':
  main(data_dir)



Epoch: 30, Training Loss: 0.0139, Validation Loss: 0.0102, PSNR: 20.0025, SSIM: 0.8552


In [10]:
!zip -r file.zip /kaggle/working/

  adding: kaggle/working/ (stored 0%)
  adding: kaggle/working/.virtual_documents/ (stored 0%)
  adding: kaggle/working/best_ssim.pth (deflated 9%)
  adding: kaggle/working/output_images/ (stored 0%)
  adding: kaggle/working/output_images/batch_16_input_2.png (deflated 0%)
  adding: kaggle/working/output_images/batch_129_input_13.png (deflated 0%)
  adding: kaggle/working/output_images/batch_119_input_9.png (deflated 0%)
  adding: kaggle/working/output_images/batch_61_input_7.png (deflated 0%)
  adding: kaggle/working/output_images/batch_92_input_2.png (deflated 0%)
  adding: kaggle/working/output_images/batch_83_input_6.png (deflated 0%)
  adding: kaggle/working/output_images/batch_78_input_15.png (deflated 0%)
  adding: kaggle/working/output_images/batch_107_input_14.png (deflated 0%)
  adding: kaggle/working/output_images/batch_105_input_15.png (deflated 0%)
  adding: kaggle/working/output_images/batch_89_input_10.png (deflated 0%)
  adding: kaggle/working/output_images/batch_102_in

In [11]:
os.chdir(r'/kaggle/working') 

In [12]:
from IPython.display import FileLink

In [13]:
FileLink(r'file.zip')