In [1]:
import cv2
from PIL.Image import Image

import torch
import torchvision.transforms as transforms

def Preprocess(image: Image) -> torch.Tensor:
    # Contrast Enhancement
    transform = transforms.Compose([
        transforms.PILToTensor(),
        transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05)
        # transforms.functional.equalize
    ])
    transformedImage = transform(image)

    # Gamma Correction
    gammaCorrectedImage = transforms.functional.adjust_gamma(transformedImage, 2.2)

    # Histogram Stretching
    min_val = gammaCorrectedImage.min()
    max_val = gammaCorrectedImage.max()
    stretchedImage = (gammaCorrectedImage - min_val) / (max_val - min_val)

    # for x in stretchedImage:
    #     for y in x:
    #         print(y)

    # Guided Filtering
    gFilter = cv2.ximgproc.createGuidedFilter(guide=stretchedImage.permute(1, 2, 0).numpy(), radius=3, eps=0.01)
    filteredImage = gFilter.filter(src=stretchedImage.permute(1, 2, 0).numpy())
    return torch.from_numpy(filteredImage).permute(2, 0, 1)

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.has_mps else 'cpu')

In [4]:
import os
import pathlib
from enum import Enum

from torch.utils.data import Dataset
from PIL import Image

class DatasetType(Enum):
    Train = 0,
    Test = 1,
    Validation = 2

    def ToString(self) -> str:
        if self == DatasetType.Train:
            return 'train'
        elif self == DatasetType.Test:
            return 'test'
        elif self == DatasetType.Validation:
            return 'val'

class DehazingDataset(Dataset):
    def __init__(self, dehazingDatasetPath: pathlib.Path, _type: DatasetType, transformFn=None, verbose: bool = False):
        self.__DehazingDatasetPath = dehazingDatasetPath
        self.__TransformFn = transformFn

        self.__HazyImages = []
        self.__ClearImages = []

        for variant in ('Haze1k_thin', 'Haze1k_moderate', 'Haze1k_thick'):
            inputPath = self.__DehazingDatasetPath / variant / 'dataset' / _type.ToString() / 'input'
            targetPath = self.__DehazingDatasetPath / variant / 'dataset' / _type.ToString() / 'target'

            self.__HazyImages += [inputPath / filename for filename in sorted(os.listdir(inputPath)) if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]
            self.__ClearImages += [targetPath / filename for filename in sorted(os.listdir(targetPath)) if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]

        # Filtering the mismatching (input, target) image pair
        assert len(self.__HazyImages) == len(self.__ClearImages)
        for hazyPath, clearPath in zip(self.__HazyImages, self.__ClearImages):
            hazyImage = Image.open(hazyPath)
            clearImage = Image.open(clearPath)
            if hazyImage.size != clearImage.size:
                self.__HazyImages.remove(hazyPath)
                self.__ClearImages.remove(clearPath)
            elif verbose:
                print(hazyPath)
                print(clearPath)

        self.__Size = len(self.__HazyImages)

    def __len__(self):
        return self.__Size

    def __getitem__(self, index) -> torch.Tensor:
        hazyImage = None
        clearImage = None
        try:
            hazyImage = torch.Tensor(self.__TransformFn(Image.open(self.__HazyImages[index]).convert('RGB')))
            clearImage = torch.Tensor(self.__TransformFn(Image.open(self.__ClearImages[index]).convert('RGB')))
        except OSError:
            print(f'Error Loading: {self.__HazyImages[index]}')
            print(f'Error Loading: {self.__ClearImages[index]}')

            # Handle the case of empty images, e.g., skip the sample
            # You can also replace the empty images with placeholder images if needed
            # For now, let's just return a placeholder tensor
            placeholder_image = torch.zeros((3, 512, 512), dtype=torch.float32)
            return placeholder_image, placeholder_image

        return hazyImage, clearImage

In [6]:
datasetPath = pathlib.Path('../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k')
trainingDataset = DehazingDataset(dehazingDatasetPath=datasetPath, _type=DatasetType.Train, transformFn=Preprocess, verbose=True)
validationDataset = DehazingDataset(dehazingDatasetPath=datasetPath, _type=DatasetType.Validation, transformFn=Preprocess, verbose=True)

../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/1-inputs.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/1-targets.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/10-inputs.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/10-targets.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/100-inputs.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/100-targets.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/101-inputs.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/target/101-targets.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1k_thin/dataset/train/input/102-inputs.png
../../dataset/SS594_Multispectral_Dehazing/Haze1k/Haze1k/Haze1

In [9]:
import torch.utils.data as tu_data

trainingDataLoader = tu_data.DataLoader(trainingDataset, batch_size=32, shuffle=True, num_workers=3)
validationDataLoader = tu_data.DataLoader(validationDataset, batch_size=32, shuffle=True, num_workers=3)

print(len(trainingDataset), len(validationDataset))

960 105


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as tn_functional

class AODnet(nn.Module):
    def __init__(self):
        super(AODnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

    def forward(self, x):
        x1 = tn_functional.relu(self.conv1(x))
        x2 = tn_functional.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x3 = tn_functional.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x4 = tn_functional.relu(self.conv4(cat2))
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = tn_functional.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return tn_functional.relu(output)

In [11]:
import torch.optim as optim

model = AODnet().to(device)
print(model)

criterion = nn.MSELoss().to(device=device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 10

AODnet(
  (conv1): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(6, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(6, 3, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (conv5): Conv2d(12, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [14]:
import torchvision
from torchmetrics.image import StructuralSimilarityIndexMeasure

train_number = len(trainingDataLoader)

print("Started Training...")
model.train()
for epoch in range(EPOCHS):
    # -------------------------------------------------------------------
    # start training
    for step, (haze_image, ori_image) in enumerate(trainingDataLoader):
        ori_image, haze_image = ori_image.to(device), haze_image.to(device)
        dehaze_image = model(haze_image)
        loss = criterion(dehaze_image, ori_image)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        print(
            "Epoch: {}/{}  |  Step: {}/{}  |  lr: {:.6f}  | Loss: {:.6f}".format(
                epoch + 1, EPOCHS, step + 1, train_number, optimizer.param_groups[0]["lr"], loss.item()
            )
        )
    # -------------------------------------------------------------------
    # start validation
    print("Epoch: {}/{} | Validation Model Saving Images".format(epoch + 1, EPOCHS))
    model.eval()
    for step, (haze_image, ori_image) in enumerate(validationDataLoader):
        if step > 10:  # only save image 10 times
            break
        ori_image, haze_image = ori_image.to(device), haze_image.to(device)
        dehaze_image = model(haze_image)

        ssim = StructuralSimilarityIndexMeasure().to(device)
        ssim_val = ssim(dehaze_image, ori_image)
        ssim_fake_val = ssim(haze_image, ori_image)
        print(f"SSIM: {ssim_val}, SSIM_Fake: {ssim_fake_val}")
        perc = (ssim_val - ssim_fake_val) * 100.0 / (1.0 - ssim_fake_val)
        print(f"Percentage Improvement: {perc} %")

        torchvision.utils.save_image(
            torchvision.utils.make_grid(torch.cat((haze_image, dehaze_image, ori_image), 0), nrow=ori_image.shape[0]),
            os.path.join("output", "{}_{}.jpg".format(epoch + 1, step)),
        )

    model.train()

Started Training...


Traceback (most recent call last):
  File "<string>", line 1, in <module>
Traceback (most recent call last):
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Frameworks/Python.fr

RuntimeError: DataLoader worker (pid(s) 89147) exited unexpectedly