In [1]:
%pip install kornia
import os
import torch
import yaml
import glob
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image
import numpy as np
from kornia.metrics import psnr, ssim
from torch.utils.data import Dataset, DataLoader

Collecting kornia
  Downloading kornia-0.7.3-py2.py3-none-any.whl (833 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m833.3/833.3 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kornia-rs>=0.1.0 (from kornia)
  Downloading kornia_rs-0.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.9.1->kornia)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.9.1->kornia)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.9.1->kornia)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.9

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

    # for cuda
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

In [3]:
set_seed(0)

In [4]:
def extract_files():
    import google.colab
    import zipfile

    google.colab.drive.mount('/content/drive')
    PROJECT_DIR = "/content/drive/MyDrive/thesis/data/"

    zip_ref = zipfile.ZipFile(PROJECT_DIR + "fiveK.zip", 'r')
    zip_ref.extractall(".")
    zip_ref.close()

In [5]:
if 'google.colab' in str(get_ipython()):
  extract_files()
  config_path = "/content/drive/MyDrive/thesis/config.yaml"
else:
  config_path = "../../config.yaml"

Mounted at /content/drive


In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [7]:
try:
    # Load configuration
    with open(config_path, 'r') as config_file:
        config = yaml.safe_load(config_file)
except:
    raise FileNotFoundError(f"Config file not found at path: {config_path}")

In [8]:
loss_type = config['unetmodel']['loss']
depth = config['unetmodel']['depth']
lambda_ = config['unetmodel']['contrastive_lambda']
base_checkpoint_path = f"{config['paths']['unetcheckpoints']}_five_classes_contrastive_{loss_type}_{depth}_{lambda_}"

In [9]:
def load_best_checkpoint(checkpoint_dir):
    # Check if the directory exists
    if not os.path.exists(base_checkpoint_path):
        print(f"No directory found: {checkpoint_dir}")
        return None
      # Get a list of all checkpoint files in the directory
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, f'unet_*.pth'))

    # sort the checkpoint files according to the epoch number
    checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))

    # Check if any checkpoint files are present
    if not checkpoint_files:
        print(f"No checkpoints found in the directory: {checkpoint_dir}")
        return None

    best_val_loss = float('inf')
    epochs = []
    train_losses = []
    val_losses = []
    for checkpoint_file in checkpoint_files:
        checkpoint = torch.load(checkpoint_file, map_location=torch.device(device))
        epochs.append(checkpoint['epoch'])
        train_losses.append(checkpoint['train_loss'])
        val_losses.append(checkpoint['val_loss'])
        if checkpoint['val_loss'] < best_val_loss:
            best_val_loss = checkpoint['val_loss']
            best_checkpoint = checkpoint

    return best_checkpoint, epochs, train_losses, val_losses

In [11]:
checkpoint, epochs, train_losses, val_losses = load_best_checkpoint(base_checkpoint_path)

In [12]:
class ConvBlock(torch.nn.Module):
    def __init__(self, inchannels, outchannels, downscale=False, upscale=False):
        super(ConvBlock, self).__init__()
        self.down = torch.nn.MaxPool2d(2) if downscale else torch.nn.Identity()
        self.conv1 = torch.nn.Conv2d(inchannels, outchannels, 3, padding=1)
        self.bnorm1 = torch.nn.InstanceNorm2d(outchannels)
        self.conv2 = torch.nn.Conv2d(outchannels, outchannels, 3, padding=1)
        self.bnorm2 = torch.nn.InstanceNorm2d(outchannels)
        self.up = torch.nn.Upsample(scale_factor=2) if upscale else torch.nn.Identity()

    def forward(self, x):
        x = self.down(x)
        x = torch.nn.functional.relu(self.bnorm1(self.conv1(x)))
        x = torch.nn.functional.relu(self.bnorm2(self.conv2(x)))
        x = self.up(x)
        return x

