## Libraries

In [1]:
import time
from utils.data_load import DataLoad
from models.model import create_model
import torch
import os
import torchvision
from torch.utils import data
import torchvision.transforms as transforms

from piqa import psnr, ssim 
from piqa.utils.functional import gaussian_kernel

## Settings

In [2]:
class Opion():
    
    def __init__(self):
            
        self.dataroot = r'/path/to/images/'  # image dataroot
        self.maskroot = r'/path/to/masks/'  # mask dataroot
        self.batchSize = 1   # Need to be set to 1
        self.fineSize = 256  # image size
        self.input_nc = 3  # input channel size for first stage
        self.input_nc_g = 6  # input channel size for second stage
        self.output_nc = 3  # output channel size
        self.ngf = 64  # inner channel
        self.ndf = 64  # inner channel
        self.which_model_netD = 'basic'  # patch discriminator
        self.which_model_netF = 'feature'  # feature patch discriminator
        self.which_model_netG = 'unet_csa'  # seconde stage network
        self.which_model_netP = 'unet_256'  # first stage network
        self.triple_weight = 1
        self.name = 'CSA_inpainting'
        self.n_layers_D = '3'  # network depth
        self.gpu_ids = [0]
        self.model = 'csa_net'
        self.checkpoints_dir = r'/path/to/checkpoints/'
        self.norm = 'instance'
        self.fixed_mask = 1
        self.use_dropout = False
        self.init_type = 'normal'
        self.mask_type = 'random'
        self.lambda_A = 100
        self.threshold = 5/16.0
        self.stride = 1
        self.shift_sz = 1  # size of feature patch
        self.mask_thred = 1
        self.bottleneck = 512
        self.gp_lambda = 10.0
        self.ncritic = 5
        self.constrain = 'MSE'
        self.strength = 1
        self.init_gain = 0.02
        self.cosis = 1
        self.gan_type = 'lsgan'
        self.gan_weight = 0.2
        self.ssim_weight = 100
        self.lorentzian_weight = 10
        self.overlap = 4
        self.skip = 0
        self.display_freq = 10
        self.print_freq = 2
        self.save_latest_freq = 5
        self.save_epoch_freq = 2
        self.continue_train = False
        self.epoch_count = 1
        self.phase = 'test'  # or train
        self.which_epoch = '120'
        self.niter = 2
        self.niter_decay = 4
        self.beta1 = 0.5
        self.lr = 0.0002
        self.lr_policy = 'lambda'
        self.lr_decay_iters = 50
        self.isTrain = False
        self.ssim_loss = True  # True or False if we want to use / don't use SSIM loss additionally
        self.lorentzian_loss = False  # True or False if we want to use / don't use Lorentzian loss additionally
        self.l1_weight = 0.1

## Test Dataset

In [3]:
opt = Opion()
transform_mask = transforms.Compose(
    [transforms.Resize((opt.fineSize, opt.fineSize)),
     transforms.ToTensor(),
    ])
transform = transforms.Compose(
    [
     transforms.Resize((opt.fineSize, opt.fineSize)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)])

dataset_test = DataLoad(opt.dataroot, opt.maskroot, transform, transform_mask)
iterator_test = (data.DataLoader(dataset_test, batch_size=opt.batchSize, shuffle=True))
print(len(dataset_test))
model = create_model(opt)
total_steps = 0

100
csa_net
initialize network with normal


  init.normal(m.weight.data, 0.0, gain)
  init.constant(m.bias.data, 0.0)


initialize network with normal
Loading pre-trained network!
model [CSAModel] was created


## Metrics
Metrics that were used to measure the correctness of the algorithm:

PSNR - (Peak Signal-to-Noise Ratio)

SSIM - (Structure Similarity Index Measure)

MSE - (Mean Square Error)

In [4]:
criterionMSE = torch.nn.MSELoss()

## Model Loading

In [5]:
load_epoch = 120
model.load(load_epoch)

## Model Testing

In [6]:
psnr_results = []
ssim_results = []
mse_results = []

In [7]:
save_dir = '/path/to/save/'
if os.path.exists(save_dir) is False:
    os.makedirs(save_dir)

