In [None]:
!wget https://vedas.sac.gov.in/static/pdf/SIH_2022/SS594_Multispectral_Dehazing.zip
!unzip SS594_Multispectral_Dehazing.zip
!mv SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/valid SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/val
!mv SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thick/dataset/valid SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thick/dataset/val

In [3]:
!pip install torch torchvision

Collecting torch
  Downloading torch-1.10.2-cp36-cp36m-manylinux1_x86_64.whl (881.9 MB)
     |██████████████                  | 386.3 MB 136.6 MB/s eta 0:00:04

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



     |███████████████████████         | 635.4 MB 137.6 MB/s eta 0:00:02

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



     |█████████████████████████████▏  | 803.2 MB 126.2 MB/s eta 0:00:01

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



     |████████████████████████████████| 881.9 MB 11 kB/s               
[?25hCollecting torchvision
  Downloading torchvision-0.11.2-cp36-cp36m-manylinux1_x86_64.whl (23.3 MB)
     |████████████████████████████████| 23.3 MB 39 kB/s              
Collecting torch
  Downloading torch-1.10.1-cp36-cp36m-manylinux1_x86_64.whl (881.9 MB)
     |████████████████▋               | 458.6 MB 82.2 MB/s eta 0:00:06 

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



     |███████████████████████████▉    | 766.1 MB 123.4 MB/s eta 0:00:01

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



     |████████████████████████████████| 881.9 MB 13 kB/s               
Installing collected packages: torch, torchvision
Successfully installed torch-1.10.1 torchvision-0.11.2


In [1]:
# import cv2
from PIL.Image import Image

import torch
import torchvision.transforms as transforms

def Preprocess(image: Image) -> torch.Tensor:
    # Contrast Enhancement
    transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
        # transforms.functional.equalize
    ])
    transformedImage = transform(image)

    # Gamma Correction
    gammaCorrectedImage = transforms.functional.adjust_gamma(transformedImage, 2.2)

    # Histogram Stretching
    min_val = gammaCorrectedImage.min()
    max_val = gammaCorrectedImage.max()
    stretchedImage = (gammaCorrectedImage - min_val) / (max_val - min_val)

    # for x in stretchedImage:
    #     for y in x:
    #         print(y)

    # Guided Filtering
    # gFilter = cv2.ximgproc.createGuidedFilter(guide=stretchedImage.permute(1, 2, 0).numpy(), radius=3, eps=0.01)
    # filteredImage = gFilter.filter(src=stretchedImage.permute(1, 2, 0).numpy())
    # return torch.from_numpy(stretchedImage).permute(2, 0, 1)
    return stretchedImage

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.has_mps else 'cpu')
device

device(type='cuda')

In [3]:
import os
import pathlib
from enum import Enum

from torch.utils.data import Dataset
from PIL import Image

class DatasetType(Enum):
    Train = 0,
    Test = 1,
    Validation = 2

    def ToString(self) -> str:
        if self == DatasetType.Train:
            return 'train'
        elif self == DatasetType.Test:
            return 'test'
        elif self == DatasetType.Validation:
            return 'val'

class DehazingDataset(Dataset):
    def __init__(self, dehazingDatasetPath: pathlib.Path, _type: DatasetType, transformFn=None, verbose: bool = False):
        self.__DehazingDatasetPath = dehazingDatasetPath
        self.__TransformFn = transformFn

        self.__HazyImages = []
        self.__ClearImages = []

        for variant in ('Haze1k_thin', 'Haze1k_moderate', 'Haze1k_thick'):
            inputPath = self.__DehazingDatasetPath / variant / 'dataset' / _type.ToString() / 'input'
            targetPath = self.__DehazingDatasetPath / variant / 'dataset' / _type.ToString() / 'target'

            self.__HazyImages += [inputPath / filename for filename in sorted(os.listdir(inputPath)) if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]
            self.__ClearImages += [targetPath / filename for filename in sorted(os.listdir(targetPath)) if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]

        # Filtering the mismatching (input, target) image pair
        assert len(self.__HazyImages) == len(self.__ClearImages)
        for hazyPath, clearPath in zip(self.__HazyImages, self.__ClearImages):
            hazyImage = Image.open(hazyPath)
            clearImage = Image.open(clearPath)
            if hazyImage.size != clearImage.size:
                self.__HazyImages.remove(hazyPath)
                self.__ClearImages.remove(clearPath)
            elif verbose:
                print(hazyPath)
                print(clearPath)

        self.__Size = len(self.__HazyImages)

    def __len__(self):
        return self.__Size

    def __getitem__(self, index) -> torch.Tensor:
        hazyImage = None
        clearImage = None
        try:
            hazyImage = torch.Tensor(self.__TransformFn(Image.open(self.__HazyImages[index]).convert('RGB')))
            clearImage = torch.Tensor(self.__TransformFn(Image.open(self.__ClearImages[index]).convert('RGB')))
        except OSError:
            print(f'Error Loading: {self.__HazyImages[index]}')
            print(f'Error Loading: {self.__ClearImages[index]}')

            # Handle the case of empty images, e.g., skip the sample
            # You can also replace the empty images with placeholder images if needed
            # For now, let's just return a placeholder tensor
            placeholder_image = torch.zeros((3, 512, 512), dtype=torch.float32)
            return placeholder_image, placeholder_image

        return hazyImage, clearImage

