# Super Resolution Network

This notebook serves as a fundamental guide for training and utilizing the superres model.
 
It also includes visual demonstrations and metrics of upscaled images.

## Setup
This cell just imports necessary libraries and sets global variables.

In [47]:
%matplotlib inline

import random
import torch
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from skimage import metrics
from torch.utils.data import DataLoader

# custom modules
from utils import  convert_img, upscale_img, evaluate

# select model scale
SCALE = 2 
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
torch.manual_seed(SEED) 
torch.cuda.manual_seed(SEED)

## Obtain dataset
Run next cell to download `DIV2K` dataset or use any high-resolution image data.

In [None]:
!python download.py

## Train model based on configuration

Configure `train_sresnet.py` and run next cell to initiate training or skip this step if you already have a trained model checkpoint (example is `2x-sresnet.pt`).

In [None]:
!python train_sresnet.py

## Load trained model

Load trained model with name in `PT_SAVED`.

In [None]:
PT_SAVED = '2x-sresnet.pt'

model = torch.load(PT_SAVED)['model'].to(DEVICE)
model.train(False)

print('Model loaded successfully')

## Upscale image

Either choose high-resolution image from image data and downscale it or load low-resolution image without downscaling.

In next cell, the image is manually selected, cropped and downscaled using `Image.BICUBIC` algorithm.

In [None]:
IMAGE = 'DIV2K/0420.png'

# get high-resolution image
hr = Image.open(IMAGE).convert('RGB').crop((200, 400, 800, 1000))

# resize high-resolution image using BICUBIC
lr = hr.resize((hr.width // SCALE, hr.height // SCALE), Image.BICUBIC)

hr.size, lr.size

Upscaling can be done manually using model or just calling `upscale_img` with proper params.

In [None]:
# upscale low-resolution image in PIL format using given trained model
sr = upscale_img(lr, model, input_format='pil', output_format='pil')
sr

In [None]:
# plot high, low and super resolution images

fig = plt.figure(figsize=(20, 10))
for i, [img, type] in enumerate([(hr, 'high res'), (lr, 'low res'), (sr, 'super res')]):
    fig.add_subplot(1, 3, i + 1)
    plt.title(type)
    plt.imshow(img)
    plt.axis('off')

plt.show()

We can use `structural similarity index` or `peak signal noise ratio` to evaluate upscaled images with original high-resolution image. 

In [None]:
# convert onto [0, 1] np arrays
original_img, bicubic_img, superres_img = (
    np.array(convert_img(hr, 'pil', '[0, 1]')),
    np.array(convert_img(lr.resize(hr.size, Image.BICUBIC), 'pil', '[0, 1]')),
    np.array(convert_img(sr, 'pil', '[0, 1]'))
)

# value from range (-1, 1)
print(f'SSIM bicubic:  {metrics.structural_similarity(original_img, bicubic_img, channel_axis=0, data_range=1):.3f}') 
print(f'SSIM superres: {metrics.structural_similarity(original_img, bicubic_img, channel_axis=0, data_range=1):.3f}') 

print()

# higher value means higher similarity (identical img produces zero division)
print(f'PSNR bicubic:  {metrics.peak_signal_noise_ratio(original_img, bicubic_img):.3f}') 
print(f'PSNR superres: {metrics.peak_signal_noise_ratio(original_img, superres_img):.3f}') 

`SSIM` gives us coef. from range (-1, 1) which represents structural similarity of given images.

`PSNR` gives us number which represents quality of reconstructed image.

## Metrics and loss
In this section are metrics and loss plots for 2x and 4x sresnets.

### 2x sresnet

In [None]:
INVALID = 1000

# load model state and count epochs without validation
model_state = torch.load('2x-sresnet.pt')
valid_div = model_state['vloss'].count(INVALID)

# get tloss and vloss with y indexes
tloss, vloss = model_state['tloss'], model_state['vloss'][valid_div:]
y = np.arange(len(tloss))

# figure plot
fig = plt.figure(figsize=(12, 6))
ax = fig.subplots()

# plot losses
ax.plot(y, tloss, 'b-', label='train loss')
ax.plot(y[len(tloss) - len(vloss):], vloss, 'r-', label='valid loss')
ax.axvline(valid_div, c='green', linestyle=':', label='epoch validation start')

ax.set_title('2x sresnet')
ax.set_xticks(np.arange(0, len(tloss) + 1, 5))
ax.legend()
plt.show()

print(f'Best train loss: {min(tloss):.4f}')
print(f'Best valid loss: {min(vloss):.4f}')

This model was validated from 25 epoch. We can also see model improvement on both losses.

Next, we will run model on random N image crops and then outputs avg `ssim` and `psnr`.

In [None]:
%%time

N = 100
CROP = 512
SCALE = 2

ssim, psnr = evaluate(N, model_state['model'], SCALE, CROP)

print(f'avg SSIM bicubic: {sum(ssim["bic"]) / N:.4f}') 
print(f'avg SSIM superres:  {sum(ssim["sres"]) / N:.4f}') 

print()

print(f'avg PSNR bicubic:  {sum(psnr["bic"]) / N:.4f}') 
print(f'avg PSNR superres: {sum(psnr["sres"]) / N:.4f}') 

## 4x sresnet

In [None]:
model_state = torch.load('4x-sresnet.pt')
tloss, vloss = model_state['tloss'], model_state['vloss']

fig = plt.figure(figsize=(12, 6))
ax = fig.subplots(1, 1)

ax.plot(tloss, 'b-', label='train loss')
ax.plot(vloss, 'r-', label='valid loss')

ax.set_title('4x sresnet')
ax.set_xticks(np.arange(0, len(tloss) + 1, 5))
ax.legend()
plt.show()

print(f'Best train loss: {min(tloss):.4f}')
print(f'Best valid loss: {min(vloss):.4f}')

In [None]:
%%time

N = 100
CROP = 512
SCALE = 4

ssim, psnr = evaluate(N, model_state['model'], SCALE, CROP)

print(f'avg SSIM bicubic: {sum(ssim["bic"]) / N:.4f}') 
print(f'avg SSIM superres:  {sum(ssim["sres"]) / N:.4f}') 

print()

print(f'avg PSNR bicubic:  {sum(psnr["bic"]) / N:.4f}') 
print(f'avg PSNR superres: {sum(psnr["sres"]) / N:.4f}') 