# Exercise 6

## SETUP

In [None]:
import os
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import MSELoss, L1Loss
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision
from torchvision.io import read_image
from torchvision.transforms import Compose, RandomCrop, ColorJitter, Resize
from torchvision.io import write_png

from pytorch_msssim import SSIM

In [None]:
# Check if GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")  # Use GPU
else:
    device = torch.device("cpu")  # Use CPU

# Print the device being used
print("Device:", device)

In [None]:
# Define class SRDataset
class SRDataset(Dataset):
    def __init__(self, folder_path, augment):
        self.folder_path = folder_path
        self.image_filenames = os.listdir(folder_path)
        self.t_crop = Compose([RandomCrop(64)])
        self.t_colorjitter = Compose([ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)])
        self.t_downscale = Compose([Resize((32, 32), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)])
        self.augment = augment
        
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.folder_path, self.image_filenames[index])
        hr_image = self.t_crop(read_image(image_path) / 255.0)  # Convert to float between 0 and 1
        if(self.augment):
            lr_image = self.t_downscale(self.t_colorjitter(hr_image))
        else:
            lr_image = self.t_downscale(hr_image)
        
        return lr_image, hr_image


In [None]:
# # Define Model
# class BasicSRModel(nn.Module):
#     def __init__(self, num_inter_blocks):
#         super(BasicSRModel, self).__init__()
        
#         self.conv_blocks = nn.Sequential(nn.ConvTranspose2d(3, 64, kernel_size=4, stride=2, padding=1))
            
#         for i in range(num_inter_blocks):  # Number of intermediate blocks
#             self.conv_blocks.add_module(
#                 f"conv_{i+1}",
#                 nn.Sequential(
#                     nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
#                     nn.LeakyReLU(inplace=True),
#                 )
#             )
        
#         self.conv_blocks.add_module(
#             "last_conv",
#             nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
#         )
    
#     def forward(self, x):
#         x = self.conv_blocks(x)
#         return x

In [None]:
# Define Model
class BasicSRModel(nn.Module):
    def __init__(self, num_inter_blocks):
        super(BasicSRModel, self).__init__()
        
        self.conv_blocks = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False))
        self.conv_blocks.add_module(
            "first_conv",
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        )


        #self.conv_blocks = nn.Sequential(nn.ConvTranspose2d(3, 64, kernel_size=4, stride=2, padding=1))
            
        for i in range(num_inter_blocks):  # Number of intermediate blocks
            self.conv_blocks.add_module(
                f"conv_{i+1}",
                nn.Sequential(
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.LeakyReLU(inplace=True),
                )
            )
        
        self.conv_blocks.add_module(
            "last_conv",
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        )
    
    def forward(self, x):
        x = self.conv_blocks(x)
        return x
    
# Create an instance of BasicSRModel
model = BasicSRModel(10)

# Count the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print(num_params)


## TRAINING & EVALUATION

In [None]:
# Load and initialize the train_dataset
train_datapath = os.path.join(os.path.abspath(''), 'data/train')
train_dataset = SRDataset(train_datapath, augment=True)
train_batch_size = 4
train_dataloader = DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=0,
    drop_last=True,
    pin_memory=True,
    )

In [None]:
# Load and initialize the test_dataset
test_datapath = os.path.join(os.path.abspath(''), 'data/eval')
test_dataset = SRDataset(test_datapath, augment=False)
test_batch_size = 9
test_dataloader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    shuffle=True,
    num_workers=0,
    drop_last=True,
    pin_memory=True,
    )

In [None]:
# # Check Dataset initialization
# print(f" * Dataset contains {len(train_dataset)} image(s).")
# for _, batch in enumerate(train_dataloader, 0):
#     lr_image, hr_image = batch
#     write_png(lr_image[0, ...].mul(255).byte(), "image-outputs/lr_image.png")
#     write_png(hr_image[0, ...].mul(255).byte(), "image-outputs/hr_image.png")
#     break # we deliberately break after one batch as this is just a test

In [None]:
# Create an instance of BasicSRModel
model.to(device)

# Define optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=learning_rate)

# Define loss function
loss_function = L1Loss()
loss_function.to(device)

# # Print the model architecture
# print(model)

# # Check number of parameters in model
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print("num_params: " + str(num_params))


In [None]:
writer = SummaryWriter()
number_of_epochs = 500
for epoch in range(number_of_epochs):
    with tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{number_of_epochs}', unit='batch') as tqdm_train_dataloader:
        # TRAIN BATCH
        cum_loss = 0
        for _, (lr_image, hr_image) in enumerate(tqdm_train_dataloader):
            lr_image, hr_image = lr_image.to(device), hr_image.to(device)
            # reset the gradient
            optimizer.zero_grad()
            # forward pass through the model
            hr_prediction = model(lr_image)  
            # compute the loss
            loss = loss_function(hr_prediction, hr_image)
            # backpropagation
            loss.backward()
            # update the model parameters
            optimizer.step()
            # add loss to be displayed at the end of epoch
            cum_loss += loss.item()
        # log training loss
        writer.add_scalar('loss/train', cum_loss / train_batch_size, epoch)


        # EVALUATE BATCH
        cum_l1 = 0.0
        cum_psnr = 0.0
        cum_ssim = 0.0
        with torch.no_grad():
            for _, (lr_image, hr_image) in enumerate(test_dataloader):
                lr_image, hr_image = lr_image.to(device), hr_image.to(device)
                hr_prediction = model(lr_image)
                # L1
                l1_metric = L1Loss()
                l1_metric.to(device)
                l1_i = l1_metric(hr_prediction, hr_image)
                # PSNR
                mse_metric = MSELoss()
                mse_metric.to(device)
                psnr_i = -10 * torch.log10(mse_metric(hr_prediction, hr_image))
                # SSIM
                ssim_metric = SSIM(data_range=1.0)
                ssim_metric.to(device)
                ssim_i = ssim_metric(hr_prediction, hr_image)
                # accumulate metrics
                cum_psnr += psnr_i.item()
                cum_ssim += ssim_i.item() 
                cum_l1 +=  l1_i.item()
            # Log test loss
            writer.add_scalar('loss/test-l1', cum_l1 / test_batch_size, epoch)   
            writer.add_scalar('loss/test-psnr', cum_psnr / test_batch_size, epoch)
            writer.add_scalar('loss/test-ssim', cum_ssim / test_batch_size, epoch)   
    

In [None]:
# Save model
torch.save(model.state_dict(), "saved-models/modelv2-500.pt")

In [None]:
# #Load model
# model = BasicSRModel(10)
# model.load_state_dict(torch.load("saved-models/model-upscale-first-500.pt"))
# model.to(device); # Suppress output

In [None]:
with torch.no_grad():
    for _, (lr_image, hr_image) in enumerate(test_dataloader):
        lr_image, hr_image = lr_image.to(device), hr_image.to(device)
        hr_prediction = model(lr_image)
        # Display images
        lr_image_disp = lr_image.to('cpu')
        hr_image_disp = hr_image.to('cpu')
        hr_prediction_disp = hr_prediction.to('cpu')
        write_png(lr_image_disp[0, ...].mul(255).byte(), "image-outputs/lr_image.png")
        write_png(hr_image_disp[0, ...].mul(255).byte(), "image-outputs/hr_image.png")
        write_png(hr_prediction_disp[0, ...].mul(255).byte(), "image-outputs/hr_prediction.png") 