In [None]:
from torch import optim

from discriminator import Discriminator
from generator import Generator, OUTPUT_CHANNELS
from utils import device, lr, beta1, beta2

gen = Generator().to(device)
disc = Discriminator().to(device)

gen_opt = optim.Adam(gen.parameters(), lr=lr, betas=(beta1, beta2))
disc_opt = optim.Adam(disc.parameters(), lr=lr, betas=(beta1, beta2))

In [None]:
import numpy as np
import os
from utils import checkpoint_dir, save_checkpoint, load_checkpoint

checkpoint_files = os.listdir(checkpoint_dir)
checkpoint_files = [a for a in checkpoint_files if (a[-4:] == '.pth')]

LOAD_MODEL = False
if len(checkpoint_files) != 0:
    LOAD_MODEL = True
    checkpoint_files.sort(key=lambda x: os.path.getmtime(checkpoint_dir + x))
    last_checkpoint = checkpoint_files[-1]
    
SAVE_MODEL = True
curr_epoch = 1 if (not LOAD_MODEL) else (load_checkpoint(last_checkpoint, (gen, disc), (gen_opt, disc_opt), lr=lr) + 1)

In [None]:
import torch

my_checkpoint = torch.load(checkpoint_dir + last_checkpoint)

gen_loss = my_checkpoint['loss_gen']
disc_loss = my_checkpoint['loss_disc']

In [None]:
from dataset import Dataset, DataLoader, BATCH_SIZE, data_folder_name

use_test = True

test_data_path = f'../{data_folder_name}/{"test" if use_test else "train"}/'

ds = Dataset(test_data_path, size=256, train=False)
dataloader_test = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

gen.train = False

In [None]:
import torch.nn as nn

loss_l1 = nn.L1Loss()
loss_mse = nn.MSELoss()

In [None]:
import matplotlib.pyplot as plt

(inp, tar), _ = next(iter(dataloader_test))

generated = gen(inp.to(device))

fig, axs = plt.subplots(1, 3, figsize=(20, 30))
axs[0].imshow(inp[0].detach().permute(1, 2, 0) * 0.5 + 0.5)
axs[0].set_title('Observed')
axs[1].imshow(tar[0].detach().permute(1, 2, 0) * 0.5 + 0.5)
axs[1].set_title('Real')
axs[2].imshow(generated[0].detach().permute(1, 2, 0).cpu() * 0.5 + 0.5)
axs[2].set_title('Generated')

## Display some average images (L1 loss between 0.16 and 0.18)

In [None]:
i = 0

for (inp, tar), _ in dataloader_test:
    
    generated = gen(inp.to(device))
    
    if i >= 5:
        break

    if loss_l1(tar, generated.cpu()) >= 0.16 and loss_l1(tar, generated.cpu()) <= 0.18:
    # if loss_l1(tar, generated.cpu()) <= 0.12:

        i += 1
        fig, axs = plt.subplots(1, 3, figsize=(20, 30))
        axs[0].imshow(inp[0].detach().permute(1, 2, 0) * 0.5 + 0.5)
        axs[0].set_title('Observed')
        axs[1].imshow(tar[0].detach().permute(1, 2, 0) * 0.5 + 0.5)
        axs[1].set_title('Real')
        axs[2].imshow(generated[0].detach().permute(1, 2, 0).cpu() * 0.5 + 0.5)
        axs[2].set_title('Generated')
        # print(f'DISCRIMINATOR EVALUATION ON REAL IMAGE: {torch.mean(disc(inp.cuda(), tar.cuda()))}')
        # print(f'DISCRIMINATOR EVALUATION ON GENERATED IMAGE: {torch.mean(disc(inp.cuda(), generated.cuda()))}')
        
        plt.show()

## Display bad or good images thanks to wasserstain distance

In [None]:
i = 0

for (inp, tar), _ in dataloader_test:
    i += 1
    if i == 50:
        break
        
    generated = gen(inp.to(device))
    wass_loss = (torch.mean(disc(inp.cuda(), tar.cuda())) - torch.mean(disc(inp.cuda(), generated.cuda()))).item()
    
    # With wass_loss > 20.0 we obtain bad images, with wass_loss < 10.0 we obtain good images
    if wass_loss > 20.0:
    
        fig, axs = plt.subplots(1, 3, figsize=(20, 30))
        axs[0].imshow(inp[0].detach().permute(1, 2, 0) * 0.5 + 0.5)
        axs[0].set_title('Observed')
        axs[1].imshow(tar[0].detach().permute(1, 2, 0) * 0.5 + 0.5)
        axs[1].set_title('Real')
        axs[2].imshow(generated[0].detach().permute(1, 2, 0).cpu() * 0.5 + 0.5)
        axs[2].set_title('Generated')

        plt.show()
        print(f'WASS LOSS: {wass_loss}')
    