
## Sources
- https://paperswithcode.com/dataset/div2k
- https://github.com/sgrvinod/Deep-Tutorials-for-PyTorch
- https://jonathan-hui.medium.com/gan-super-resolution-gan-srgan-b471da7270ec
- https://github.com/labmlai/annotated_deep_learning_paper_implementations
- https://gitlab.fit.cvut.cz/dufekja4/bi-ml2-2023-dufekja4/-/blob/hw02/02/homework_02_B222.ipynb?ref_type=heads

- SRESNN paper: https://arxiv.org/pdf/1501.00092.pdf
- structural similarity: https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf
- loss functions for SR: https://arxiv.org/pdf/1511.08861.pdf


Dataset used for training: https://data.vision.ee.ethz.ch/cvl/DIV2K/

Specifically `High Resolution Images`: train data and validation data


## Concepts
- GAN - one NN for superres and second one for img rating in zero-sum game
- deep learning residual connections
- sub pixel convolution
- pretrain SRmodel and then use GAN with pretrained model



In [None]:
import numpy as np
import pandas as pd
import torch
import torchvision
import tqdm

from PIL import Image
from matplotlib import pyplot as plt

from torch.utils.data import DataLoader
from torch import nn
from torch.optim import Adam

SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED) 
torch.cuda.manual_seed(SEED)

In [None]:
from dataset import Div2kDataset, ImgTransform

train_dataset = Div2kDataset('DIV2K', transform=ImgTransform(crop=128, scale=4))

In [None]:
BATCH_SIZE = 64

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
from models import SResNet

gen = SResNet()

it = iter(train_dataloader)
lr, hr = next(it)

In [None]:
lr.shape, gen(lr).shape, hr.shape

In [None]:
def train_epoch(loader, model, optimizer):

    train_loss = .0
    
    for batch in loader:
        lr, hr = batch
        lr, hr = lr.to(DEVICE), hr.to(DEVICE)
       
        optimizer.zero_grad()

        sr = model(lr)
        
        criterion = nn.MSELoss().to(DEVICE)
        loss = criterion(sr, hr)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * lr.shape[0]
        
        del lr, hr, sr

        break

    print(f"loss: {train_loss / len(loader)}")

opt = Adam(gen.parameters())

In [None]:
train_epoch(train_dataloader, gen, opt)