In [None]:
# import libraries
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import datasets

from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
import numpy as np
import random
import lpips
from IPython.display import clear_output
import warnings
warnings.filterwarnings('ignore')

from utils.utils import (downsample, upsample)
from utils.utils import unfreeze, freeze, forward_chop
from utils.fid_score import (get_generated_inception_stats, get_hr_inception_stats, 
                             calculate_frechet_distance)

from models.edsr_G import EDSR
from models.upsample_plus_unet import UNet

from dataset_utils.aim19_datasets import AugDataset, TestDataset

%matplotlib inline

In [None]:
BATCH_SIZE = 1
NUM_WORKERS = 1
SCALE_FACTOR = 4
CROP_SIZE = None # 128
DATASET =  'AIM19'
G_ARCH = 'EDSR'

## Load pretrained model

In [None]:
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
G = EDSR(scale_factor=SCALE_FACTOR, device='cuda').cuda()

In [None]:
G.load_state_dict(torch.load('path_to_state_dict'))
        
G.cuda();
freeze(G);

## Load Datasets

In [None]:
assert DATASET == 'AIM19'
dataset = TestDataset(hr_dir='path_to_hr_test', 
                      lr_dir='path_to_lr_test',
                      scale_factor=SCALE_FACTOR, crop_size=CROP_SIZE)

In [None]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

## LPIPS

In [None]:
# IMPORTANT: Inputs must be in [-1, 1]

In [None]:
loss_fn_alex = lpips.LPIPS(net='alex').cuda()

In [None]:
print("===> Calculate LPIPS.")

assert DATASET == 'AIM19'

losses = np.zeros((len(dataset)))
for i, (X, Y) in tqdm(enumerate(dataloader)):
    X = X.cuda()
    Y = Y.cuda()
    G_Y = G(Y)
    X = torch.clamp(X, -1, 1)
    G_Y = torch.clamp(G_Y, -1, 1).cuda()

    loss = loss_fn_alex(X, G_Y).squeeze()
    losses[i] = loss.item()
    del X, Y, G_Y, loss
    torch.cuda.empty_cache();
out = np.mean(losses)

print('mean LPIPS = %f'%out)

## FID

For **CelebA**:

In [None]:
mu1, sigma1 = get_hr_inception_stats(verbose=True, batch_size=50)

In [None]:
mu2, sigma2 = get_generated_inception_stats(G, verbose=True, batch_size=50)

In [None]:
calculate_frechet_distance(mu1, sigma1, mu2, sigma2)

For **AIM19**:
- Datasets of random crops for FID calculation are stored as h5 files and prepared using `utils/aim19_prepare_data.py` file.
- `h5dataset` from `aim19_h5_datasets.py` is used to extract images from h5 format.
- `h5dataset` outputs LR in $[-1, 1]$ and HR in $[0, 1]$, channels first.

In [None]:
assert DATASET == 'AIM19'
from dataset_utils.aim19_h5_datasets import h5dataset
from utils.aim19_fid_score import (get_hr_inception_stats, get_generated_inception_stats, 
                                   calculate_frechet_distance)

In [None]:
print('Prepare datasets for test partition.')
try:
    stats = np.load('path_to_hr_test_inception_stats in .npz format')
    mu, sigma = stats['mu'], stats['sigma']
except:
    d = h5dataset(partition='test', mode='hr')
    mu, sigma = get_hr_inception_stats(dataset=d, batch_size=50)

In [None]:
assert DATASET == 'AIM19'
d = h5dataset(partition='test', mode='lr')
m, s = get_generated_inception_stats(G=G, dataset=d, batch_size=50, verbose=True)
fid = calculate_frechet_distance(m, s, mu, sigma)
print('Test FID = %f'%fid)

## SSIM

In [None]:
from skimage.metrics import structural_similarity as compare_ssim

**IMPORTANT:** Inputs must be in [0, 255]

In [None]:
print("===> Calculate SSIM.")

assert DATASET == 'AIM19'

losses = np.zeros(len(dataset))
for i, (X, Y) in tqdm(enumerate(dataloader)):
    Y = Y.cuda()
    G_Y = G(Y)

    X = X[0].mul(0.5).add(0.5)
    G_Y = G_Y[0].mul(0.5).add(0.5)
    X = torch.clamp(X, 0, 1)
    G_Y = torch.clamp(G_Y, 0, 1)

    X = (X.permute(1, 2, 0).detach().cpu().numpy() * 255.0).round().astype(np.uint8)
    G_Y = (G_Y.permute(1, 2, 0).detach().cpu().numpy() * 255.0).round().astype(np.uint8)
    loss = compare_ssim(X, G_Y, multichannel=True).squeeze()

    losses[i] = loss.item()
    del X, Y, G_Y, loss
    torch.cuda.empty_cache();
        