In [13]:
class UNet(torch.nn.Module):
    def __init__(self, classes, depth):
        super(UNet, self).__init__()
        self.encoder = torch.nn.ModuleList()
        channels = [3] + [64 * (2 ** i) for i in range(depth)]
        for i in range(depth):
            self.encoder.append(ConvBlock(channels[i], channels[i + 1], downscale=(i > 0)))

        self.embedding = torch.nn.Embedding(classes, channels[-1])
        self.bottleneck = ConvBlock(channels[-1], channels[-1], downscale=True, upscale=True)

        self.decoder = torch.nn.ModuleList()
        self.linear = torch.nn.ModuleList()
        channels[0] = 64
        for i in range(depth - 1, -1, -1):
            self.decoder.append(ConvBlock(2 * channels[i + 1], channels[i], upscale=(i > 0)))
            self.linear.append(torch.nn.Linear(channels[-1], 2 * channels[i] if i > 0 else channels[i], bias=False))

        self.output = torch.nn.Sequential(
            torch.nn.Conv2d(channels[0], 3, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x, label):
        skip = []
        for mod in self.encoder:
            x = mod(x)
            skip.append(x)
        emb = self.embedding(label)
        x = x + emb.unsqueeze(-1).unsqueeze(-1)
        x = self.bottleneck(x)
        for mod, linear in zip(self.decoder, self.linear):
            y = skip.pop()
            # add embedding with the decoder input
            x = x + linear(emb).unsqueeze(-1).unsqueeze(-1)
            x = torch.cat([x, y], 1)
            x = mod(x)
        x = self.output(x)
        return x

In [14]:
depth = config['unetmodel']['depth']
net = UNet(classes=5, depth=3)
net = net.to(device)

In [15]:
net.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [16]:
data_folder = config['paths']['data']
train_file = config['paths']['train']
test_file = config['paths']['test']

In [17]:
test_tr = transforms.Compose([
        transforms.ToTensor(),
        transforms.CenterCrop(224),
    ])

In [18]:
# List of class directories
class_directories = ['expA','expB', 'expC', 'expD', 'expE']
# raw data directory
raw_dir = "raw"

In [19]:
class FiveK(Dataset):
    def __init__(self, data_dir, raw_data_dir, filename, transform=None):
        super().__init__()
        self.filename = filename
        self.transform = transform

        self.classname = self._extract_class_name(data_dir)
        self.encode = {k: i for i, k in enumerate(class_directories)}


        # Read the train.txt file and store the image paths
        with open(self.filename) as f:
            img_paths= []
            raw_img_paths = []
            for line in f:
                line = line.strip()
                img_paths.append(os.path.join(data_dir, line))
                raw_img_paths.append(os.path.join(raw_data_dir, line))

            self.image_paths = img_paths
            self.raw_image_paths = raw_img_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        raw_image_path = self.raw_image_paths[index]
        image = Image.open(image_path)
        raw_image = Image.open(raw_image_path)
        image = np.dstack((np.array(raw_image), np.array(image)))
        label = self.encode[self.classname]
        if self.transform is not None:
            image = self.transform(image)

        tr_raw_image = image[:3]
        normalize = transforms.Normalize(mean=[0.2279, 0.2017, 0.1825], std=[0.1191, 0.1092, 0.1088])
        tr_raw_image = normalize(tr_raw_image)
        tr_image = image[3:]
        tr_final_image = torch.cat((tr_raw_image, tr_image), 0)
        return tr_final_image, label

    def _extract_class_name(self, root_dir):
        # Extract the class name from the root directory
        class_name = os.path.basename(root_dir)
        return class_name

In [20]:
def read_dataset(data_folder, txt_file, trasform=None):
    # Create separate datasets for each class
    datasets = []

    for class_dir in class_directories:
        class_train_dataset = FiveK(
            data_dir=os.path.join(data_folder, class_dir),
            raw_data_dir=os.path.join(data_folder, raw_dir),
            filename=os.path.join(txt_file),
            transform=trasform
        )
        datasets.append(class_train_dataset)
    return datasets

In [21]:
val_dataset = torch.utils.data.ConcatDataset(read_dataset(data_folder, test_file, test_tr))

In [22]:
bs = 16

In [23]:
val_dataloader = DataLoader(val_dataset, batch_size=bs, shuffle=False)

In [24]:
print(checkpoint['epoch'])

167


In [25]:
psnrs = []
ssims = []
# calculate psnr for the validation dataset
for inputs, labels in val_dataloader:
    raw = inputs[:, :3]
    gt = inputs[:, 3:]
    raw = raw.to(device)
    gt = gt.to(device)
    labels = labels.to(device)
    outputs = net(raw, labels)
    psnr_val = psnr(outputs, gt, 1)
    ssim_val = ssim(outputs, gt, 5).mean()

    ssims.append(ssim_val.item())
    psnrs.append(psnr_val.item())

print(f"Average PSNR: {np.mean(psnrs)}")
print(f"Average SSIM: {np.mean(ssims)}")


Average PSNR: 20.329298348472523
Average SSIM: 0.8565619309870199


In [27]:
print(len(val_dataset))

5000
