In [None]:
import cv2
import time
import torch
import matplotlib.pyplot as plt
import math
import numpy as np
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
import torchvision.transforms as tf
from util.local_visualizer import LocalVisualizer

In [None]:
from PIL import Image
from skimage import data, img_as_float
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr

In [None]:
NAME = 'gen_new4'
# DS_NAME = 'HFlickr'
# DS_NAME = 'Hday2night'
# DS_NAME = 'HAdobe5k'
# DS_NAME = 'HCOCO'
DS_NAME = ''

# defaults = {
#     'dataroot': f'../../embedding_data/{DS_NAME}',
#     'model': 'emb',
#     'dataset_mode': 'emb',
#     'dataset_root': f'../../embedding_data/{DS_NAME}',
#     'name': NAME
# }

defaults = {
    'dataroot': f'../../generator_data/{DS_NAME}',
    'model': 'my',
    'dataset_mode': 'my',
    'dataset_root': f'../../generator_data/{DS_NAME}',
    'embedding_save_dir': './checkpoints/emb_vidit4',
    'name': NAME
}

In [None]:
opt = TestOptions(defaults=defaults).parse()

In [None]:
# hard-code some parameters for test
opt.num_threads = 0   # test code only supports num_threads = 1
opt.batch_size = 1    # test code only supports batch_size = 1
opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.

# opt.max_dataset_size = 200
# opt.gpu_ids = []    # use cpu

In [None]:
assert opt.isTrain == False

In [None]:
dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset)    # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)

model = create_model(opt)      # create a model given opt.model and other options
# model.setup(opt)               # regular setup: load and print networks; create schedulers
model.load_networks('latest')
model.print_networks(verbose=False)
model.eval()

In [None]:
# assert model.device == torch.device('cpu')

In [None]:
class NormalizeInverse(tf.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor)

In [None]:
loss_psnr = 0.0
loss_mse = 0.0
loss_fmse = 0.0

mse_scores = []
fmse_scores = []

image_size = (256, 256)

def to_img(tensor, unnorm=None):
    tensor = torch.squeeze(tensor.detach().to('cpu'))
    if unnorm is not None:
        tensor = unnorm(tensor)
    np_img = tensor.numpy()
    np_img = np_img.transpose((1, 2, 0))
#     np_img = cv2.resize(np_img, image_size, interpolation=cv2.INTER_CUBIC)
    return (np_img.clip(0, 1) * 255).astype(np.uint8)
    
unnorm = NormalizeInverse((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
unnorm1 = NormalizeInverse((0.5,), (0.5,))

print(f'Testing {DS_NAME}, {dataset_size} images')
for i, data in enumerate(dataset):
#     if i > 0:
#         break
    model.set_input(data)
    model.test()
    
    mask = data['mask'].numpy()
#     mask = unnorm1(mask)
    fore_area = np.sum(mask)
    mask = mask[...,np.newaxis]
#     mask3 = np.dstack([mask, mask, mask])
    
    real = to_img(model.real, unnorm=unnorm)
    harmonized = to_img(model.harmonized, unnorm=unnorm)
#     harmonized = harmonized * mask3 + real * (1 - mask3)
    
    mse_score = mse(harmonized, real)
    fmse_score = mse(harmonized * mask, real * mask) * 256*256 / fore_area
    psnr_score = psnr(real, harmonized, data_range=harmonized.max() - harmonized.min())
    
    loss_psnr += psnr_score
    loss_mse += mse_score
    loss_fmse += fmse_score
    
    mse_scores.append(mse_score)
    fmse_scores.append(fmse_score)
    
    if i % 100 == 0:
        print(f'Done {i+1} / {dataset_size}: MSE={mse_score:.2f}, PSNR={psnr_score:.2f}, FMSE={fmse_score:.2f}')
#         display = np.concatenate([real, harmonized], axis=1)
#         fig = plt.figure(figsize=(15, 15))
#         plt.imshow(display)
#         plt.show()

with open('mse.txt', 'w') as f:
    for mse_score in mse_scores:
        f.write(f'{mse_score:.2f}\n')

with open('fmse.txt', 'w') as f:
    for fmse_score in fmse_scores:
        f.write(f'{fmse_score:.2f}\n')

print(f'Average MSE loss: {(loss_mse / dataset_size):.2f}')
print(f'Average PSNR loss: {(loss_psnr / dataset_size):.2f}')
print(f'Average FMSE loss: {(loss_fmse / dataset_size):.2f}')