out = losses.mean()
print('mean SSIM = %f'%out)

# PSNR

In [None]:
from piq import psnr

**IMPORTANT:** Inputs must be in [0, 1]

In [None]:
print("===> Calculate PSNR.")

assert DATASET == 'AIM19'

losses = np.zeros((len(dataset)))
for i, (X, Y) in tqdm(enumerate(dataloader)):
    X = X.cuda()
    Y = Y.cuda()
    G_Y = G(Y)

    X = torch.clamp(X.mul(0.5).add(0.5), 0, 1)
    G_Y = torch.clamp(G_Y.mul(0.5).add(0.5), 0, 1)

    loss = psnr(X, G_Y).squeeze()
    losses[i] = loss.item()
    del X, Y, G_Y, loss
    torch.cuda.empty_cache();
        
out = losses.mean()
print('mean PSNR = %f'%out)

# AIM19 Color Palettes (visualization & variance)

In [None]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

In [None]:
HR_dataset = AugDataset(datadir='path_to_aim19_hr_test', crop_size=128, 
                        flips=True, rotations=True)
LR_dataset = AugDataset(datadir='path_to_aim19_lr_test', crop_size=32, 
                        flips=True, rotations=True)

In [None]:
X_dataloader = DataLoader(HR_dataset, batch_size=100, num_workers=20, shuffle=False)
Y_dataloader = DataLoader(LR_dataset, batch_size=100, num_workers=20, shuffle=False)
X_iter = iter(cycle(X_dataloader))
Y_iter = iter(cycle(Y_dataloader))

In [None]:
def plot_rgb_cloud(cloud, ax):
    colors = np.clip(cloud, 0, 1)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_zticks([])
    ax.scatter(cloud[:, 0], cloud[:, 1], cloud[:, 2], c=colors)
    ax.set_xlabel('Red', labelpad=-10); ax.set_ylabel('Green', labelpad=-10); ax.set_zlabel('Blue', labelpad=-10);

In [None]:
SIZE = 128*8
s = 100
pc_var_OTS = np.zeros((s))

for k in tqdm(range(s)):

    fig = plt.figure(figsize=(4, 4), dpi=100)
    
    Y = next(iter(Y_dataloader))

    G = EDSR(scale_factor=SCALE_FACTOR, device='cuda').cuda()
    G.cuda();
    freeze(G);
    G.load_state_dict(torch.load('path_to_state_dict'))

    ax = fig.add_subplot(111, projection='3d')
    for i in range(Y.shape[0] // 20):
        Y_push = G(
            torch.tensor(Y[i*20:(i+1)*20, :, :, :], device='cuda', dtype=torch.float32, requires_grad=True)
        ).add(1).div(2).permute(0, 2, 3, 1).flatten(start_dim=0, end_dim=2)
        if i==0:
            Y_pushed = Y_push.detach().cpu()
        else:
            Y_pushed = torch.cat((Y_pushed, Y_push.detach().cpu()), dim=1)
            del Y_push
            torch.cuda.empty_cache()
    Y_0 = np.random.choice(Y_pushed[:, 0].cpu().detach().numpy(), size=SIZE)
    Y_1 = np.random.choice(Y_pushed[:, 1].cpu().detach().numpy(), size=SIZE)
    Y_2 = np.random.choice(Y_pushed[:, 2].cpu().detach().numpy(), size=SIZE)
    Y_pushed = np.stack((Y_0, Y_1, Y_2), axis=1)
    pc_var_OTS[k] = np.sum(np.var(Y_pushed, axis=0))
    plot_rgb_cloud(Y_pushed, ax)
    ax.set_xlim(0, 1); ax.set_ylim(0, 1); ax.set_zlim(0, 1); ax.title.set_text('OTS (ours)')
    del G
    
    clear_output(wait=True)
    plt.show(); plt.close(fig)

In [None]:
print('Variance of OTS (ours) color palette = %.2f +- %.2f'%(pc_var_OTS.mean(), pc_var_OTS.std()))