In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import tqdm
from PIL import Image
from google.colab import drive
from pathlib import Path

%load_ext google.colab.data_table

# configuration setup
class Config():
    SRCNN_path = 'SRCNN/'
    content_path = f'/content/drive/MyDrive/{SRCNN_path}'
    data_path = './data/'

    content_path = Path(content_path)

    GPU = True
    device = torch.device("cuda" if torch.cuda.is_available() and GPU else "cpu")

    batch_size = 16
    num_epochs = 100
    # Change this file path to change checkpoint model
    # Set to "None" to start training from scratch
    checkpoint_path = content_path / 'Models/SRCNN_checkpoint.pth

    # tranforms high-resolution images
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # converts high-resolution to low-resolution
    @staticmethod
    def low_res_transform():
        return transforms.Compose([
            transforms.GaussianBlur(kernel_size=(5, 5), sigma=(1.5, 1.5)), # Gaussian blur
            transforms.Resize(224 // 3, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
        ])

    hr_train_dir = content_path / 'DIV2K_train_HR'
    hr_val_dir = content_path / 'DIV2K_valid_HR'

# Setup drive paths
class Setup():
    def __init__(self, config):
        self.config = config
        self.mount()
        self.make_dir()
        self.seed()

    def mount(self):
        drive.mount('/content/drive/')

    def make_dir(self):
        if not os.path.exists(self.config.content_path / 'Models/'):
            os.makedirs(self.config.content_path / 'Models/')
        if not os.path.exists(self.config.data_path):
            os.makedirs(self.config.data_path)

    def seed(self):
        if torch.cuda.is_available():
            torch.backends.cudnn.deterministic = True
            torch.manual_seed(0)

# Setup dataset
class SRDataset(Dataset):
    def __init__(self, hr_image_dir, transform=None, low_res_transform=None):
        self.hr_image_dir = hr_image_dir
        self.hr_image_filenames = os.listdir(hr_image_dir)
        self.transform = transform
        self.low_res_transform = low_res_transform

    def __len__(self):
        return len(self.hr_image_filenames)

    def __getitem__(self, idx):
        hr_image_path = os.path.join(self.hr_image_dir, self.hr_image_filenames[idx])
        hr_image = Image.open(hr_image_path).convert('RGB')

        if self.transform:
            hr_image = self.transform(hr_image)

        lr_image = self.low_res_transform(hr_image) if self.low_res_transform else hr_image

        return lr_image, hr_image

# SRCNN model
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.patch_extraction = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=4),
            nn.ReLU()
        )
        self.non_linear_mapping = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2),
            nn.ReLU()
        )
        self.reconstruction = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=2),
        )
        self.initialise_weights()

    # initialises weights
    def initialise_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0.0, std=0.001)
                nn.init.constant_(m.bias, 0)
    
    # performs a forward pass on the network
    def forward(self, x):
        x = self.patch_extraction(x)
        x = self.non_linear_mapping(x)
        x = self.reconstruction(x)
        return x

