### Single Image Dehazing with Convolutional Vision Transformer (CVT)
This project implements a Single Image Dehazing model using a Convolutional Vision Transformer (CVT). The model is designed to remove haze from images, thereby enhancing the clarity and quality of the visual content.

### Mustafa Shabazi Dill - TMU
- [Github]([link](https://github.com/mrdjango/CVT-Dehazer.git))

### Dataset: https://drive.google.com/file/d/1yaY_trqGn-SNoy7mrX040KUU7xhTPvas/view?usp=drive_link

Importing the necessary libraries

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
  from einops import rearrange
except:
  !pip install einops
  from einops import rearrange
import os
import pandas as pd
import cv2
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.optim as optim
import math
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

In [33]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## unzip the dataset

In [34]:
!unzip /content/drive/MyDrive/dehaze_train.zip -d data/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/train/hazy/19.jpg  
  inflating: data/train/hazy/190.jpg  
  inflating: data/train/hazy/1900.jpg  
  inflating: data/train/hazy/1901.jpg  
  inflating: data/train/hazy/1902.jpg  
  inflating: data/train/hazy/1903.jpg  
  inflating: data/train/hazy/1904.jpg  
  inflating: data/train/hazy/1905.jpg  
  inflating: data/train/hazy/1906.jpg  
  inflating: data/train/hazy/1907.jpg  
  inflating: data/train/hazy/1908.jpg  
  inflating: data/train/hazy/1909.jpg  
  inflating: data/train/hazy/191.jpg  
  inflating: data/train/hazy/1910.jpg  
  inflating: data/train/hazy/1911.jpg  
  inflating: data/train/hazy/1912.jpg  
  inflating: data/train/hazy/1913.jpg  
  inflating: data/train/hazy/1914.jpg  
  inflating: data/train/hazy/1915.jpg  
  inflating: data/train/hazy/1916.jpg  
  inflating: data/train/hazy/1917.jpg  
  inflating: data/train/hazy/1918.jpg  
  inflating: data/train/hazy/1919.jpg  
  inflating: data/t

## Implementing the model and blocks of the CVT

In [47]:
class ConvolutionalAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=False)
        self.proj = nn.Conv2d(dim, dim, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(x).reshape(B, 3, self.num_heads, C // self.num_heads, H, W)
        q, k, v = qkv.unbind(1)

        attn = (q.transpose(-2, -1) @ k).softmax(dim=-1)
        x = (attn @ v.transpose(-2, -1)).transpose(-2, -1)
        x = x.reshape(B, C, H, W)
        x = self.proj(x)
        return x

class CVTBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.):
        super().__init__()
        self.norm1 = nn.BatchNorm2d(dim)
        self.attn = ConvolutionalAttention(dim, num_heads)
        self.norm2 = nn.BatchNorm2d(dim)
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, int(dim * mlp_ratio), 1),
            nn.GELU(),
            nn.Conv2d(int(dim * mlp_ratio), dim, 1)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class CVTLayer(nn.Module):
    def __init__(self, dim, depth, num_heads, mlp_ratio=4.):
        super().__init__()
        self.blocks = nn.ModuleList([
            CVTBlock(dim, num_heads, mlp_ratio) for _ in range(depth)
        ])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x

class PatchEmbed(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        return x

class SimpleCVT(nn.Module):
    def __init__(self, img_size=256, patch_size=4, embed_dim=96, depth=3, num_heads=4, mlp_ratio=4.):
        super().__init__()
        self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
        self.encoder = CVTLayer(dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio)
        self.decoder = nn.ConvTranspose2d(embed_dim, 3, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.encoder(x)
        x = self.decoder(x)
        return x


#### data augmentation

In [48]:
def read_img(path):
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32) / 255.0
    return img

def augment(imgs, size=256, edge_decay=0., only_h_flip=False):
    H, W, _ = imgs[0].shape
    Hc, Wc = [size, size]

    if random.random() < Hc / H * edge_decay:
        Hs = 0 if random.randint(0, 1) == 0 else H - Hc
    else:
        Hs = random.randint(0, H - Hc)

    if random.random() < Wc / W * edge_decay:
        Ws = 0 if random.randint(0, 1) == 0 else W - Wc
    else:
        Ws = random.randint(0, W - Wc)

    for i in range(len(imgs)):
        imgs[i] = imgs[i][Hs:(Hs + Hc), Ws:(Ws + Wc), :].copy()

    if random.randint(0, 1) == 1:
        for i in range(len(imgs)):
            imgs[i] = np.flip(imgs[i], axis=1).copy()

    if not only_h_flip:
        rot_deg = random.randint(0, 3)
        for i in range(len(imgs)):
            imgs[i] = np.rot90(imgs[i], rot_deg, (0, 1)).copy()

    return imgs


#### PSNR and SSIM metrics

In [50]:
def calculate_psnr(img1, img2):
    psnr = peak_signal_noise_ratio(img1.cpu().detach().numpy(), img2.cpu().detach().numpy(), data_range=1.0) # Call detach() before converting to NumPy array to avoid the error
    return psnr

def calculate_ssim(img1, img2):
    ssim = structural_similarity(img1.cpu().detach().numpy().transpose(1, 2, 0), img2.cpu().detach().numpy().transpose(1, 2, 0), multichannel=True, data_range=1.0) # Call detach() before converting to NumPy array to avoid the error
    return ssim

#### Training the model

In [59]:
# Hyperparameters
img_size = 256
batch_size = 4
learning_rate = 1e-4
num_epochs = 10

# Dataset and DataLoader
data_dir = 'data/'
# PairLoader: is a colab built-in class that loads paired images from a directory
train_dataset = PairLoader(data_dir, mode='train', size=img_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Model, Loss, Optimizer
model = SimpleCVT(img_size=img_size).cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Lists to store metrics
epoch_list = []
loss_list = []
psnr_list = []
ssim_list = []

In [62]:
# Training Loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    psnr_total = 0.0
    ssim_total = 0.0
    num_batches = 0

    for i, data in enumerate(train_loader):
        inputs, targets = data
        inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        psnr = calculate_psnr(outputs[0], targets[0])
        ssim = calculate_ssim(outputs[0], targets[0])
        psnr_total += psnr
        ssim_total += ssim
        num_batches += 1

        if i % 10 == 9:  # print every 10 mini-batches
            avg_loss = running_loss / 10
            avg_psnr = psnr_total / 10
            avg_ssim = ssim_total / 10
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}')
            running_loss = 0.0
            psnr_total = 0.0
            ssim_total = 0.0

    # Calculate average PSNR and SSIM for the epoch
    avg_loss_epoch = running_loss / num_batches
    avg_psnr_epoch = psnr_total / num_batches
    avg_ssim_epoch = ssim_total / num_batches

    # Store metrics
    epoch_list.append(epoch + 1)
    loss_list.append(avg_loss_epoch)
    psnr_list.append(avg_psnr_epoch)
    ssim_list.append(avg_ssim_epoch)

    # # Save metrics for the epoch
    # with open(os.path.join(results_dir, f'epoch_{epoch + 1}.txt'), 'w') as f:
    #     f.write(f'Epoch: {epoch + 1}, Loss: {avg_loss_epoch:.4f}, PSNR: {avg_psnr_epoch:.4f}, SSIM: {avg_ssim_epoch:.4f}\n')

    # # Save model checkpoint
    # torch.save(model.state_dict(), os.path.join(results_dir, f'cvt_epoch_{epoch + 1}.pth'))

print('Finished Training')

  ssim = structural_similarity(img1.cpu().detach().numpy().transpose(1, 2, 0), img2.cpu().detach().numpy().transpose(1, 2, 0), multichannel=True, data_range=1.0) # Call detach() before converting to NumPy array to avoid the error


Epoch [1/10], Step [10/1500], Loss: 0.4977, PSNR: 3.5083, SSIM: 0.0018
Epoch [1/10], Step [20/1500], Loss: 0.3380, PSNR: 4.7677, SSIM: 0.0076
Epoch [1/10], Step [30/1500], Loss: 0.2577, PSNR: 6.1430, SSIM: 0.0098
Epoch [1/10], Step [40/1500], Loss: 0.1949, PSNR: 7.6146, SSIM: 0.0165
Epoch [1/10], Step [50/1500], Loss: 0.1714, PSNR: 7.8512, SSIM: 0.0232
Epoch [1/10], Step [60/1500], Loss: 0.1353, PSNR: 9.9113, SSIM: 0.0310
Epoch [1/10], Step [70/1500], Loss: 0.1274, PSNR: 8.7124, SSIM: 0.0323
Epoch [1/10], Step [80/1500], Loss: 0.1047, PSNR: 9.2663, SSIM: 0.0351
Epoch [1/10], Step [90/1500], Loss: 0.0972, PSNR: 10.7135, SSIM: 0.0414
Epoch [1/10], Step [100/1500], Loss: 0.0905, PSNR: 11.1125, SSIM: 0.0606
Epoch [1/10], Step [110/1500], Loss: 0.0962, PSNR: 10.7622, SSIM: 0.0711
Epoch [1/10], Step [120/1500], Loss: 0.0767, PSNR: 13.3496, SSIM: 0.0970
Epoch [1/10], Step [130/1500], Loss: 0.0635, PSNR: 12.4351, SSIM: 0.0713
Epoch [1/10], Step [140/1500], Loss: 0.0702, PSNR: 12.3006, SSIM: 0.

In [63]:
# Create a DataFrame to store the metrics
metrics_df = pd.DataFrame({
    'Epoch': epoch_list,
    'Loss': loss_list,
    'PSNR': psnr_list,
    'SSIM': ssim_list
})

# Save the DataFrame to a CSV file
metrics_df.to_csv(os.path.join("", 'training_metrics.csv'), index=False)


### Testing the model with a sample image

In [65]:
# Load the trained model
# model = SimpleCVT(img_size=img_size).cuda()
# model.load_state_dict(torch.load('results/cvt_epoch_10.pth'))
model.eval()

def preprocess_image(image_path, img_size):
    img = cv2.imread(image_path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (img_size, img_size))
    img = img.astype(np.float32) / 255.0
    # Convert to tensor and add batch dimension
    img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).cuda()
    return img

def postprocess_and_save_image(output_tensor, output_path):
    output_image = output_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
    output_image = (output_image * 255).astype(np.uint8)
    output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(output_path, output_image)


In [66]:
# Preprocess the input image
input_image_path = 'data/train/hazy/8.jpg'
input_image = preprocess_image(input_image_path, img_size)

# Run the model for dehazing
with torch.no_grad():
    output_image = model(input_image)

# Save the dehazed image
output_image_path = 'dehazed_image.jpg'
postprocess_and_save_image(output_image, output_image_path)

print('Dehazing completed and image saved!')


Dehazing completed and image saved!
