# Super Resolution Network


In [None]:
import torch
import matplotlib.pyplot as plt

import numpy as np
from PIL import Image
from skimage import metrics

# custom modules
from utils import convert_img, upscale_img

SEED = 42
SCALE = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(SEED) 
torch.cuda.manual_seed(SEED)

## Obtain dataset

In [None]:
!python download.py

## Train model based on configuration

In [None]:
!python train_sresnet.py

## Load trained model

In [None]:
model = torch.load('2x-sresnet.pt')['model'].to(DEVICE)
model.train(False)

print('model loaded')

In [None]:
scale = 2

# get high res image and downscale it 
hr = Image.open('DIV2K/0040.png').convert('RGB').crop((0, 0, 1300, 1300))
lr = hr.resize((hr.width // scale, hr.height // scale), Image.BICUBIC)
sr = upscale_img(lr, model)

for img, type in [(lr, 'low res'), (hr, 'high res'), (sr, 'super res')]:
    plt.title(type)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

## Metrics

In [None]:

original_img = np.array(convert_img(hr, 'pil', '[0, 1]'))
superres_img = np.array(convert_img(sr, 'pil', '[0, 1]'))

# value from range (-1, 1)
print('SSIM:', metrics.structural_similarity(original_img, superres_img, channel_axis=0, data_range=1)) 

# higher value means higher similarity (identical img produces zero division)
print('PSNR:', metrics.peak_signal_noise_ratio(original_img, superres_img)) 