In [1]:
!wget https://vedas.sac.gov.in/static/pdf/SIH_2022/SS594_Multispectral_Dehazing.zip
!unzip SS594_Multispectral_Dehazing.zip
!mv SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/valid SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/val
!mv SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thick/dataset/valid SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thick/dataset/val

--2023-12-22 07:37:16--  https://vedas.sac.gov.in/static/pdf/SIH_2022/SS594_Multispectral_Dehazing.zip
Resolving vedas.sac.gov.in (vedas.sac.gov.in)... 103.99.192.69, 2001:df0:4840::69
Connecting to vedas.sac.gov.in (vedas.sac.gov.in)|103.99.192.69|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1481111113 (1.4G) [application/x-zip-compressed]
Saving to: ‘SS594_Multispectral_Dehazing.zip.1’

.zip.1                3%[                    ]  44.41M  3.35MB/s    eta 10m 54s^C
Archive:  SS594_Multispectral_Dehazing.zip
replace SS594_Multispectral_Dehazing/GT/01_GT.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C
mv: cannot stat 'SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/valid': No such file or directory
mv: cannot stat 'SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thick/dataset/valid': No such file or directory


In [2]:
%pip install torch

^C
[31mERROR: Operation cancelled by user[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
%pip install torchvision

UnboundLocalError: local variable 'child' referenced before assignment

In [1]:
from PIL.Image import Image

import torch
import torchvision.transforms as transforms

def Preprocess(image: Image) -> torch.Tensor:
    # Contrast Enhancement
    transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
        # transforms.functional.equalize
    ])
    transformedImage = transform(image)

    # Gamma Correction
    gammaCorrectedImage = transforms.functional.adjust_gamma(transformedImage, 2.2)

    # Histogram Stretching
    min_val = gammaCorrectedImage.min()
    max_val = gammaCorrectedImage.max()
    stretchedImage = (gammaCorrectedImage - min_val) / (max_val - min_val)
    return stretchedImage

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.has_mps else 'cpu')
device

device(type='cuda')

In [3]:
import os
import pathlib
from enum import Enum

from torch.utils.data import Dataset
from PIL import Image

class DatasetType(Enum):
    Train = 0,
    Test = 1,
    Validation = 2

    def ToString(self) -> str:
        if self == DatasetType.Train:
            return 'train'
        elif self == DatasetType.Test:
            return 'test'
        elif self == DatasetType.Validation:
            return 'val'

class DehazingDataset(Dataset):
    def __init__(self, dehazingDatasetPath: pathlib.Path, _type: DatasetType, transformFn=None, verbose: bool = False):
        self.__DehazingDatasetPath = dehazingDatasetPath
        self.__TransformFn = transformFn

        self.__HazyImages = []
        self.__ClearImages = []

        for variant in ('Haze1k_thin', 'Haze1k_moderate', 'Haze1k_thick'):
            inputPath = self.__DehazingDatasetPath / variant / 'dataset' / _type.ToString() / 'input'
            targetPath = self.__DehazingDatasetPath / variant / 'dataset' / _type.ToString() / 'target'

            self.__HazyImages += [inputPath / filename for filename in sorted(os.listdir(inputPath)) if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]
            self.__ClearImages += [targetPath / filename for filename in sorted(os.listdir(targetPath)) if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]

        # Filtering the mismatching (input, target) image pair
        assert len(self.__HazyImages) == len(self.__ClearImages)
        for hazyPath, clearPath in zip(self.__HazyImages, self.__ClearImages):
            hazyImage = Image.open(hazyPath)
            clearImage = Image.open(clearPath)
            if hazyImage.size != clearImage.size:
                self.__HazyImages.remove(hazyPath)
                self.__ClearImages.remove(clearPath)
            elif verbose:
                print(hazyPath)
                print(clearPath)

        self.__Size = len(self.__HazyImages)

    def __len__(self):
        return self.__Size

    def __getitem__(self, index) -> torch.Tensor:
        hazyImage = None
        clearImage = None
        try:
            hazyImage = torch.Tensor(self.__TransformFn(Image.open(self.__HazyImages[index]).convert('RGB')))
            clearImage = torch.Tensor(self.__TransformFn(Image.open(self.__ClearImages[index]).convert('RGB')))
        except OSError:
            print(f'Error Loading: {self.__HazyImages[index]}')
            print(f'Error Loading: {self.__ClearImages[index]}')

            # Handle the case of empty images, e.g., skip the sample
            # You can also replace the empty images with placeholder images if needed
            # For now, let's just return a placeholder tensor
            placeholder_image = torch.zeros((3, 512, 512), dtype=torch.float32)
            return placeholder_image, placeholder_image

        return hazyImage, clearImage

In [4]:
datasetPath = pathlib.Path('SS594_Multispectral_Dehazing/Haze1k/Haze1k')
trainingDataset = DehazingDataset(dehazingDatasetPath=datasetPath, _type=DatasetType.Train, transformFn=Preprocess, verbose=False)
validationDataset = DehazingDataset(dehazingDatasetPath=datasetPath, _type=DatasetType.Validation, transformFn=Preprocess, verbose=False)

In [16]:
import torch.utils.data as tu_data

BATCH_SIZE = 64

trainingDataLoader = tu_data.DataLoader(trainingDataset, batch_size=BATCH_SIZE, shuffle=True)
validationDataLoader = tu_data.DataLoader(validationDataset, batch_size=BATCH_SIZE, shuffle=True)

