In [None]:
from torch import optim

from discriminator import Discriminator
from generator import Generator, OUTPUT_CHANNELS
from utils import device, lr, beta1, beta2

gen_x = Generator(OUTPUT_CHANNELS).to(device)  # Generates Xs from Ys
gen_y = Generator(OUTPUT_CHANNELS).to(device)  # Generates Ys from Xs
disc_x = Discriminator().to(device)  # Discriminates real Xs from fake Xs
disc_y = Discriminator().to(device)  # Discriminates real Ys from fake Ys

gen_x_opt = optim.Adam(gen_x.parameters(), lr=lr, betas=(beta1, beta2))
gen_y_opt = optim.Adam(gen_y.parameters(), lr=lr, betas=(beta1, beta2))
disc_x_opt = optim.Adam(disc_x.parameters(), lr=lr, betas=(beta1, beta2))
disc_y_opt = optim.Adam(disc_y.parameters(), lr=lr, betas=(beta1, beta2))

In [None]:
import numpy as np
import os
from utils import checkpoint_dir, save_checkpoint, load_checkpoint

checkpoint_files = os.listdir(checkpoint_dir)
checkpoint_files = [a for a in checkpoint_files if (a[-4:] == '.pth')]

LOAD_MODEL = False
if len(checkpoint_files) != 0:
    LOAD_MODEL = True
    checkpoint_files.sort(key=lambda x: os.path.getmtime(checkpoint_dir + x))
    last_checkpoint = checkpoint_files[-1]
    
SAVE_MODEL = True
curr_epoch = 1 if (not LOAD_MODEL) else (load_checkpoint(last_checkpoint, (gen_x, gen_y, disc_x, disc_y), (gen_x_opt, gen_y_opt, disc_x_opt, disc_y_opt), lr=lr) + 1)

In [None]:
from dataset import Dataset, DataLoader, BATCH_SIZE, dataset_folder_name

use_test = True

test_data_path = f'../{dataset_folder_name}/{"test" if use_test else "train"}/'

ds = Dataset(test_data_path, size=256, train=False)
dataloader_test = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

gen_x.train = False
gen_y.train = False

In [None]:
(strain, real_aux), (aux, real_strain) = next(iter(dataloader_test))

In [None]:
import matplotlib.pyplot as plt

(strain, real_aux), (aux, real_strain) = next(iter(dataloader_test))

gen_aux = gen_y(strain.to(device))
gen_strain = gen_x(aux.to(device))

fig, axs = plt.subplots(1, 4, figsize=(20, 40))
axs[0].imshow(real_aux[0].detach().permute(1, 2, 0).cpu() * 0.5 + 0.5)
axs[0].set_title('Real Aux Image')
axs[1].imshow(gen_aux[0].detach().permute(1, 2, 0).cpu() * 0.5 + 0.5)
axs[1].set_title('Generated Aux')
axs[2].imshow(real_strain[0].detach().permute(1, 2, 0).cpu() * 0.5 + 0.5)
axs[2].set_title('Real Strain Image')
axs[3].imshow(gen_strain[0].detach().permute(1, 2, 0).cpu() * 0.5 + 0.5)
axs[3].set_title('Generated Strain')