for count, (image, mask) in enumerate(iterator_test):
    iter_start_time = time.time()
    image=image.cuda()
    mask=mask.cuda()
    mask=mask[0][0]
    mask=torch.unsqueeze(mask,0)
    mask=torch.unsqueeze(mask,1)
    mask=mask.bool()

    model.set_input(image,mask)
    model.set_gt_latent()
    model.test()
    real_A,real_B,fake_B=model.get_current_visuals()
    pic = (torch.cat([real_A, real_B,fake_B], dim=0) + 1) / 2.0
    torchvision.utils.save_image(pic, '%s/Image_(%d)_(%dof%d).jpg' % (
    save_dir, count, count + 1, len(dataset_test)), nrow=1)
    
    # Calculate MSE
    acc_mse = criterionMSE(real_B, fake_B).item()
    print("For {} batch:".format(count + 1)) 
    print("MSE ==> {}".format(acc_mse))
    mse_results.append(acc_mse)
    # Calculate PSNR
    acc_psnr = psnr.psnr(real_B, fake_B).item()
    print("{} ==> {}".format("PSNR", acc_psnr))
    psnr_results.append(acc_psnr)
    # Calculate SSIM
    kernel = gaussian_kernel(11).repeat(3, 1, 1).cuda()
    acc_ssim = ssim.ssim(real_B, fake_B, kernel)
    print("{} ==> {}".format("SSIM", acc_ssim[0].item()))
    ssim_results.append(acc_ssim[0].item())

# MSE result for test data
mse_test_data = sum(mse_results) / len(mse_results)
print("MSE for test data is: {}".format(mse_test_data))

# PSNR result for test data
psnr_test_data = sum(psnr_results) / len(psnr_results)
print("PSNR for test data is: {}".format(psnr_test_data))

# SSIM result for test data
ssim_test_data = sum(ssim_results) / len(ssim_results)
print("SSIM for test data is: {}".format(ssim_test_data))


For 1 batch:
MSE ==> 0.053971774876117706
PSNR ==> 12.678332328796387
SSIM ==> 0.4671025574207306
For 2 batch:
MSE ==> 0.048840295523405075
PSNR ==> 13.112215995788574
SSIM ==> 0.4886021912097931
For 3 batch:
MSE ==> 0.04842561483383179
PSNR ==> 13.149248123168945
SSIM ==> 0.49265050888061523
For 4 batch:
MSE ==> 0.04408326745033264
PSNR ==> 13.55726146697998
SSIM ==> 0.5383450984954834
For 5 batch:
MSE ==> 0.04372094199061394
PSNR ==> 13.593103408813477
SSIM ==> 0.5345590114593506
For 6 batch:
MSE ==> 0.03730323165655136
PSNR ==> 14.282533645629883
SSIM ==> 0.651111900806427
For 7 batch:
MSE ==> 0.123477503657341
PSNR ==> 9.084121704101562
SSIM ==> 0.38831886649131775
For 8 batch:
MSE ==> 0.06932743638753891
PSNR ==> 11.590948104858398
SSIM ==> 0.5666568875312805
For 9 batch:
MSE ==> 0.027314377948641777
PSNR ==> 15.636085510253906
SSIM ==> 0.5340876579284668
For 10 batch:
MSE ==> 0.055225010961294174
PSNR ==> 12.578640937805176
SSIM ==> 0.5490313768386841
For 11 batch:
MSE ==> 0.0358

For 85 batch:
MSE ==> 0.04478468745946884
PSNR ==> 13.488703727722168
SSIM ==> 0.5400516390800476
For 86 batch:
MSE ==> 0.04954593628644943
PSNR ==> 13.049918174743652
SSIM ==> 0.5176152586936951
For 87 batch:
MSE ==> 0.05682101473212242
PSNR ==> 12.454910278320312
SSIM ==> 0.5732230544090271
For 88 batch:
MSE ==> 0.07027792930603027
PSNR ==> 11.53180980682373
SSIM ==> 0.5103942155838013
For 89 batch:
MSE ==> 0.030983787029981613
PSNR ==> 15.088653564453125
SSIM ==> 0.5940161347389221
For 90 batch:
MSE ==> 0.04736264795064926
PSNR ==> 13.24563980102539
SSIM ==> 0.4395633935928345
For 91 batch:
MSE ==> 0.03922116011381149
PSNR ==> 14.064794540405273
SSIM ==> 0.5799131393432617
For 92 batch:
MSE ==> 0.03827095031738281
PSNR ==> 14.171306610107422
SSIM ==> 0.5147023797035217
For 93 batch:
MSE ==> 0.1423562616109848
PSNR ==> 8.46623420715332
SSIM ==> 0.584256649017334
For 94 batch:
MSE ==> 0.06188260763883591
PSNR ==> 12.08431339263916
SSIM ==> 0.5168940424919128
For 95 batch:
MSE ==> 0.02