## Libraries

In [None]:
import time
from utils.data_load import DataLoad
from models.model import create_model
import torch
import os
import torchvision
import torch.nn.functional as F
from torch.utils import data
import torchvision.transforms as transforms

from piqa import psnr, ssim 
from piqa.utils.functional import gaussian_kernel

## Settings

In [None]:
class Opion():
    
    def __init__(self):
            
        self.dataroot= r'/home/klaudiaplk/Magisterka/Datasets/Paris_street_view/paris_street_view_fragment' #image dataroot
        self.maskroot= r'/home/klaudiaplk/Magisterka/Datasets/Irregular_Masks/test_mask/mask/testing_mask_dataset'#mask dataroot
        self.batchSize= 1   # Need to be set to 1
        self.fineSize=256 # image size
        self.input_nc=3  # input channel size for first stage
        self.input_nc_g=6 # input channel size for second stage
        self.output_nc=3# output channel size
        self.ngf=64 # inner channel
        self.ndf=64# inner channel
        self.which_model_netD='basic' # patch discriminator
        self.which_model_netF='feature'# feature patch discriminator
        self.which_model_netG='unet_csa'# seconde stage network
        self.which_model_netP='unet_256'# first stage network
        self.triple_weight=1
        self.name='CSA_inpainting'
        self.n_layers_D='3' # network depth
        self.gpu_ids=[0]
        self.model='csa_net'
        self.checkpoints_dir=r'/home/klaudiaplk/Magisterka/ImageInpainting/checkpoints' #
        self.norm='instance'
        self.fixed_mask=1
        self.use_dropout=False
        self.init_type='normal'
        self.mask_type='random'
        self.lambda_A=100
        self.threshold=5/16.0
        self.stride=1
        self.shift_sz=1 # size of feature patch
        self.mask_thred=1
        self.bottleneck=512
        self.gp_lambda=10.0
        self.ncritic=5
        self.constrain='MSE'
        self.strength=1
        self.init_gain=0.02
        self.cosis=1
        self.gan_type='lsgan'
        self.gan_weight=0.2
        self.overlap=4
        self.skip=0
        self.display_freq=10
        self.print_freq=2
        self.save_latest_freq=5
        self.save_epoch_freq=2
        self.continue_train=False
        self.epoch_count=1
        self.phase='train'
        self.which_epoch=''
        self.niter=2
        self.niter_decay=4
        self.beta1=0.5
        self.lr=0.0002
        self.lr_policy='lambda'
        self.lr_decay_iters=50
        self.isTrain=True


## Test Dataset

In [None]:
opt = Opion()
transform_mask = transforms.Compose(
    [transforms.Resize((opt.fineSize,opt.fineSize)),
     transforms.ToTensor(),
    ])
transform = transforms.Compose(
    [
     transforms.Resize((opt.fineSize,opt.fineSize)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)])

dataset_test = DataLoad(opt.dataroot, opt.maskroot, transform, transform_mask)
iterator_test = (data.DataLoader(dataset_test, batch_size=opt.batchSize,shuffle=True))
print(len(dataset_test))
model = create_model(opt)
total_steps = 0

## Metrics
Metrics that were used to measure the correctness of the algorithm:

PSNR - (Peak Signal-to-Noise Ratio)

SSIM - (Structure Similarity Index Measure)

MSE - (Mean Square Error)

In [None]:
criterionMSE = torch.nn.MSELoss()

## Model Loading

In [None]:
load_epoch=6
model.load(load_epoch)

## Model Testing

In [None]:
psnr_results = []
ssim_results = []
mse_results = []

In [None]:
save_dir = '/home/klaudiaplk/Magisterka/ImageInpainting/checkpoints/true'
if os.path.exists(save_dir) is False:
    os.makedirs(save_dir)

for count, (image, mask) in enumerate(iterator_test):
    iter_start_time = time.time()
    image=image.cuda()
    mask=mask.cuda()
    mask=mask[0][0]
    mask=torch.unsqueeze(mask,0)
    mask=torch.unsqueeze(mask,1)
    mask=mask.bool()

    model.set_input(image,mask)
    model.set_gt_latent()
    model.test()
    real_A,real_B,fake_B=model.get_current_visuals()
    pic = (torch.cat([real_A, real_B,fake_B], dim=0) + 1) / 2.0
    torchvision.utils.save_image(pic, '%s/Image_(%d)_(%dof%d).jpg' % (
    save_dir, count, count + 1, len(dataset_test)), nrow=1)
    
    # Calculate MSE
    acc_mse = criterionMSE(real_B, fake_B).item()
    print("For {} batch:".format(count + 1)) 
    print("MSE ==> {}".format(acc_mse))
    mse_results.append(acc_mse)
    # Calculate PSNR
    acc_psnr = psnr.psnr(real_B, fake_B).item()
    print("{} ==> {}".format("PSNR", acc_psnr))
    psnr_results.append(acc_psnr)
    # Calculate SSIM
    kernel = gaussian_kernel(11).repeat(3, 1, 1).cuda()
    acc_ssim = ssim.ssim(real_B, fake_B, kernel)
    print("{} ==> {}".format("SSIM", acc_ssim[0].item()))
    ssim_results.append(acc_ssim[0].item())
    print(ssim_results)

# MSE result for test data
mse_test_data = sum(mse_results) / len(mse_results)
print("MSE for test data is: {}".format(mse_test_data))

# PSNR result for test data
psnr_test_data = sum(psnr_results) / len(psnr_results)
print("PSNR for test data is: {}".format(psnr_test_data))

# SSIM result for test data
ssim_test_data = sum(ssim_results) / len(ssim_results)
print("SSIM for test data is: {}".format(ssim_test_data))
