## Initiate the Models

In [None]:
from torch import optim

from discriminator import Discriminator
from generator import Generator
from utils import device, lr, beta1, beta2

In [None]:
gen = Generator().to(device)
disc = Discriminator().to(device)

gen_opt = optim.Adam(gen.parameters(), lr=lr, betas=(beta1, beta2))
disc_opt = optim.Adam(disc.parameters(), lr=lr, betas=(beta1, beta2))

## Load Checkpoints

In [None]:
import numpy as np
import os
from utils import checkpoint_dir, save_checkpoint, load_checkpoint

In [None]:
checkpoint_files = os.listdir(checkpoint_dir)
checkpoint_files = [a for a in checkpoint_files if (a[-4:] == '.pth')]

LOAD_MODEL = False
if len(checkpoint_files) != 0:
    LOAD_MODEL = True
    checkpoint_files.sort(key=lambda x: os.path.getmtime(checkpoint_dir + x))
    last_checkpoint = checkpoint_files[-1]
    
SAVE_MODEL = True
curr_epoch = 1 if (not LOAD_MODEL) else (load_checkpoint(last_checkpoint, (gen, disc), (gen_opt, disc_opt), lr=lr) + 1)

## Training Loop

In [None]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from dataset import dataloader
from train import gen_step, disc_step
from utils import show_images

EPOCHS_PER_CHECKPOINT = 1

In [None]:
from dataset import Dataset, DataLoader, BATCH_SIZE, data_folder_name
import torch.nn as nn

test_data_path = f'../{data_folder_name}/test/'

ds = Dataset(test_data_path, size=256, train=False)
# ds.items.remove('../aux_channel_two/test/.ipynb_checkpoints')
dataloader_test = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

loss_mse = nn.MSELoss()
loss_l1 = nn.L1Loss()

In [None]:
import torch

for epoch in range(curr_epoch, curr_epoch + 100*EPOCHS_PER_CHECKPOINT):
    print(f"Training epoch {epoch}")
    step = 0
        
    ### Training
    gen.training = True
    disc.training = True
    for (inputs, targets), _ in tqdm(dataloader):
        step += 1
        inputs, targets = inputs.cuda(), targets.cuda()
        Disc_Loss = disc_step(inputs, targets, gen, disc, disc_opt)
        
        for i in range(2):
            Gen_Loss, generator_image = gen_step(inputs, targets, gen, disc, gen_opt)
            
    ### Validation
    gen.training = False
    disc.training = False
    mse_errors = []
    l1_errors = []
    wass_error = []
    for (inputs, targets), _ in dataloader_test:
        generated = gen(inputs.to(device))
        mse_errors.append(loss_mse(targets, generated.cpu()).item())
        l1_errors.append(loss_l1(targets, generated.cpu()).item())
        wass_error.append((torch.mean(disc(inputs.cuda(), targets.cuda())) - torch.mean(disc(inputs.cuda(), generated.cuda()))).item())
    
    ### Saving checkpoints
    if SAVE_MODEL and epoch % EPOCHS_PER_CHECKPOINT == 0:
        
        checkpoint_files = [a for a in os.listdir(checkpoint_dir) if (a[-4:] == '.pth')]
        checkpoint_files.sort(key=lambda x: os.path.getmtime(checkpoint_dir + x))
        last_checkpoint = checkpoint_files[-1] if len(checkpoint_files) else None
        
        filename = "checkpoint_epoch" + str(epoch)
        save_checkpoint((gen, disc), (gen_opt, disc_opt), curr_epoch=epoch, losses=(Gen_Loss, Disc_Loss), wass=np.mean(wass_error), mse=np.mean(mse_errors), l1=np.mean(l1_errors), last_checkpoint=last_checkpoint, filename=filename)

## Obtain the evaluation metrics

In [None]:
import torch

dictionary = torch.load('./save_states/' + last_checkpoint)

In [None]:
dictionary['wass_error']

<strong>You can choose from:</strong><br>
<ul>
<li>'l1_error'</li>
<li>'mse_error'</li>
<li>'loss_gen'</li>
<li>'loss_disc'</li>
<li>'wass_error'</li>
</ul>