Datasets: [here](https://paperswithcode.com/datasets?q=&v=lst&o=match&mod=images&task=image-super-resolution)

## Sources
- https://paperswithcode.com/dataset/div2k
- https://github.com/sgrvinod/Deep-Tutorials-for-PyTorch
- https://jonathan-hui.medium.com/gan-super-resolution-gan-srgan-b471da7270ec

## Concepts
- GAN - one NN for superres and second one for img rating in zero-sum game




In [None]:
import os
import random
import numpy as np
import pandas as pd
import scipy
import torch
import torchvision
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms.functional import pil_to_tensor
from torch import nn

from PIL import Image

SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED) 
torch.cuda.manual_seed(SEED)

In [None]:
class ImgTransform(object):
    def __init__(self, crop, scale, is_train=True, output='Tensor'):
        self.crop = crop
        self.scale = scale
        self.is_train = is_train
        self.output = output

    def __call__(self, img):

        if self.is_train:
            
            # crop HR image 
            left, top = random.randint(0, img.width - self.crop), random.randint(0, img.height - self.crop)
            right, bottom = left + self.crop, top + self.crop
    
            hr_img = img.crop((left, top, right, bottom))
        else:
            right, bottom = (img.width // self.scale) * self.scale, (img.height // self.scale) * self.scale
            hr_img = img.crop((0, 0, right, bottom))     
            
        # downscale hr image
        lr_img = hr_img.resize((hr_img.width // self.scale, hr_img.height // self.scale), Image.BICUBIC)
        
        assert lr_img.width * self.scale == hr_img.width
        assert lr_img.height * self.scale == hr_img.height
        assert (hr_img.width % self.scale, hr_img.height % self.scale) == (0, 0)
        
        # convert to tensor
        if self.output == 'Tensor':
            return pil_to_tensor(lr_img), pil_to_tensor(hr_img)
        
        return lr_img, hr_img

class Div2kDataset(Dataset):
    def __init__(self, dir='./data', transform=None):
        self.dir = dir
        self.transform = transform
        
        self.images = sorted([x for x in os.listdir(self.dir)])
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(f'{self.dir}/{self.images[idx]}')

        if self.transform:
            return self.transform(img)
            
train_dataset = Div2kDataset('DIV2K/HR', transform=ImgTransform(crop=128, scale=4))

In [None]:
BATCH_SIZE = 64

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, 3),
            nn.ReLU()
        )

        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.flatten(x)
        return x

In [None]:
gen = Generator()

loader = iter(train_dataloader)
lr, hr = next(loader)

gen.forward(lr.type(torch.float))