In [None]:
import numpy as np
from util.data_preparation import VariedSizedImagesCollate, LoadingImgs_to_ListOfListOfTensors, VariedSizedImagesDataset, data_aug_preprocessing_HR
import time
import torch
from options.train_options import TrainOptions
from models import create_model
from util.visualizer import Visualizer
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
from util.metrics_utils import metrics_names
def linear_normalize(tmp):
    return (tmp - tmp.min())/(tmp.max() - tmp.min())

name = '--name CUT+REG+Pseudo_4x_NCE20_REG100_TV600_MAE1_MAEpseudo10_3channel'
model = '--model cutreg_twostage' # current proposed framework that uses CUT loss and registration loss
mode = '--train_R_with_G True --only_train_R False --only_train_G False --train_G_pseudo True'
hyperparameters = '--lambda_GAN 1.0 --lambda_NCE 20.0 --lambda_REG 100.0 --lambda_TV 600.0 --lambda_MAE 1.0 --lambda_MAE_pseudo 10.0 --CUT_mode FastCUT'
lr_schedule = '--n_epochs 50 --n_epochs_decay 450'
cmd_line = ' '.join((name, model, mode, hyperparameters, lr_schedule))
opt = TrainOptions(cmd_line).parse()   # get training options
model = create_model(opt)      # create a model given opt.model and other options
## loading the image data (full FOV, HR meaning original,non-downsampled)
data_folder = 'Dataset/example_training_data/'
image_indices = [1, 2, 3]
list_of_list_of_tensors = LoadingImgs_to_ListOfListOfTensors(data_folder, image_indices)
dataset = VariedSizedImagesDataset(list_of_list_of_tensors)
dataset_size = len(dataset) # get the number of images in the dataset.
train_dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1, ## can only be 1
            collate_fn=VariedSizedImagesCollate,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads),
            drop_last=True if opt.isTrain else False,
        )
test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 
                                              sampler=torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=2))
writer = SummaryWriter('runs/'+str(opt.name))

In [None]:
model.netR.load_state_dict(torch.load('model_weights/registration_net_pretrained.pth'))
visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
opt.visualizer = visualizer
total_iters = 0                # the total number of training iterations
optimize_time = 0.1
times = []
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
    epoch_start_time = time.time()  # timer for entire epoch
    iter_data_time = time.time()    # timer for data loading per iteration
    epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
    visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch

    model.netG.train() # switching on the training mode for netG
    running_losses = OrderedDict()
    for name in model.loss_names:
        running_losses[name] = 0
        
    for i, data in enumerate(train_dataloader): # inner loop within one epoch
        iter_start_time = time.time()  # timer for computation per iteration
        if total_iters % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time

        batch_size = len(data[0])
        total_iters += batch_size
        epoch_iter += batch_size

        if epoch == opt.epoch_count and i == 0:
            model.data_dependent_initialize(data) 
            model.setup(opt)               # regular setup: load and print networks; create schedulers
            model.parallelize()
            metrics = model.save_images_FFOV(0, test_dataloader)
            for i in range(len(metrics_names)):
                writer.add_scalar('metric_'+str(metrics_names[i]), metrics[i], epoch)
            
        ## stage 1: train image translation net with CUT loss with small FOV images, register first then translate
        ## netG, netF and netD is trainable, netR is fixed
        
        model.set_input(data)
        model.define_fold(data)
        model.forward_stage2()
#         real_A, real_B = model.real_A.detach().cpu().numpy(), model.real_B_reg.detach().cpu().numpy()
            
        # data_aug_preprocessing_HR crops large image into small patches
        mini_dataloader = model.patchify_WSI(real_B_reg = model.real_B_reg)       
        for i, mini_data in enumerate(mini_dataloader):
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_start_time = time.time()
            model.set_mini_input(mini_data)
            model.optimize_parameters()
            
            total_iters += model.mini_bs
            epoch_iter += model.mini_bs
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
            losses = model.get_current_losses()
            for name in model.loss_names:
                if isinstance(name, str):
                    running_losses[name] += losses[name] 
                
        ## stage 2: train image registration across full FOV, translate first then register
        ## netR is trainable, netG, netF, netD are fixed
        if not model.opt.only_train_G:
            model.set_input(data)
            fake_B = model.patch_wise_predict(input_data=model.real_A_eq, stride_ratio=1).permute(2,0,1).unsqueeze(0)
            
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()            
            optimize_start_time = time.time()
            model.set_input(data)
            ## gradient descend on netR
            model.optimize_parameters_stage2(fake_B = fake_B)
                
            model.netG.train()
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
            
        losses = model.get_current_losses()
        for name in model.loss_names:
            if isinstance(name, str):
                running_losses[name] += losses[name] 
        visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
        if opt.display_id is None or opt.display_id > 0:
            visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

        iter_data_time = time.time()

    if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
#         model.save_networks('latest')
#         model.save_networks(epoch)
        metrics = model.save_images_FFOV(epoch, test_dataloader)
        for i in range(len(metrics_names)):
            writer.add_scalar('metric_'+str(metrics_names[i]), metrics[i], epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
    model.update_learning_rate()                     # update learning rates at the end of every epoch.
    for name in model.loss_names:
        if isinstance(name, str):
            writer.add_scalar('loss_'+str(name), running_losses[name]/len(train_dataloader), epoch)