Datasets: [here](https://paperswithcode.com/datasets?q=&v=lst&o=match&mod=images&task=image-super-resolution)

## Sources
- https://paperswithcode.com/dataset/div2k
- https://github.com/sgrvinod/Deep-Tutorials-for-PyTorch
- https://jonathan-hui.medium.com/gan-super-resolution-gan-srgan-b471da7270ec
- https://github.com/labmlai/annotated_deep_learning_paper_implementations
- https://gitlab.fit.cvut.cz/dufekja4/bi-ml2-2023-dufekja4/-/blob/hw02/02/homework_02_B222.ipynb?ref_type=heads

## Concepts
- GAN - one NN for superres and second one for img rating in zero-sum game
- deep learning residual connections
- sub pixel convolution
- pretrain SRmodel and then use GAN with pretrained model



In [None]:
import os
import random
import numpy as np
import pandas as pd
import scipy
import torch
import torchvision
import tqdm

from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms.functional import pil_to_tensor
from torch import nn

from PIL import Image

SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED) 
torch.cuda.manual_seed(SEED)

In [None]:
class ImgTransform(object):
    def __init__(self, crop, scale, is_train=True, output='Tensor'):
        self.crop = crop
        self.scale = scale
        self.is_train = is_train
        self.output = output

    def __call__(self, img):

        if self.is_train:
            
            # crop HR image 
            left, top = random.randint(0, img.width - self.crop), random.randint(0, img.height - self.crop)
            right, bottom = left + self.crop, top + self.crop
    
            hr_img = img.crop((left, top, right, bottom))
        else:
            right, bottom = (img.width // self.scale) * self.scale, (img.height // self.scale) * self.scale
            hr_img = img.crop((0, 0, right, bottom))     
            
        # downscale hr image
        lr_img = hr_img.resize((hr_img.width // self.scale, hr_img.height // self.scale), Image.BICUBIC)
        
        assert lr_img.width * self.scale == hr_img.width
        assert lr_img.height * self.scale == hr_img.height
        assert (hr_img.width % self.scale, hr_img.height % self.scale) == (0, 0)
        
        # convert to tensor
        if self.output == 'Tensor':
            return pil_to_tensor(lr_img).type(torch.float), pil_to_tensor(hr_img).type(torch.float)
        
        return lr_img, hr_img

class Div2kDataset(Dataset):
    def __init__(self, dir='./data', transform=None):
        self.dir = dir
        self.transform = transform
        
        self.images = sorted([x for x in os.listdir(self.dir)])
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(f'{self.dir}/{self.images[idx]}')

        if self.transform:
            return self.transform(img)
            
train_dataset = Div2kDataset('DIV2K/HR', transform=ImgTransform(crop=128, scale=4))

In [None]:
BATCH_SIZE = 64

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class SubPixelBlock(nn.Module):
    def __init__(self, scale=4, k=3, n_channels=64):
        super().__init__()
        self.scale = scale
        self.k = k
        self.n_channels = n_channels

        self.layers = nn.Sequential(
            nn.Conv2d(self.n_channels, self.n_channels * (self.scale ** 2), self.k, padding=self.k // 2),
            nn.PixelShuffle(self.scale),
            nn.PReLU()
        )    
    
    def forward(self, x):
        return self.layers(x)

class ConvBlock(nn.Module):    
    def __init__(self, in_channels, out_channels, k=3, norm=False, activation=None):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k
        self.norm = norm

        # insert conv layer
        self.layers = [nn.Conv2d(self.in_channels, self.out_channels, self.k, padding=self.k // 2)]

        # insert batch norm layer
        if norm: self.layers.append(nn.BatchNorm2d())

        # insert activation func
        if activation is not None:
            self.layers.append(
                {
                    'prelu' : nn.PReLU(),
                    'tanh' : nn.Tanh()
                }[activation.lower()]
            )
        
        self.conv = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.conv(x)
        

class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, k=3, norm=False, activation=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.k = k
        self.norm = norm
        self.activation = activation

        self.conv_blocks = nn.Sequential(
            ConvBlock(self.in_channels, self.out_channels, self.k, self.norm, self.activation),
            ConvBlock(self.in_channels, self.out_channels, self.k, self.norm)
        )
        
    def forward(self, x):
        residual, x = x, self.conv_blocks(x)
        return x + residual

class SResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, scale=4, residual_cnt=1, sub_pix_cnt=1, norm=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.scale = scale
        self.residual_cnt = residual_cnt
        self.sub_pix_cnt = sub_pix_cnt
        self.norm = norm

        self.conv1 = ConvBlock(self.in_channels, 64, 9, activation='prelu')

        self.residual_blocks = nn.Sequential(
            *[ResidualConvBlock(64, 64, 3, activation='prelu') for _ in range(self.residual_cnt)]
        )

        self.conv2 = ConvBlock(64, 64, 9)

        self.sub_pixel_blocks = nn.Sequential(
            *[SubPixelBlock(self.scale) for _ in range(self.sub_pix_cnt)]
        )

        self.conv3 = ConvBlock(64, self.out_channels, 9, activation='tanh')

    def forward(self, x):

        # first conv with prelu
        x = self.conv1(x)

        # residual blocks with skipped conn
        skip, x = x, self.residual_blocks(x)
        x = self.conv2(x)
        x += skip

        # subpix blocks
        x = self.sub_pixel_blocks(x)
        
        # tanh conv block
        x = self.conv3(x)

        return x

In [None]:
gen = SResNet()

it = iter(train_dataloader)
lr, hr = next(it)

In [None]:
lr.shape, gen(lr).shape, hr.shape

In [None]:
from torch.optim import Adam

def train_epoch(loader, model, optimizer):

    train_loss = .0
    
    for batch in loader:
        lr, hr = batch
        lr, hr = lr.to(DEVICE), hr.to(DEVICE)
       
        optimizer.zero_grad()

        sr = model(lr)
        
        criterion = nn.MSELoss().to(DEVICE)
        loss = criterion(sr, hr)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * lr.shape[0]
        
        del lr, hr, sr

        break

    print(f"loss: {train_loss / len(loader)}")

opt = Adam(gen.parameters())

In [None]:
train_epoch(train_dataloader, gen, opt)