In [None]:
# https://github.com/leftthomas/SRGAN

import os
from math import log10

import pandas as pd
import torch
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
from tqdm.notebook import tqdm

from pathlib import Path
from PIL import Image, ImageFile

# import pytorch_ssim
import data_utils
from data_utils import CustumDataset, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator

from skimage.metrics import structural_similarity as compare_ssim
import imutils
import cv2
import numpy as np

high_res = 128
upscale_factor = 8

In [None]:
def compute_ssim(A, B):
    
    pil_A = pil_transform(A[0])
    gray_A = cv2.cvtColor(np.array(pil_A), cv2.COLOR_BGR2GRAY)

    pil_B = pil_transform(B[0])
    gray_B = cv2.cvtColor(np.array(pil_B), cv2.COLOR_BGR2GRAY)

    score, diff = compare_ssim(gray_A, gray_B, full=True)
    diff = (diff * 255).astype('uint8')

    return score

In [None]:
test_set = CustumDataset('data/valid/test', high_res, upscale_factor)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)

netG = Generator(upscale_factor).cuda()
netD = Discriminator().cuda()

generator_criterion = GeneratorLoss().cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

checkpoint = torch.load(f'./res{high_res}_uf{upscale_factor}.tar')

netG.load_state_dict(checkpoint['G_model_state_dict'])
netD.load_state_dict(checkpoint['D_model_state_dict'])
optimizerG.load_state_dict(checkpoint['G_optimizer_state_dict'])
optimizerD.load_state_dict(checkpoint['D_optimizer_state_dict'])

g_loss = checkpoint['G_loss']
d_loss = checkpoint['D_loss']

netG.eval()
netD.eval()

upsample_nearest = torch.nn.Upsample(scale_factor=upscale_factor, mode='nearest').cuda()
upsample_bilinear = torch.nn.Upsample(scale_factor=upscale_factor, mode='bilinear').cuda()

pil_transform = torchvision.transforms.Compose([
                            torchvision.transforms.ToPILImage(),
                            ])

print('done')


In [None]:
out_path = f'results/res_{high_res}_srf_{upscale_factor}/'
images = []

if not os.path.exists(out_path):
    os.makedirs(out_path)

with torch.no_grad():
    for lr, hr in test_loader:
        batch_size = lr.size(0)
        results = {'mse': 0, 'ssim': 0, 'psnr': 0}

        if torch.cuda.is_available():
            lr = lr.cuda()
            hr = hr.cuda()

        sr = netG(lr)

        lr_upsample = upsample_nearest(lr)
        lr_bilinear = upsample_bilinear(lr)

        
        images.extend([display_transform()(lr_upsample.data.cpu().squeeze(0)),
                        display_transform()(lr_bilinear.data.cpu().squeeze(0)),
                        display_transform()(sr.data.cpu().squeeze(0)),
                        display_transform()(hr.data.cpu().squeeze(0))])

        mse_sr = ((sr - hr) ** 2).data.mean()
        psnr_sr = 10 * log10((hr.max()**2) / mse_sr)
        ssim_sr = compute_ssim(sr, hr)

        mse_lr = ((lr_upsample - hr) ** 2).data.mean()
        psnr_lr = 10 * log10((hr.max()**2) / mse_lr)
        ssim_lr = compute_ssim(lr_upsample, hr)

        mse_bl = ((lr_bilinear - hr) ** 2).data.mean()
        psnr_bl = 10 * log10((hr.max()**2) / mse_bl)
        ssim_bl = compute_ssim(lr_bilinear, hr)

        print('----------------------')
        print(mse_lr, psnr_lr, ssim_lr)
        print(mse_bl, psnr_bl, ssim_bl)
        print(mse_sr, psnr_sr, ssim_sr)

images = torch.stack(images)
print(images.size())
saving_image = utils.make_grid(images, nrow=4, padding=10)
utils.save_image(saving_image, out_path + 'result.png', padding=5)