# SRCNN Training process
class SRCNNTrainer():
    def __init__(self, config):
        self.config = config
        self.device = config.device
        self.batch_size = config.batch_size
        self.num_epochs = config.num_epochs
        self.transform = config.transform
        self.low_res_transform = config.low_res_transform()
        self.psnr = {'Train': [], 'Validate': []}
        self.loss = {'Train': [], 'Validate': []}
        self.model = SRCNN().to(self.device)
        self.optimizer, self.scheduler = self.initialise_optimizer()
        self.train_dataloader, self.val_dataloader = self.initialise_dataset()

    # optimiser with LRS
    def initialise_optimizer(self):
        params_to_optimize = [
            {"params": self.model.patch_extraction.parameters(), "lr": 1e-4, "weight_decay": 1e-6},
            {"params": self.model.non_linear_mapping.parameters(), "lr": 1e-4, "weight_decay": 1e-6},
            {"params": self.model.reconstruction.parameters(), "lr": 1e-5, "weight_decay": 1e-6}
        ]
        optimizer = torch.optim.Adam(params_to_optimize)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # Adjust the learning rate every 50 epochs
        return optimizer, scheduler

    # initialise the training dataset
    def initialise_dataset(self):
        train_dataset = SRDataset(self.config.hr_train_dir, transform=self.transform, low_res_transform=self.low_res_transform)
        train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_dataset = SRDataset(self.config.hr_val_dir, transform=self.transform, low_res_transform=self.low_res_transform)
        val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        return train_dataloader, val_dataloader

    # Number of parameter in model
    def model_output(self):
        params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print("Total number of parameters is: {}".format(params))
        print(self.model)

    # MSE loss
    def loss_MSE(self, input, target):
        loss = nn.MSELoss()
        mse_loss = loss(input, target)
        max_pixel = 1.0
        psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse_loss))
        return mse_loss, psnr

    # save checkpoint model in training
    def save_checkpoint(self, epoch):
        checkpoint_path = self.config.content_path / f'Models/SRCNN_checkpoint_epoch_{epoch}.pth'
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': self.loss,
            'psnr': self.psnr
        }, checkpoint_path)

    # loads checkpoint model for training
    def load_checkpoint(self, checkpoint_path):
        if checkpoint_path.exists():
            checkpoint = torch.load(checkpoint_path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])  # Load scheduler state
            self.loss = checkpoint['loss']
            self.psnr = checkpoint['psnr']
            start_epoch = checkpoint['epoch'] + 1
            return start_epoch
        else:
            return 0
    
    # saves model metrics to .txt file
    def save_metrics(self):
        metrics_path = self.config.content_path / 'Models/metrics.txt'
        with open(metrics_path, 'w') as f:
            for epoch in range(len(self.loss['Train'])):
                if epoch % 20 == 0 or epoch == len(self.loss['Train']) - 1:
                    f.write(f"Epoch {epoch}:\n")
                    f.write(f"Train Loss: {self.loss['Train'][epoch]}\n")
                    f.write(f"Train PSNR: {self.psnr['Train'][epoch]}\n")
                    f.write(f"Validate Loss: {self.loss['Validate'][epoch]}\n")
                    f.write(f"Validate PSNR: {self.psnr['Validate'][epoch]}\n")

    # Train loop
    def training(self, checkpoint_path):
        start_epoch = self.load_checkpoint(checkpoint_path)
        self.model.train()
        try:
            for epoch in range(start_epoch, start_epoch + self.num_epochs):
                training_loss = 0
                psnr_total = 0
                with tqdm.tqdm(self.train_dataloader, unit="batch") as tepoch:
                    for batch_idx, (low, high) in enumerate(tepoch):
                        high_res = high.to(self.device)
                        low_res = low.to(self.device)
                        self.optimizer.zero_grad()
                        reconstructed_images = self.model(low_res)
                        loss, psnr = self.loss_MSE(reconstructed_images, high_res)
                        loss.backward()
                        psnr_total += psnr.item()
                        training_loss += loss.item()
                        self.optimizer.step()
                        if batch_idx % 20 == 0:
                            tepoch.set_description(f"Epoch {epoch}")
                            tepoch.set_postfix(loss=loss.item()/len(high_res), psnr=psnr.item())
                self.psnr['Train'].append(psnr_total/len(tepoch))
                self.loss['Train'].append(training_loss/len(tepoch))
                self.scheduler.step()
                self.validate()
                if epoch % 20 == 0 or epoch == start_epoch + self.num_epochs - 1:
                    self.save_checkpoint(epoch)
            self.save_metrics()
        except Exception as e:
            print(f"Training interrupted at epoch {epoch}: {e}")
            self.save_checkpoint(epoch)

    # Validation loop
    def validate(self):
        self.model.eval()
        valid_loss = 0
        psnr_total = 0
        with torch.no_grad():
            with tqdm.tqdm(self.val_dataloader, unit="batch") as tepoch:
                for batch_idx, (low, high) in enumerate(tepoch):
                    high_res = high.to(self.device)
                    low_res = low.to(self.device)
                    reconstructed_images = self.model(low_res)
                    loss, psnr = self.loss_MSE(reconstructed_images, high_res)
                    valid_loss += loss.item()
                    psnr_total += psnr.item()
                    if batch_idx % 20 == 0:
                        tepoch.set_description(f"Test")
                        tepoch.set_postfix(loss=loss.item()/len(high_res))
            self.psnr['Validate'].append(psnr_total/len(tepoch))
            self.loss['Validate'].append(valid_loss/len(tepoch))

def main():
    config = Config()
    setup = Setup(config)
    trainer = SRCNNTrainer(config)
    trainer.training()

if __name__ == "__main__":
    main()