
## Sources
- https://paperswithcode.com/dataset/div2k
- https://github.com/sgrvinod/Deep-Tutorials-for-PyTorch
- https://jonathan-hui.medium.com/gan-super-resolution-gan-srgan-b471da7270ec
- https://github.com/labmlai/annotated_deep_learning_paper_implementations
- https://gitlab.fit.cvut.cz/dufekja4/bi-ml2-2023-dufekja4/-/blob/hw02/02/homework_02_B222.ipynb?ref_type=heads

- SRESNN paper: https://arxiv.org/pdf/1501.00092.pdf
- structural similarity: https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf
- loss functions for SR: https://arxiv.org/pdf/1511.08861.pdf


Dataset used for training: https://data.vision.ee.ethz.ch/cvl/DIV2K/

Specifically `High Resolution Images`: train data and validation data


## Concepts
- GAN - one NN for superres and second one for img rating in zero-sum game
- deep learning residual connections
- sub pixel convolution
- pretrain SRmodel and then use GAN with pretrained model



In [None]:
# download DIV2K dataset 
!python download.py

In [None]:
import torch

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import Adam

from dataset import Div2kDataset
from utils import ImgTransform
from models import SResNet

SEED = 42
BATCH_SIZE = 16
SCALE = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED) 
torch.cuda.manual_seed(SEED)

In [None]:
# define train image transformer to crop 
transform = ImgTransform(crop=64, scale=SCALE, is_train=True, output='Tensor')

train_dataset = Div2kDataset('DIV2K', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
gen = SResNet(3, 3, res_block_cnt=10, scale=SCALE).to(DEVICE)
criterion = nn.MSELoss().to(DEVICE)

optimizer = Adam(gen.parameters())

lr, hr = next(iter(train_dataloader))
sr = gen(lr)


display(len(gen.res_blocks))

In [None]:
display(lr.shape, hr.shape, sr.shape)

In [None]:
def train_epoch(loader, model, criterion, optimizer):

    train_loss = .0
    
    for [lr, hr] in tqdm(loader, total=len(loader)):
        lr, hr = lr.to(DEVICE), hr.to(DEVICE)
        
        # zero grad and forward batch trough model
        optimizer.zero_grad()
        sr = model(lr)
        
        # calculate and backprop loss 
        loss = criterion(sr, hr)
        loss.backward()

        # adjust optimizer weights
        optimizer.step()

        train_loss += loss.item() * lr.shape[0]
        
        del lr, hr, sr

    print(f"loss: {train_loss / len(loader)}")


In [None]:
train_epoch(train_dataloader, gen, criterion, optimizer)