## Initiate the Models

In [None]:
from torch import optim
import torch

from discriminator import Discriminator
from generator import Generator, OUTPUT_CHANNELS
from utils import device, lr, beta1, beta2

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

gen_x = Generator(OUTPUT_CHANNELS).to(device)  # Generates Xs from Ys
gen_y = Generator(OUTPUT_CHANNELS).to(device)  # Generates Ys from Xs
disc_x = Discriminator().to(device)  # Discriminates real Xs from fake Xs
disc_y = Discriminator().to(device)  # Discriminates real Ys from fake Ys

gen_x_opt = optim.Adam(gen_x.parameters(), lr=lr, betas=(beta1, beta2))
gen_y_opt = optim.Adam(gen_y.parameters(), lr=lr, betas=(beta1, beta2))
disc_x_opt = optim.Adam(disc_x.parameters(), lr=lr, betas=(beta1, beta2))
disc_y_opt = optim.Adam(disc_y.parameters(), lr=lr, betas=(beta1, beta2))

## Load Checkpoints

In [None]:
import numpy as np
import os
from utils import checkpoint_dir, save_checkpoint, load_checkpoint

In [None]:
checkpoint_files = os.listdir(checkpoint_dir)
checkpoint_files = [a for a in checkpoint_files if (a[-4:] == '.pth')]

LOAD_MODEL = False
if len(checkpoint_files) != 0:
    LOAD_MODEL = True
    checkpoint_files.sort(key=lambda x: os.path.getmtime(checkpoint_dir + x))
    last_checkpoint = checkpoint_files[-1]
    
SAVE_MODEL = True
curr_epoch = 1 if (not LOAD_MODEL) else (load_checkpoint(last_checkpoint, (gen_x, gen_y, disc_x, disc_y), (gen_x_opt, gen_y_opt, disc_x_opt, disc_y_opt), lr=lr) + 1)

## Training Loop

In [None]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from dataset import dataloader
from train import gen_x_step, gen_y_step, disc_x_step, disc_y_step, train_step
from utils import show_images, device

EPOCHS_PER_CHECKPOINT = 1

In [None]:
from dataset import Dataset, DataLoader, BATCH_SIZE, dataset_folder_name
import torch.nn as nn

test_data_path = f'../{dataset_folder_name}/test/'

ds = Dataset(test_data_path, size=256, train=False)
# ds.items.remove('../aux_channel_two/test/.ipynb_checkpoints')
dataloader_test = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

loss_mse = nn.MSELoss()
loss_l1 = nn.L1Loss()

In [None]:
for epoch in range(curr_epoch, curr_epoch + 100 * EPOCHS_PER_CHECKPOINT):
    print(f"Training epoch {epoch}")
    step = 0
    for (strain, _), (aux, _) in tqdm(dataloader):
        step += 1
        strain, aux = strain.to(device), aux.to(device)
        
        gen_x_loss, _ = gen_x_step(strain, aux, gen_x, gen_x_opt, gen_y, disc_x)
        gen_y_loss, _ = gen_y_step(strain, aux, gen_y, gen_y_opt, gen_x, disc_y)
        disc_x_loss = disc_x_step(strain, aux, disc_x, disc_x_opt, gen_x)
        disc_y_loss = disc_y_step(strain, aux, disc_y, disc_y_opt, gen_y)
        
    ### Validation
    mse_errors_x = []
    mse_errors_y = []
    l1_errors_x = []
    l1_errors_y = []
    for (strain, real_aux), (aux, real_strain) in dataloader_test:
        generated_strain = gen_x(aux.to(device))
        generated_aux = gen_y(strain.to(device))
        mse_errors_x.append(loss_mse(real_strain, generated_strain.cpu()).item())
        mse_errors_y.append(loss_mse(real_aux, generated_aux.cpu()).item())
        l1_errors_x.append(loss_l1(real_strain, generated_strain.cpu()).item())
        l1_errors_y.append(loss_l1(real_aux, generated_aux.cpu()).item())
        
    if SAVE_MODEL and epoch % EPOCHS_PER_CHECKPOINT == 0:
        
        checkpoint_files = [a for a in os.listdir(checkpoint_dir) if (a[-4:] == '.pth')]
        checkpoint_files.sort(key=lambda x: os.path.getmtime(checkpoint_dir + x))
        last_checkpoint = checkpoint_files[-1] if len(checkpoint_files) else None
        
        filename = "checkpoint_epoch" + str(epoch)
        save_checkpoint((gen_x, gen_y, disc_x, disc_y), (gen_x_opt, gen_y_opt, disc_x_opt, disc_y_opt), curr_epoch=epoch, losses=((gen_x_loss, gen_y_loss), (disc_x_loss, disc_y_loss)), mse=(np.mean(mse_errors_x), np.mean(mse_errors_y)), l1=(np.mean(l1_errors_x), np.mean(l1_errors_y)), last_checkpoint=last_checkpoint, filename=filename)

## Obtain the evaluation metrics

In [None]:
import torch

dictionary = torch.load('./save_states/' + last_checkpoint)

In [None]:
dictionary['l1_error_x']

<strong>You can choose from:</strong><br>
<ul>
<li>'l1_error_x'</li>
<li>'l1_error_y'</li>
<li>'mse_error_x'</li>
<li>'mse_error_y'</li>
<li>'loss_gen_x'</li>
<li>'loss_gen_y'</li>
<li>'loss_disc_x'</li>
<li>'loss_disc_y'</li>
</ul>