In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import sys
sys.path.append(str(Path.cwd() / 'src'))

from src.configs import TrainConfig
import numpy as np
from tqdm import tqdm
import torch
from src.test import get_fid_score
from multiprocessing import cpu_count
import matplotlib.pyplot as plt
from datetime import datetime

config = TrainConfig()

# Create dataloaders

In [None]:
from dataset import getDataLoader
from utils.deep import NetPhase
train_loader = getDataLoader(is_train=True, bandwidth=config.wl, batch_size=config.data_loader.batch_size,
                             path_to_data='data')  # create a Dataset given opt.dataset_mode and other options
val_loader = getDataLoader(is_train=False, bandwidth=config.wl, batch_size=config.data_loader.batch_size,
                           path_to_data='data')  # create a Dataset for evaluating the results after each iteration
print(f"The number of training images = {len(train_loader) * config.data_loader.batch_size}")

# Model Setup
Configure and initialize the model to be trained according to the desired specifications:
1. Model backbone (default is CUT).
2. Use the camera intrinsic (FPA) temperature (default is True).
3. Use the physical model (default is True).

In [None]:
from cycle_gan_model import CycleGANModel
from cut_model import CUTModel

# config.model = "CUT" # "CycleGan"
# config.thermal.is_fpa_input = True
# config.thermal.is_physical_model = True


backbone = CUTModel if config.model == "CUT" else CycleGANModel
model = backbone(config)

model.setup(config)

# Train
The model's training loop. At the end of every epoch, an auxiliary visualization of images is produced, with the following labels:
1. "pan_real" - the real panchromatic input image.
2. "mono_phys" - the physical model's prediction of the monochromatic (9000nm) output.
3. "mono_fake" - the fusion (physical + deep backbone) model's prediction of the monochromatic output.
4. "mono_real" - an unpaired real monochromatic image for reference.

In [None]:
# TODO: update order of visualized images (mono_phys before mono_fake)
path_to_save = Path.cwd() / 'results' / 'train' / datetime.now().strftime("%Y%m%d_h%Hm%Ms%S")
path_to_save.mkdir(parents=True, exist_ok=True)

tot_epochs = config.scheduler.n_epochs + config.scheduler.n_epochs_decay + 1
best_fid_score = np.inf  # initialize fid threshold for best solution saving
rand_val_idx = np.random.randint(low=0, high=len(val_loader))  # used to randomly pick image for saving
for epoch in range(1, tot_epochs):
    # Train
    model.set_phase(NetPhase.train)
    for i, data in enumerate(tqdm(train_loader, postfix="Train", desc=f"Epoch {epoch}|{tot_epochs-1}")):
        if epoch == 1 and i == 0 and 'cut' in str(model.__class__).lower():  # first iteration:
            model.data_dependent_initialize(data)
        model.set_input(data)
        model.forward()

        if config.network.gan_mode == "wgangp" and i % config.network.n_critic:
            train_gen = False
        else:
            train_gen = True
        model.optimize_parameters(train_gen)
    model.update_loss(epoch, len(train_loader))

    # Validate
    model.set_phase(NetPhase.val)
    with torch.inference_mode():
        for i, data in enumerate(tqdm(val_loader, postfix="Validate", desc=f"Epoch {epoch}|{tot_epochs-1}")):
            model.set_input(data)
            model.forward()

            # additionally calculate losses for visualization purposes:
            model.calc_loss_D()
            model.calc_loss_G()
            model.update_agg_loss()

            # track visual performance:
            if i in [0, rand_val_idx]:
                plt.figure(figsize=(20, 10))
                plt.imshow(model.gen_vis_grid().permute(1, 2, 0))
                plt.show()
                plt.close()

    model.update_loss(epoch, len(val_loader))
    model.update_learning_rate()  # update learning rates in the beginning of every epoch.

    # calculate FID. TODO: remove after asserting the correlation with loss components
    fid_score = get_fid_score(config, model, batch_size=cpu_count())
    if fid_score < best_fid_score:
        model.save_networks("best", path_to_save / "checkpoints")
        best_fid_score = fid_score
        print(f'Best FID score: {best_fid_score}')

model.save_networks("last", path_to_save / "checkpoints")