# Super Resolution Network

This notebook is made as a basic guideline to training and using superres model.

In [None]:
import torch
import matplotlib.pyplot as plt

import numpy as np
from PIL import Image
from skimage import metrics
from models import Generator, SResNet, Discriminator

# custom modules
from utils import convert_img, upscale_img

# select model scale
SCALE = 2 
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED) 
torch.cuda.manual_seed(SEED)

## Obtain dataset
Run next cell to download `DIV2K` dataset or use any high-resolution image data.

In [None]:
!python download.py

## Train model based on configuration

Configure `train_sresnet.py` and run next cell to initiate training or skip this step if you already have trained model.

In [None]:
!python train_sresnet.py

## Load trained model

Load trained model with name in `PT_SAVED`.

In [None]:
PT_SAVED = '2x-sresnet.pt' 

model = torch.load(PT_SAVED)['model'].to(DEVICE)
model.train(False)

print('Model loaded successfully')

## Upscale image

Either choose high-resolution image from image data and downscale it or load low-resolution image without downscaling.

In [None]:
IMAGE = 'DIV2K/0420.png'

# get high-resolution image
hr = Image.open(IMAGE).convert('RGB').crop((200, 400, 800, 1000))

# resize high-resolution image using BICUBIC
lr = hr.resize((hr.width // SCALE, hr.height // SCALE), Image.BICUBIC)

hr, lr

In [None]:
# upscale low-resolution image in PIL format using given trained model
sr = upscale_img(lr, model) 

In [None]:

# plot high, low and super resolution images

fig = plt.figure(figsize=(20, 10))
for i, [img, type] in enumerate([(hr, 'high res'), (lr, 'low res'), (sr, 'super res')]):
    fig.add_subplot(1, 3, i + 1)
    plt.title(type)
    plt.imshow(img)
    plt.axis('off')

plt.show()

## Metrics

We can use `structural similarity index` or `peak signal noise ratio` to evaluate upscaled images with original high-resolution image. 

In [None]:
original_img, bicubic_img, superres_img = (
    np.array(convert_img(hr, 'pil', '[0, 1]')),
    np.array(convert_img(lr.resize(hr.size, Image.BICUBIC), 'pil', '[0, 1]')),
    np.array(convert_img(sr, 'pil', '[0, 1]'))
)

# value from range (-1, 1)
print(f'SSIM bicubic:  {metrics.structural_similarity(original_img, bicubic_img, channel_axis=0, data_range=1):.3f}') 
print(f'SSIM superres: {metrics.structural_similarity(original_img, bicubic_img, channel_axis=0, data_range=1):.3f}') 

print()

# higher value means higher similarity (identical img produces zero division)
print(f'PSNR bicubic:  {metrics.peak_signal_noise_ratio(original_img, bicubic_img):.3f}') 
print(f'PSNR superres: {metrics.peak_signal_noise_ratio(original_img, superres_img):.3f}') 