print(len(trainingDataset), len(validationDataset))

960 105


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as tn_functional

class AODnet(nn.Module):
    def __init__(self):
        super(AODnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        
        # Initialize the weights of the convolutional layer with a Gaussian distribution
        nn.init.normal_(self.conv1.weight, mean=0.0, std=0.01)
        
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

    def forward(self, x):
        x1 = tn_functional.relu(self.conv1(x))
        x2 = tn_functional.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x3 = tn_functional.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x4 = tn_functional.relu(self.conv4(cat2))
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = tn_functional.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return tn_functional.relu(output)

In [21]:
LEARNING_RATE = 1e-2
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0001
EPOCHS = 10
GRADIENT_CLIP_VALUE = 0.5
STEPS = len(trainingDataLoader)

In [8]:
lre = torch.linspace(-2, -6, EPOCHS * STEPS)
lrs = 10 ** lre

In [22]:
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau

model = AODnet().to(device)
print(model)

criterion = nn.MSELoss().to(device=device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# def lambdaLR(epoch):
#     return lrs[epoch - 1]
# scheduler = LambdaLR(optimizer, lr_lambda=lambdaLR)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3, verbose=True, min_lr=1e-6)

AODnet(
  (conv1): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(6, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(6, 3, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (conv5): Conv2d(12, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [12]:
!pip install torchmetrics



In [None]:
import torchvision
from torchmetrics.image import StructuralSimilarityIndexMeasure

train_number = len(trainingDataLoader)
os.makedirs("output", exist_ok=True)

lre

print("Started Training...")
model.train()
for epoch in range(EPOCHS):
    # -------------------------------------------------------------------
    # start training
    for step, (haze_image, ori_image) in enumerate(trainingDataLoader):
        ori_image, haze_image = ori_image.to(device), haze_image.to(device)
        # Forward Pass
        dehaze_image = model(haze_image)
        # Loss Calculation
        loss = criterion(dehaze_image, ori_image)
        # Setting the gradients to zero to avoid accumulation across steps
        optimizer.zero_grad()
        # Backward Propagation
        loss.backward()
        # Setting the clipping value of the gradients
        torch.nn.utils.clip_grad_value_(model.parameters(), GRADIENT_CLIP_VALUE)
        # Updating the gradients
        optimizer.step()
        scheduler.step(loss)

        print(
            "Epoch: {}/{}  |  Step: {}/{}  |  lr: {:.6f}  | Loss: {:.6f}".format(
                epoch + 1, EPOCHS, step + 1, train_number, optimizer.param_groups[0]["lr"], loss.item()
            )
        )
    # -------------------------------------------------------------------
    # start validation
    print("Epoch: {}/{} | Validation Model Saving Images".format(epoch + 1, EPOCHS))
    model.eval()
    for step, (haze_image, ori_image) in enumerate(validationDataLoader):
        if step > 10:  # only save image 10 times
            break
        ori_image, haze_image = ori_image.to(device), haze_image.to(device)
        dehaze_image = model(haze_image)

        # ssim = StructuralSimilarityIndexMeasure().to(device)
        # ssim_val = ssim(dehaze_image, ori_image)
        # ssim_fake_val = ssim(haze_image, ori_image)
        # print(f"SSIM: {ssim_val}, SSIM_Fake: {ssim_fake_val}")
        # perc = (ssim_val - ssim_fake_val) * 100.0 / (1.0 - ssim_fake_val)
        # print(f"Percentage Improvement: {perc} %")

        torchvision.utils.save_image(
            torchvision.utils.make_grid(torch.cat((haze_image, dehaze_image, ori_image), 0), nrow=ori_image.shape[0]),
            os.path.join("output", "{}_{}.jpg".format(epoch + 1, step)),
        )

    model.train()

Started Training...
Epoch: 1/10  |  Step: 1/15  |  lr: 0.010000  | Loss: 0.467733
Epoch: 1/10  |  Step: 2/15  |  lr: 0.010000  | Loss: 0.284729
Epoch: 1/10  |  Step: 3/15  |  lr: 0.010000  | Loss: 0.115446
Epoch: 1/10  |  Step: 4/15  |  lr: 0.010000  | Loss: 0.054747
Epoch: 1/10  |  Step: 5/15  |  lr: 0.010000  | Loss: 0.073421
Error Loading: SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_moderate/dataset/train/input/265.png
Error Loading: SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_moderate/dataset/train/target/265.png
Epoch: 1/10  |  Step: 6/15  |  lr: 0.010000  | Loss: 0.084567
Epoch: 1/10  |  Step: 7/15  |  lr: 0.010000  | Loss: 0.089532
Epoch     8: reducing learning rate of group 0 to 8.0000e-03.
Epoch: 1/10  |  Step: 8/15  |  lr: 0.008000  | Loss: 0.096852
Epoch: 1/10  |  Step: 9/15  |  lr: 0.008000  | Loss: 0.105473
Epoch: 1/10  |  Step: 10/15  |  lr: 0.008000  | Loss: 0.103228
Epoch: 1/10  |  Step: 11/15  |  lr: 0.008000  | Loss: 0.097101
Error Loading: SS594_Multispe

In [None]:
torch.save(model, 'saved_model.pth')

In [None]:
!tar -zcvf output.tar.gz output