In [4]:
datasetPath = pathlib.Path('SS594_Multispectral_Dehazing/Haze1k/Haze1k')
trainingDataset = DehazingDataset(dehazingDatasetPath=datasetPath, _type=DatasetType.Train, transformFn=Preprocess, verbose=True)
validationDataset = DehazingDataset(dehazingDatasetPath=datasetPath, _type=DatasetType.Validation, transformFn=Preprocess, verbose=True)

SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/1-inputs.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/1-targets.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/10-inputs.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/10-targets.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/100-inputs.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/100-targets.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/101-inputs.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/101-targets.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/102-inputs.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/102-targets.png
SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/103-inputs.png
SS594_

In [5]:
import torch.utils.data as tu_data

trainingDataLoader = tu_data.DataLoader(trainingDataset, batch_size=16, shuffle=True)
validationDataLoader = tu_data.DataLoader(validationDataset, batch_size=16, shuffle=True)

print(len(trainingDataset), len(validationDataset))

960 105


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as tn_functional

class AODnet(nn.Module):
    def __init__(self):
        super(AODnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

    def forward(self, x):
        x1 = tn_functional.relu(self.conv1(x))
        x2 = tn_functional.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x3 = tn_functional.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x4 = tn_functional.relu(self.conv4(cat2))
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = tn_functional.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return tn_functional.relu(output)

In [7]:
import torch.nn.functional as F

class AOD_pono_net(nn.Module):
    def __init__(self):
        super(AOD_pono_net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

        self.pono = PONO(affine=False)
        self.ms = MS()

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x1, mean1, std1 = self.pono(x1)
        x2, mean2, std2 = self.pono(x2)
        x3 = F.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x3 = self.ms(x3, mean1, std1)
        x4 = F.relu(self.conv4(cat2))
        x4 = self.ms(x4, mean2, std2)
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = F.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return F.relu(output)

class PONO(nn.Module):
    def __init__(self, input_size=None, return_stats=False, affine=True, eps=1e-5):
        super(PONO, self).__init__()
        self.return_stats = return_stats
        self.input_size = input_size
        self.eps = eps
        self.affine = affine

        if affine:
            self.beta = nn.Parameter(torch.zeros(1, 1, *input_size))
            self.gamma = nn.Parameter(torch.ones(1, 1, *input_size))
        else:
            self.beta, self.gamma = None, None

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = (x.var(dim=1, keepdim=True) + self.eps).sqrt()
        x = (x - mean) / std
        if self.affine:
            x = x * self.gamma + self.beta
        return x, mean, std

class MS(nn.Module):
    def __init__(self, beta=None, gamma=None):
        super(MS, self).__init__()
        self.gamma, self.beta = gamma, beta

    def forward(self, x, beta=None, gamma=None):
        beta = self.beta if beta is None else beta
        gamma = self.gamma if gamma is None else gamma
        if gamma is not None:
            x.mul_(gamma)
        if beta is not None:
            x.add_(beta)
        return x

In [8]:
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

class SSIMLoss(StructuralSimilarityIndexMeasure):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

In [9]:
import torch.optim as optim

model = AODnet().to(device)
print(model)

criterion = nn.MSELoss().to(device=device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 10

AODnet(
  (conv1): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(6, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(6, 3, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (conv5): Conv2d(12, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [10]:
!pip install torchmetrics



In [11]:
import torchvision
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torch.optim.lr_scheduler import ReduceLROnPlateau

train_number = len(trainingDataLoader)

ssim = StructuralSimilarityIndexMeasure().to(device)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True, min_lr=1e-6)

print("Started Training...")
model.train()
for epoch in range(EPOCHS):
    # -------------------------------------------------------------------
    # start training
    for step, (haze_image, ori_image) in enumerate(trainingDataLoader):
        ori_image, haze_image = ori_image.to(device), haze_image.to(device)
        dehaze_image = model(haze_image)
        loss = criterion(dehaze_image, ori_image)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        print(
            "Epoch: {}/{}  |  Step: {}/{}  |  lr: {:.6f}  | Loss: {:.6f}".format(
                epoch + 1, EPOCHS, step + 1, train_number, optimizer.param_groups[0]["lr"], loss.item()
            )
        )
    # -------------------------------------------------------------------
    # start validation
    print("Epoch: {}/{} | Validation Model Saving Images".format(epoch + 1, EPOCHS))
    model.eval()
    with torch.no_grad():
        for step, (haze_image, ori_image) in enumerate(validationDataLoader):
            ori_image, haze_image = ori_image.to(device), haze_image.to(device)
            dehaze_image = model(haze_image)

            # ssim_val = ssim(dehaze_image, ori_image)
            # ssim_fake_val = ssim(haze_image, ori_image)
            # print(f"SSIM: {ssim_val}, SSIM_Fake: {ssim_fake_val}")
            # perc = (ssim_val - ssim_fake_val) * 100.0 / (1.0 - ssim_fake_val)
            # print(f"Percentage Improvement: {perc} %")

            val_loss = 0
            loss = criterion(dehaze_image, ori_image)
            val_loss += loss.item()
            
            if step > 10:  # only save image 10 times
                break
            
            print(f'Validation Loss: {val_loss}')

            torchvision.utils.save_image(
                torchvision.utils.make_grid(torch.cat((haze_image, dehaze_image, ori_image), 0), nrow=ori_image.shape[0]),
                os.path.join("output", "{}_{}.jpg".format(epoch + 1, step)),
            )

        val_loss /= len(validationDataLoader)

        # Update the learning rate based on validation loss
        scheduler.step(val_loss)

    model.train()

Started Training...




Epoch: 1/10  |  Step: 1/60  |  lr: 0.001000  | Loss: 0.488064
Epoch: 1/10  |  Step: 2/60  |  lr: 0.001000  | Loss: 0.458054
Epoch: 1/10  |  Step: 3/60  |  lr: 0.001000  | Loss: 0.429903
Epoch: 1/10  |  Step: 4/60  |  lr: 0.001000  | Loss: 0.365717
Epoch: 1/10  |  Step: 5/60  |  lr: 0.001000  | Loss: 0.383836
Epoch: 1/10  |  Step: 6/60  |  lr: 0.001000  | Loss: 0.356100
Epoch: 1/10  |  Step: 7/60  |  lr: 0.001000  | Loss: 0.304983
Epoch: 1/10  |  Step: 8/60  |  lr: 0.001000  | Loss: 0.293587
Epoch: 1/10  |  Step: 9/60  |  lr: 0.001000  | Loss: 0.233136
Epoch: 1/10  |  Step: 10/60  |  lr: 0.001000  | Loss: 0.210471
Epoch: 1/10  |  Step: 11/60  |  lr: 0.001000  | Loss: 0.157414
Epoch: 1/10  |  Step: 12/60  |  lr: 0.001000  | Loss: 0.179051
Epoch: 1/10  |  Step: 13/60  |  lr: 0.001000  | Loss: 0.111424
Epoch: 1/10  |  Step: 14/60  |  lr: 0.001000  | Loss: 0.099890
Epoch: 1/10  |  Step: 15/60  |  lr: 0.001000  | Loss: 0.064558
Epoch: 1/10  |  Step: 16/60  |  lr: 0.001000  | Loss: 0.045172
E

KeyboardInterrupt: 

In [12]:
torch.save(model, 'saved_model_MSE_Nvidia_Scheduled_LR.pth')

In [13]:
!tar -zcvf output_MSE_Nvidia_Scheduled_LR.tar.gz output

output/
output/4_0.jpg
output/10_4.jpg
output/6_5.jpg
output/7_0.jpg
output/6_1.jpg
output/5_5.jpg
output/7_2.jpg
output/1_1.jpg
output/6_3.jpg
output/7_3.jpg
output/1_3.jpg
output/8_3.jpg
output/2_0.jpg
output/5_1.jpg
output/1_6.jpg
output/3_0.jpg
output/5_6.jpg
output/4_6.jpg
output/4_2.jpg
output/9_0.jpg
output/10_3.jpg
output/8_2.jpg
output/2_2.jpg
output/.ipynb_checkpoints/
output/.ipynb_checkpoints/9_4-checkpoint.jpg
output/.ipynb_checkpoints/3_4-checkpoint.jpg
output/.ipynb_checkpoints/1_3-checkpoint.jpg
output/.ipynb_checkpoints/10_6-checkpoint.jpg
output/.ipynb_checkpoints/4_6-checkpoint.jpg
output/.ipynb_checkpoints/3_6-checkpoint.jpg
output/.ipynb_checkpoints/1_5-checkpoint.jpg
output/.ipynb_checkpoints/6_6-checkpoint.jpg
output/.ipynb_checkpoints/1_6-checkpoint.jpg
output/1_0.jpg
output/5_4.jpg
output/9_2.jpg
output/3_4.jpg
output/2_3.jpg
output/6_6.jpg
output/7_5.jpg
output/9_4.jpg
output/1_4.jpg
output/1_2.jpg
output/4_4.jpg
output/10_6.jpg
output/2_4.jpg
output/5_2.jpg
o

In [23]:
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}