In [None]:
import numpy as np
import skimage
import cv2
import math
import pywt
from src.srgan import Generator
from src.data import SuperResolutionImageDataset
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image

# 1. PSNR Similarity

In [None]:
def compute_psnr(hr,sr):
  mse = np.mean((hr-sr)**2)
  max = 255.0
  psnr = 10*math.log10(max**2/mse)
  return psnr

# 2. SSIM

# 3. Histogram Intersection

In [None]:
def downgrade_step(img,downgrade):
  rows,cols,chan = img.shape
  if rows%downgrade != 0 or cols%downgrade!=0 or (not isinstance(downgrade,int)):
    print('Not a valid degradation!')
    return None;

  new_rows = int(rows/downgrade)
  new_cols = int(cols/downgrade)

  new_img = np.zeros(img.shape,dtype=int)

  for c in range(chan):
    new_img[:,:,c] = join_blocks(skimage.util.view_as_blocks(img[:,:,c],(downgrade,downgrade)))


  return new_img

In [None]:
def get_color_histogram(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    n_channels = hsv.shape[2]
    channels = list(range(n_channels))
    sizes = [256,]*n_channels
    ranges = [0, 255]*n_channels
    hist = cv2.calcHist(hsv, channels, None, sizes, ranges)
    return hist

In [None]:
def get_texture_histogram(img):
    r,g,b = cv2.split(img)
    approx = []
    horiz = []
    vert = []
    for x in [b,g,r]:
        cA, (cH, cV, _) = pywt.dwt2(x, 'haar')
        approx.append(cA)
        horiz.append(cH)
        vert.append(cV)
    img_approx = cv2.merge(approx)
    img_horiz = cv2.merge(horiz)
    img_vert = cv2.merge(vert)
    # Step 5: assign weights to approx, horiz, and vert
    new_img = cv2.addWeighted(img_approx, 0.75, img_horiz, 0.25, 0.0)
    new_img = cv2.addWeighted(new_img, 0.8, img_vert, 0.2, 0.0)
    return get_color_histogram(new_img.astype('float32'))

In [None]:
def get_distance(histA,histB):
    histA = histA.flatten()
    histA = histA/(sum(histA))
    histB = histB.flatten()
    histB = histB/(sum(histB))
    denominator = min(np.sum(histA),np.sum(histB))
    numerator = 0
    for i in range(0,8*8*8):
        minimum = histA[i] if histA[i] < histB[i] else histB[i]
        numerator = numerator+minimum
    return 1-numerator/denominator

def get_chi_distance(histA,histB):
    histA = histA.flatten()
    histA = histA/(sum(histA))
    histB = histB.flatten()
    histB = histB/(sum(histB))
    dist = 0.0
    for i in range(0,8*8*8):
        denom = histA[i]+histB[i]
        if denom != 0:
            dist = dist + np.square((histA[i]-histB[i]))/(histA[i]+histB[i])
    return dist*0.5

In [None]:
def get_similarity_pwh(hr,sr,n_bins):
  new_shape = tuple(t//n_bins for t in hr.shape[:2])
  hr_blocks = skimage.util.view_as_blocks(hr,(new_shape[0],new_shape[1],3))
  sr_blocks = skimage.util.view_as_blocks(sr,(new_shape[0],new_shape[1],3))

  distances_list_color = []
  distances_list_texture = []
  distances_list = []

  for i in range(hr_blocks.shape[0]):
    for j in range(hr_blocks.shape[1]):
      sub_img_hr = hr_blocks[i,j,0,:,:]
      sub_img_sr = sr_blocks[i,j,0,:,:]
      hist_sub_hr = get_color_histogram(sub_img_hr)
      texture_hist_sub_hr = get_texture_histogram(sub_img_hr)
      hist_sub_sr = get_color_histogram(sub_img_sr)
      texture_hist_sub_sr = get_texture_histogram(sub_img_sr)


      distances_list_color.append(get_chi_distance(hist_sub_hr,hist_sub_sr))
      distances_list_texture.append(get_chi_distance(texture_hist_sub_hr,texture_hist_sub_sr))
      distances_list.append(get_chi_distance(hist_sub_hr,hist_sub_sr)*0.5+get_chi_distance(texture_hist_sub_hr,texture_hist_sub_sr)*0.5)

  color_distance = sum(distances_list_color)/len(distances_list_color)
  texture_distance = sum(distances_list_texture)/len(distances_list_texture)
  distance = sum(distances_list)/len(distances_list)
  return [color_distance,texture_distance,distance]

In [None]:
def get_similarity_pwh2(hr,sr,n_bins):
  distances_list_color = []
  distances_list_texture = []
  distances_list = []

  hist_hr = get_color_histogram(hr)
  texture_hist_hr = get_texture_histogram(hr)
  hist_sr = get_color_histogram(sr)
  texture_hist_sr = get_texture_histogram(sr)

  distances_list_color.append(get_chi_distance(hist_hr,hist_sr))
  distances_list_texture.append(get_chi_distance(texture_hist_hr,texture_hist_sr))
  distances_list.append(get_chi_distance(hist_hr,hist_sr)*0.5+get_chi_distance(texture_hist_hr,texture_hist_sr)*0.5)

  color_distance = sum(distances_list_color)/len(distances_list_color)
  texture_distance = sum(distances_list_texture)/len(distances_list_texture)
  distance = sum(distances_list)/len(distances_list)
  return [color_distance,texture_distance,distance]

# 4. Evaluate Model

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

In [None]:
#Input Parameters
data_path = "./data/Renders/"
r = 4
n_channels = 3
B = 1
batch_size_train = 128
batch_size_validation = 128
workers = 1
seed = 1317
train_test_val_split = [.7, .15, .15]

In [None]:
hr_size = (128,128)
lr_size = (hr_size[0]//r, hr_size[1]//r)
hr_dimension = (*hr_size,n_channels)
lr_dimension = (*lr_size,n_channels)

In [None]:
dataset = SuperResolutionImageDataset(
    root = data_path,
    transform = transforms.Compose([
        transforms.RandomCrop(hr_size),
    ]),
    target_transform = transforms.Compose([
        # transforms.GaussianBlur(3,1),
        transforms.Resize(lr_size),
    ])
)

random_generator = torch.Generator().manual_seed(seed)
train_dataset, test_dataset, validation_dataset = torch.utils.data.random_split(dataset,train_test_val_split,random_generator)

validation_dataloader = DataLoader(
    validation_dataset,
    batch_size = batch_size_validation,
    shuffle = True,
    num_workers = workers
)

In [None]:
netG = Generator(lr_dimension,B)
netG.to(device)
gen_load = torch.load('/Users/julionevado/Documents/Personal/SRGAN/checkpoints_perc_disc/generator')
netG.load_state_dict(gen_load['model_state_dict'])

In [None]:
img_list_lr = []
img_list_hr = []

for i,data in enumerate(validation_dataloader,0):
    for batch_element in range(data["lr_sample"].shape[0]):
        img_list_lr.append(data["lr_sample"][batch_element,:,:,:])
    for batch_element in range(data["hr_sample"].shape[0]):
        img_list_hr.append(data["hr_sample"][batch_element,:,:,:])

In [None]:
#https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/2
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor
    
unorm = UnNormalize(mean = [0.5,0.5,0.5],
            std = [0.5,0.5,0.5])

In [None]:
cdist_list = []
tdist_list = []
dist_list = []
psnr_list = []
for iter in range(len(img_list_hr)):
    hr = unorm(img_list_hr[iter].to(device).unsqueeze(0)).cpu().numpy().transpose(2,3,1,0).squeeze(3)
    sr = unorm(netG(img_list_lr[iter].to(device).unsqueeze(0)).detach()).cpu().numpy().transpose(2,3,1,0).squeeze(3)
    cdist, tdist, dist = get_similarity_pwh2(hr,sr,16)
    cdist_list.append(cdist)
    tdist_list.append(tdist)
    dist_list.append(dist)
    psnr_list.append(compute_psnr(hr,sr))

In [None]:
list_mean = lambda l: sum(l)/len(l)
print(list_mean(cdist_list))
print(list_mean(tdist_list))
print(list_mean(dist_list))
print(list_mean(psnr_list))

# Plot Images

In [None]:
netG = Generator(lr_dimension,B)
netG.to(device)
gen_load = torch.load('/Users/julionevado/Documents/Personal/SRGAN/checkpoints_perc_disc/generator')
netG.load_state_dict(gen_load['model_state_dict'])

In [None]:
for i in range(4):
    hr = img_list_hr[i]
    lr = img_list_lr[i]
    Image.fromarray((unorm(hr.to(device).unsqueeze(0)).cpu().numpy().transpose(2,3,1,0).squeeze(3)*255).astype(np.uint8)).save(f'memory_images/ex{i+1}_hr.png')
    Image.fromarray((unorm(lr.to(device).unsqueeze(0)).cpu().numpy().transpose(2,3,1,0).squeeze(3)*255).astype(np.uint8)).save(f'memory_images/ex{i+1}_lr.png')
    Image.fromarray((unorm(netG(lr.to(device).unsqueeze(0)).detach()).cpu().numpy().transpose(2,3,1,0).squeeze(3)*255).astype(np.uint8)).save(f'memory_images/ex{i+1}_sr_perc_disc.png')

# Estimate Resolution

In [None]:
hr = cv2.imread('/Users/julionevado/Documents/Personal/SRGAN/memory_images/ex1_hr.png')
sr = cv2.imread('/Users/julionevado/Documents/Personal/SRGAN/memory_images/ex1_sr.png')

In [None]:
distance_list = []
for i in [2,4,8,16]:
  try:
    distance_list.append((i,get_similarity_pwh(downgrade_step(hr,i).astype('uint8'),sr,8)))
  except:
    distance_list.append((i,0))