In [1]:
## import modules to be used
import numpy as np
from util.data_preparation import VariedSizedImagesCollate, LoadingImgs_to_ListOfListOfTensors, VariedSizedImagesDataset, data_aug_preprocessing_HR
import time
import torch
from options.train_options import TrainOptions
# from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict
from util.metrics_utils import metrics_names
def linear_normalize(tmp):
    return (tmp - tmp.min())/(tmp.max() - tmp.min())

# if __name__ == '__main__':
name = '--name CUT+REG+Pseudo_9spl_4x_NCE20_REG100_TV600_MAE1_MAEpseudo10_3channel'
model = '--model cutreg_twostage' # current proposed framework that uses CUT loss and registration loss
# the switch for using plain CUT framework or our CUT+Registration framework
mode = '--train_R_with_G True --only_train_R False --only_train_G False --train_G_pseudo True'
hyperparameters = '--lambda_GAN 1.0 --lambda_NCE 20.0 --lambda_REG 100.0 --lambda_TV 600.0 --lambda_MAE 1.0 --lambda_MAE_pseudo 10.0 --CUT_mode FastCUT'
lr_schedule = '--n_epochs 50 --n_epochs_decay 450'
cmd_line = ' '.join((name, model, mode, hyperparameters, lr_schedule))
opt = TrainOptions(cmd_line).parse()   # get training options
model = create_model(opt)      # create a model given opt.model and other options
## loading the image data (full FOV, HR meaning original,non-downsampled)
data_folder = '../Data/Data_1212_22/'
good_image_index = [10, 11, 13, 14, 16, 22, 24, 25, 30]
list_of_list_of_tensors = LoadingImgs_to_ListOfListOfTensors(data_folder, good_image_index)
dataset = VariedSizedImagesDataset(list_of_list_of_tensors)
dataset_size = len(dataset) # get the number of images in the dataset.
train_dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1, ## can only be 1
            collate_fn=VariedSizedImagesCollate,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads),
            drop_last=True if opt.isTrain else False,
        )
test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 
                                              sampler=torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=2))
writer = SummaryWriter('runs/'+str(opt.name))

----------------- Options ---------------
                 CUT_mode: FastCUT                       
                LAB_space: False                         
               batch_size: 1                             
                    beta1: 0.5                           
                    beta2: 0.999                         
          checkpoints_dir: ./checkpoints                 
           continue_train: False                         
                crop_size: 256                           
                 dataroot: placeholder                   
             dataset_mode: unaligned                     
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: None                          
            display_ncols: 4                             
             display_port: 8097                          
           display_server: htt

In [None]:
# print('The number of training images = %d' % dataset_size)
# model.netG.load_state_dict(torch.load('checkpoints/CUT_newdata_512/160_net_G.pth'))
# model.netD.load_state_dict(torch.load('checkpoints/CUT_newdata_512/160_net_D.pth'))
# model.netG.load_state_dict(torch.load('checkpoints/CUTREG_NC6839_FFOV_cGAN/15_net_G.pth'))
model.netR.load_state_dict(torch.load('checkpoints/REG_35spl_4x_MAE_REG100TV600_ResUnet_3channel/500_net_R.pth'))
# model.netR.load_state_dict(torch.load('checkpoints/CUT+REG+Pseudo_9spl_4x_NCE20_REG100_TV600_MAE1_MAEpseudo10/30_net_R.pth'))
model.netG.load_state_dict(torch.load('checkpoints/'+ str(name)+ CUT+REG+Pseudo_9spl_4x_NCE20_REG100_TV600_MAE1_MAEpseudo10 +'/30_net_G.pth'))

visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
opt.visualizer = visualizer
total_iters = 0                # the total number of training iterations
optimize_time = 0.1
times = []
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
    epoch_start_time = time.time()  # timer for entire epoch
    iter_data_time = time.time()    # timer for data loading per iteration
    epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
    visualizer.reset()              # reset the visualizer: make sure it saves the results to HTML at least once every epoch

#     dataset.set_epoch(epoch)
    model.netG.train() # switching on the training mode for netG
#     model.netG.eval()
    running_losses = OrderedDict()
    for name in model.loss_names:
        running_losses[name] = 0
        
    for i, data in enumerate(train_dataloader): # inner loop within one epoch
        iter_start_time = time.time()  # timer for computation per iteration
        if total_iters % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time

        batch_size = len(data[0])
        total_iters += batch_size
        epoch_iter += batch_size

        if epoch == opt.epoch_count and i == 0:
            model.data_dependent_initialize(data) 
            model.setup(opt)               # regular setup: load and print networks; create schedulers
            model.parallelize()
            metrics = model.save_images_FFOV(0, test_dataloader)
            for i in range(len(metrics_names)):
                writer.add_scalar('metric_'+str(metrics_names[i]), metrics[i], epoch)
            
        ## stage 1: train image translation net with CUT loss with small FOV images, register first then translate
        ## netG, netF and netD is trainable, netR is fixed
        
        model.set_input(data)
        model.define_fold(data)
        model.forward_stage2()
#         real_A, real_B = model.real_A.detach().cpu().numpy(), model.real_B_reg.detach().cpu().numpy()
            
        # data_aug_preprocessing_HR crops large image into small patches
        mini_dataloader = model.patchify_WSI(real_B_reg = model.real_B_reg)       
        for i, mini_data in enumerate(mini_dataloader):
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_start_time = time.time()
            model.set_mini_input(mini_data)
            model.optimize_parameters()
            
            total_iters += model.mini_bs
            epoch_iter += model.mini_bs
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
            losses = model.get_current_losses()
            for name in model.loss_names:
                if isinstance(name, str):
                    running_losses[name] += losses[name] 
                
        ## stage 2: train image registration across full FOV, translate first then register
        ## netR is trainable, netG, netF, netD are fixed
        if not model.opt.only_train_G:
            model.set_input(data)
            fake_B = model.patch_wise_predict(input_data=model.real_A_eq, stride_ratio=1).permute(2,0,1).unsqueeze(0)
            
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()            
            optimize_start_time = time.time()
            model.set_input(data)
            ## gradient descend on netR
            model.optimize_parameters_stage2(fake_B = fake_B)
                
            model.netG.train()
            if len(opt.gpu_ids) > 0:
                torch.cuda.synchronize()
            optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
            
        losses = model.get_current_losses()
        for name in model.loss_names:
            if isinstance(name, str):
                running_losses[name] += losses[name] 
        visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
        if opt.display_id is None or opt.display_id > 0:
            visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)


#         if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
#             save_result = total_iters % opt.update_html_freq == 0
#             model.compute_visuals()
#             visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

        iter_data_time = time.time()

    if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
#         model.save_networks('latest')
#         model.save_networks(epoch)
        metrics = model.save_images_FFOV(epoch, test_dataloader)
        for i in range(len(metrics_names)):
            writer.add_scalar('metric_'+str(metrics_names[i]), metrics[i], epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
    model.update_learning_rate()                     # update learning rates at the end of every epoch.
    for name in model.loss_names:
        if isinstance(name, str):
            writer.add_scalar('loss_'+str(name), running_losses[name]/len(train_dataloader), epoch)

Setting up a new session...


create web directory ./checkpoints/CUT+REG+Pseudo_9spl_4x_NCE20_REG100_TV600_MAE1_MAEpseudo10_3channel/web...


  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


---------- Networks initialized -------------
[Network G] Total number of parameters : 11.389 M
[Network F] Total number of parameters : 0.560 M
[Network D] Total number of parameters : 2.766 M
[Network R] Total number of parameters : 2.059 M
-----------------------------------------------
Directory Results/CUT+REG+Pseudo_9spl_4x_NCE20_REG100_TV600_MAE1_MAEpseudo10_3channel/images/ already exists


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  L_mean, A_mean, B_mean = L[mask_WM].mean(), A[mask_WM].mean(), B[mask_WM].mean()
  ret = ret.dtype.type(ret / rcount)
  return n/db/n.sum(), bin_edges
  L_mean, A_mean, B_mean = L[mask_WM].mean(), A[mask_WM].mean(), B[mask_WM].mean()
  ret = ret.dtype.type(ret / rcount)
  return n/db/n.sum(), bin_edges
  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 1, iters: 25, time: 0.124, data: 0.167) G_GAN: 0.604 D_real: 0.032 D_fake: 0.059 G: 116.919 NCE: 110.904 REG: 8.067 TV: 0.897 MAE: 0.480 MAE_pseudo: 4.932 LAB: 0.000 
(epoch: 1, iters: 50, time: 0.146, data: 0.167) G_GAN: 0.382 D_real: 0.368 D_fake: 0.304 G: 100.089 NCE: 97.276 REG: 7.860 TV: 0.767 MAE: 0.235 MAE_pseudo: 2.195 LAB: 0.000 
(epoch: 1, iters: 75, time: 0.166, data: 0.167) G_GAN: 0.280 D_real: 0.306 D_fake: 0.241 G: 97.468 NCE: 95.624 REG: 5.275 TV: 0.402 MAE: 0.147 MAE_pseudo: 1.418 LAB: 0.000 
(epoch: 1, iters: 100, time: 0.181, data: 0.167) G_GAN: 0.274 D_real: 0.166 D_fake: 0.371 G: 100.412 NCE: 98.322 REG: 7.655 TV: 0.794 MAE: 0.169 MAE_pseudo: 1.646 LAB: 0.000 
(epoch: 1, iters: 131, time: 0.201, data: 0.005) G_GAN: 0.299 D_real: 0.265 D_fake: 0.215 G: 98.320 NCE: 96.059 REG: 6.494 TV: 0.580 MAE: 0.208 MAE_pseudo: 1.754 LAB: 0.000 
(epoch: 1, iters: 156, time: 0.214, data: 0.005) G_GAN: 0.567 D_real: 0.350 D_fake: 0.290 G: 93.838 NCE: 91.175 REG: 6.015 TV: 1.

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 6, iters: 31, time: 0.295, data: 0.005) G_GAN: 0.758 D_real: 0.107 D_fake: 0.393 G: 94.346 NCE: 91.790 REG: 5.583 TV: 0.568 MAE: 0.172 MAE_pseudo: 1.626 LAB: 0.000 
(epoch: 6, iters: 56, time: 0.295, data: 0.005) G_GAN: 0.354 D_real: 0.161 D_fake: 0.194 G: 88.406 NCE: 86.076 REG: 6.522 TV: 1.060 MAE: 0.169 MAE_pseudo: 1.807 LAB: 0.000 
(epoch: 6, iters: 87, time: 0.295, data: 0.005) G_GAN: 0.328 D_real: 0.299 D_fake: 0.212 G: 95.049 NCE: 93.648 REG: 5.446 TV: 0.574 MAE: 0.098 MAE_pseudo: 0.976 LAB: 0.000 
(epoch: 6, iters: 118, time: 0.296, data: 0.005) G_GAN: 0.615 D_real: 0.082 D_fake: 0.191 G: 93.006 NCE: 90.501 REG: 6.036 TV: 0.547 MAE: 0.195 MAE_pseudo: 1.696 LAB: 0.000 
(epoch: 6, iters: 143, time: 0.295, data: 0.005) G_GAN: 0.424 D_real: 0.181 D_fake: 0.094 G: 90.012 NCE: 87.886 REG: 5.236 TV: 0.360 MAE: 0.162 MAE_pseudo: 1.539 LAB: 0.000 
(epoch: 6, iters: 168, time: 0.295, data: 0.005) G_GAN: 0.383 D_real: 0.305 D_fake: 0.131 G: 90.653 NCE: 88.306 REG: 7.203 TV: 0.843 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 11, iters: 31, time: 0.295, data: 0.005) G_GAN: 0.295 D_real: 0.320 D_fake: 0.184 G: 91.197 NCE: 89.605 REG: 5.054 TV: 0.613 MAE: 0.118 MAE_pseudo: 1.179 LAB: 0.000 
(epoch: 11, iters: 56, time: 0.295, data: 0.005) G_GAN: 0.465 D_real: 0.176 D_fake: 0.314 G: 79.774 NCE: 78.083 REG: 7.081 TV: 0.848 MAE: 0.112 MAE_pseudo: 1.114 LAB: 0.000 
(epoch: 11, iters: 81, time: 0.295, data: 0.005) G_GAN: 0.408 D_real: 0.291 D_fake: 0.164 G: 83.377 NCE: 81.709 REG: 6.651 TV: 0.826 MAE: 0.131 MAE_pseudo: 1.129 LAB: 0.000 
(epoch: 11, iters: 106, time: 0.295, data: 0.005) G_GAN: 0.481 D_real: 0.407 D_fake: 0.399 G: 86.829 NCE: 85.188 REG: 4.821 TV: 0.388 MAE: 0.108 MAE_pseudo: 1.051 LAB: 0.000 
(epoch: 11, iters: 131, time: 0.295, data: 0.005) G_GAN: 0.368 D_real: 0.313 D_fake: 0.216 G: 78.950 NCE: 77.338 REG: 5.674 TV: 0.950 MAE: 0.126 MAE_pseudo: 1.118 LAB: 0.000 
(epoch: 11, iters: 162, time: 0.295, data: 0.005) G_GAN: 0.221 D_real: 0.114 D_fake: 0.462 G: 86.997 NCE: 85.495 REG: 5.700 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 16, iters: 31, time: 0.295, data: 0.005) G_GAN: 0.262 D_real: 0.084 D_fake: 0.373 G: 82.248 NCE: 81.198 REG: 5.675 TV: 0.570 MAE: 0.091 MAE_pseudo: 0.697 LAB: 0.000 
(epoch: 16, iters: 62, time: 0.295, data: 0.005) G_GAN: 0.351 D_real: 0.225 D_fake: 0.300 G: 82.436 NCE: 81.251 REG: 5.203 TV: 0.519 MAE: 0.082 MAE_pseudo: 0.751 LAB: 0.000 
(epoch: 16, iters: 87, time: 0.295, data: 0.005) G_GAN: 0.526 D_real: 0.190 D_fake: 0.482 G: 86.754 NCE: 85.622 REG: 6.858 TV: 0.844 MAE: 0.058 MAE_pseudo: 0.547 LAB: 0.000 
(epoch: 16, iters: 118, time: 0.295, data: 0.005) G_GAN: 0.374 D_real: 0.281 D_fake: 0.147 G: 85.637 NCE: 84.569 REG: 4.924 TV: 0.627 MAE: 0.063 MAE_pseudo: 0.631 LAB: 0.000 
(epoch: 16, iters: 143, time: 0.295, data: 0.005) G_GAN: 0.278 D_real: 0.231 D_fake: 0.391 G: 81.931 NCE: 80.873 REG: 4.753 TV: 0.407 MAE: 0.066 MAE_pseudo: 0.713 LAB: 0.000 
(epoch: 16, iters: 168, time: 0.295, data: 0.005) G_GAN: 0.473 D_real: 0.153 D_fake: 0.278 G: 71.901 NCE: 70.895 REG: 5.601 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 21, iters: 31, time: 0.295, data: 0.003) G_GAN: 0.416 D_real: 0.026 D_fake: 0.356 G: 82.015 NCE: 80.557 REG: 5.433 TV: 0.631 MAE: 0.125 MAE_pseudo: 0.916 LAB: 0.000 
(epoch: 21, iters: 56, time: 0.295, data: 0.003) G_GAN: 0.467 D_real: 0.187 D_fake: 0.237 G: 70.151 NCE: 68.915 REG: 5.788 TV: 1.287 MAE: 0.076 MAE_pseudo: 0.694 LAB: 0.000 
(epoch: 21, iters: 87, time: 0.295, data: 0.003) G_GAN: 0.451 D_real: 0.253 D_fake: 0.179 G: 88.669 NCE: 87.370 REG: 5.054 TV: 0.644 MAE: 0.077 MAE_pseudo: 0.771 LAB: 0.000 
(epoch: 21, iters: 112, time: 0.295, data: 0.003) G_GAN: 0.368 D_real: 0.211 D_fake: 0.268 G: 87.539 NCE: 85.905 REG: 6.681 TV: 0.708 MAE: 0.111 MAE_pseudo: 1.155 LAB: 0.000 
(epoch: 21, iters: 143, time: 0.296, data: 0.003) G_GAN: 0.252 D_real: 0.083 D_fake: 0.457 G: 85.319 NCE: 84.023 REG: 5.175 TV: 0.443 MAE: 0.102 MAE_pseudo: 0.943 LAB: 0.000 
(epoch: 21, iters: 168, time: 0.296, data: 0.003) G_GAN: 0.504 D_real: 0.245 D_fake: 0.105 G: 72.676 NCE: 71.499 REG: 7.171 TV: 

  return n/db/n.sum(), bin_edges


End of epoch 25 / 500 	 Time Taken: 475 sec
learning rate = 0.0002000


  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 26, iters: 31, time: 0.295, data: 0.002) G_GAN: 0.187 D_real: 0.157 D_fake: 0.410 G: 80.620 NCE: 79.885 REG: 6.136 TV: 0.515 MAE: 0.071 MAE_pseudo: 0.477 LAB: 0.000 
(epoch: 26, iters: 56, time: 0.295, data: 0.002) G_GAN: 0.175 D_real: 0.054 D_fake: 0.661 G: 74.613 NCE: 73.683 REG: 5.619 TV: 0.847 MAE: 0.064 MAE_pseudo: 0.691 LAB: 0.000 
(epoch: 26, iters: 81, time: 0.295, data: 0.002) G_GAN: 0.188 D_real: 0.062 D_fake: 0.579 G: 83.731 NCE: 83.088 REG: 7.314 TV: 0.729 MAE: 0.046 MAE_pseudo: 0.409 LAB: 0.000 
(epoch: 26, iters: 106, time: 0.295, data: 0.002) G_GAN: 0.592 D_real: 0.365 D_fake: 0.074 G: 75.171 NCE: 73.917 REG: 6.985 TV: 0.936 MAE: 0.076 MAE_pseudo: 0.586 LAB: 0.000 
(epoch: 26, iters: 137, time: 0.295, data: 0.002) G_GAN: 0.482 D_real: 0.433 D_fake: 0.177 G: 85.576 NCE: 84.657 REG: 4.961 TV: 0.717 MAE: 0.040 MAE_pseudo: 0.398 LAB: 0.000 
(epoch: 26, iters: 162, time: 0.295, data: 0.002) G_GAN: 0.359 D_real: 0.140 D_fake: 0.214 G: 70.336 NCE: 69.549 REG: 7.371 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 31, iters: 25, time: 0.295, data: 0.002) G_GAN: 0.357 D_real: 0.217 D_fake: 0.265 G: 78.800 NCE: 78.181 REG: 4.653 TV: 0.468 MAE: 0.024 MAE_pseudo: 0.237 LAB: 0.000 
(epoch: 31, iters: 50, time: 0.295, data: 0.002) G_GAN: 0.243 D_real: 0.157 D_fake: 0.335 G: 81.922 NCE: 81.363 REG: 7.289 TV: 0.744 MAE: 0.031 MAE_pseudo: 0.286 LAB: 0.000 
(epoch: 31, iters: 75, time: 0.295, data: 0.002) G_GAN: 0.312 D_real: 0.291 D_fake: 0.208 G: 69.423 NCE: 68.776 REG: 6.971 TV: 1.012 MAE: 0.044 MAE_pseudo: 0.291 LAB: 0.000 
(epoch: 31, iters: 106, time: 0.295, data: 0.002) G_GAN: 0.328 D_real: 0.237 D_fake: 0.152 G: 85.483 NCE: 84.705 REG: 5.201 TV: 0.612 MAE: 0.041 MAE_pseudo: 0.409 LAB: 0.000 
(epoch: 31, iters: 131, time: 0.295, data: 0.002) G_GAN: 0.446 D_real: 0.210 D_fake: 0.112 G: 67.461 NCE: 66.729 REG: 5.611 TV: 1.142 MAE: 0.038 MAE_pseudo: 0.247 LAB: 0.000 
(epoch: 31, iters: 156, time: 0.295, data: 0.002) G_GAN: 0.346 D_real: 0.124 D_fake: 0.502 G: 72.540 NCE: 71.728 REG: 5.500 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 36, iters: 31, time: 0.296, data: 0.002) G_GAN: 0.220 D_real: 0.286 D_fake: 0.225 G: 83.338 NCE: 83.001 REG: 4.934 TV: 0.626 MAE: 0.011 MAE_pseudo: 0.106 LAB: 0.000 
(epoch: 36, iters: 56, time: 0.295, data: 0.002) G_GAN: 0.235 D_real: 0.082 D_fake: 0.519 G: 68.446 NCE: 68.009 REG: 6.894 TV: 0.893 MAE: 0.031 MAE_pseudo: 0.171 LAB: 0.000 
(epoch: 36, iters: 81, time: 0.295, data: 0.002) G_GAN: 0.265 D_real: 0.177 D_fake: 0.521 G: 77.837 NCE: 77.477 REG: 4.655 TV: 0.430 MAE: 0.010 MAE_pseudo: 0.084 LAB: 0.000 
(epoch: 36, iters: 112, time: 0.296, data: 0.002) G_GAN: 0.146 D_real: 0.056 D_fake: 0.603 G: 78.611 NCE: 78.273 REG: 5.455 TV: 0.596 MAE: 0.021 MAE_pseudo: 0.170 LAB: 0.000 
(epoch: 36, iters: 137, time: 0.295, data: 0.002) G_GAN: 0.142 D_real: 0.101 D_fake: 0.521 G: 70.179 NCE: 69.839 REG: 5.372 TV: 0.973 MAE: 0.026 MAE_pseudo: 0.172 LAB: 0.000 
(epoch: 36, iters: 168, time: 0.296, data: 0.002) G_GAN: 0.185 D_real: 0.230 D_fake: 0.357 G: 80.685 NCE: 80.313 REG: 5.433 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 41, iters: 31, time: 0.295, data: 0.002) G_GAN: 0.341 D_real: 0.319 D_fake: 0.373 G: 83.994 NCE: 83.583 REG: 4.811 TV: 0.690 MAE: 0.006 MAE_pseudo: 0.063 LAB: 0.000 
(epoch: 41, iters: 62, time: 0.296, data: 0.002) G_GAN: 0.216 D_real: 0.235 D_fake: 0.331 G: 79.247 NCE: 78.865 REG: 4.946 TV: 0.592 MAE: 0.016 MAE_pseudo: 0.150 LAB: 0.000 
(epoch: 41, iters: 87, time: 0.295, data: 0.002) G_GAN: 0.168 D_real: 0.141 D_fake: 0.509 G: 72.961 NCE: 72.594 REG: 6.285 TV: 1.190 MAE: 0.026 MAE_pseudo: 0.173 LAB: 0.000 
(epoch: 41, iters: 118, time: 0.296, data: 0.002) G_GAN: 0.207 D_real: 0.061 D_fake: 0.624 G: 79.504 NCE: 79.127 REG: 5.651 TV: 0.572 MAE: 0.022 MAE_pseudo: 0.149 LAB: 0.000 
(epoch: 41, iters: 143, time: 0.296, data: 0.002) G_GAN: 0.363 D_real: 0.223 D_fake: 0.147 G: 73.172 NCE: 72.665 REG: 6.818 TV: 1.056 MAE: 0.021 MAE_pseudo: 0.123 LAB: 0.000 
(epoch: 41, iters: 168, time: 0.295, data: 0.002) G_GAN: 0.148 D_real: 0.034 D_fake: 0.599 G: 65.207 NCE: 64.949 REG: 5.989 TV: 

  return n/db/n.sum(), bin_edges


End of epoch 45 / 500 	 Time Taken: 472 sec
learning rate = 0.0002000


  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 46, iters: 31, time: 0.296, data: 0.002) G_GAN: 0.204 D_real: 0.251 D_fake: 0.232 G: 83.433 NCE: 83.158 REG: 5.036 TV: 0.585 MAE: 0.006 MAE_pseudo: 0.064 LAB: 0.000 
(epoch: 46, iters: 62, time: 0.296, data: 0.002) G_GAN: 0.212 D_real: 0.271 D_fake: 0.284 G: 80.934 NCE: 80.530 REG: 5.079 TV: 0.523 MAE: 0.017 MAE_pseudo: 0.175 LAB: 0.000 
(epoch: 46, iters: 87, time: 0.296, data: 0.002) G_GAN: 0.411 D_real: 0.204 D_fake: 0.141 G: 72.426 NCE: 71.860 REG: 7.441 TV: 0.787 MAE: 0.023 MAE_pseudo: 0.133 LAB: 0.000 
(epoch: 46, iters: 112, time: 0.296, data: 0.002) G_GAN: 0.287 D_real: 0.117 D_fake: 0.460 G: 70.329 NCE: 69.771 REG: 7.907 TV: 0.807 MAE: 0.036 MAE_pseudo: 0.236 LAB: 0.000 
(epoch: 46, iters: 137, time: 0.295, data: 0.002) G_GAN: 0.161 D_real: 0.098 D_fake: 0.537 G: 75.017 NCE: 74.740 REG: 4.629 TV: 0.407 MAE: 0.010 MAE_pseudo: 0.107 LAB: 0.000 
(epoch: 46, iters: 162, time: 0.295, data: 0.002) G_GAN: 0.341 D_real: 0.351 D_fake: 0.146 G: 81.539 NCE: 81.064 REG: 6.779 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 51, iters: 25, time: 0.295, data: 0.002) G_GAN: 0.369 D_real: 0.318 D_fake: 0.163 G: 82.839 NCE: 82.270 REG: 6.847 TV: 0.719 MAE: 0.014 MAE_pseudo: 0.186 LAB: 0.000 
(epoch: 51, iters: 50, time: 0.295, data: 0.002) G_GAN: 0.245 D_real: 0.080 D_fake: 0.717 G: 70.670 NCE: 70.210 REG: 7.469 TV: 0.851 MAE: 0.033 MAE_pseudo: 0.181 LAB: 0.000 
(epoch: 51, iters: 81, time: 0.295, data: 0.003) G_GAN: 0.191 D_real: 0.036 D_fake: 0.524 G: 79.253 NCE: 78.880 REG: 5.620 TV: 0.534 MAE: 0.020 MAE_pseudo: 0.162 LAB: 0.000 
(epoch: 51, iters: 106, time: 0.295, data: 0.003) G_GAN: 0.279 D_real: 0.135 D_fake: 0.368 G: 71.211 NCE: 70.780 REG: 5.224 TV: 0.879 MAE: 0.026 MAE_pseudo: 0.125 LAB: 0.000 
(epoch: 51, iters: 131, time: 0.295, data: 0.003) G_GAN: 0.201 D_real: 0.095 D_fake: 0.618 G: 75.735 NCE: 75.430 REG: 4.759 TV: 0.444 MAE: 0.007 MAE_pseudo: 0.096 LAB: 0.000 
(epoch: 51, iters: 156, time: 0.295, data: 0.003) G_GAN: 0.218 D_real: 0.079 D_fake: 0.456 G: 66.361 NCE: 66.007 REG: 5.550 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 56, iters: 31, time: 0.295, data: 0.003) G_GAN: 0.273 D_real: 0.169 D_fake: 0.419 G: 88.505 NCE: 88.100 REG: 5.233 TV: 0.603 MAE: 0.012 MAE_pseudo: 0.120 LAB: 0.000 
(epoch: 56, iters: 62, time: 0.295, data: 0.003) G_GAN: 0.284 D_real: 0.321 D_fake: 0.191 G: 81.450 NCE: 80.772 REG: 4.947 TV: 0.502 MAE: 0.033 MAE_pseudo: 0.361 LAB: 0.000 
(epoch: 56, iters: 87, time: 0.295, data: 0.003) G_GAN: 0.372 D_real: 0.316 D_fake: 0.149 G: 75.150 NCE: 74.568 REG: 6.881 TV: 0.917 MAE: 0.024 MAE_pseudo: 0.186 LAB: 0.000 
(epoch: 56, iters: 112, time: 0.295, data: 0.003) G_GAN: 0.201 D_real: 0.052 D_fake: 0.546 G: 69.786 NCE: 69.297 REG: 7.209 TV: 0.800 MAE: 0.040 MAE_pseudo: 0.248 LAB: 0.000 
(epoch: 56, iters: 137, time: 0.295, data: 0.003) G_GAN: 0.397 D_real: 0.316 D_fake: 0.331 G: 83.614 NCE: 83.032 REG: 6.731 TV: 0.798 MAE: 0.015 MAE_pseudo: 0.170 LAB: 0.000 
(epoch: 56, iters: 168, time: 0.296, data: 0.003) G_GAN: 0.137 D_real: 0.023 D_fake: 0.627 G: 78.305 NCE: 77.970 REG: 5.845 TV: 

  return n/db/n.sum(), bin_edges


End of epoch 60 / 500 	 Time Taken: 472 sec
learning rate = 0.0001951


  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 61, iters: 31, time: 0.296, data: 0.003) G_GAN: 0.214 D_real: 0.024 D_fake: 0.496 G: 75.794 NCE: 75.448 REG: 5.328 TV: 0.618 MAE: 0.013 MAE_pseudo: 0.119 LAB: 0.000 
(epoch: 61, iters: 62, time: 0.296, data: 0.003) G_GAN: 0.214 D_real: 0.261 D_fake: 0.232 G: 82.217 NCE: 81.965 REG: 4.868 TV: 0.562 MAE: 0.004 MAE_pseudo: 0.035 LAB: 0.000 
(epoch: 61, iters: 87, time: 0.295, data: 0.003) G_GAN: 0.216 D_real: 0.067 D_fake: 0.506 G: 67.886 NCE: 67.566 REG: 6.930 TV: 0.771 MAE: 0.025 MAE_pseudo: 0.079 LAB: 0.000 
(epoch: 61, iters: 112, time: 0.295, data: 0.003) G_GAN: 0.236 D_real: 0.165 D_fake: 0.430 G: 72.101 NCE: 71.763 REG: 5.325 TV: 0.912 MAE: 0.021 MAE_pseudo: 0.080 LAB: 0.000 
(epoch: 61, iters: 137, time: 0.295, data: 0.003) G_GAN: 0.186 D_real: 0.106 D_fake: 0.492 G: 66.091 NCE: 65.816 REG: 5.422 TV: 1.000 MAE: 0.031 MAE_pseudo: 0.059 LAB: 0.000 
(epoch: 61, iters: 162, time: 0.295, data: 0.003) G_GAN: 0.166 D_real: 0.064 D_fake: 0.546 G: 78.314 NCE: 78.065 REG: 4.827 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 66, iters: 31, time: 0.295, data: 0.003) G_GAN: 0.831 D_real: 0.259 D_fake: 0.025 G: 84.319 NCE: 83.129 REG: 4.970 TV: 0.590 MAE: 0.033 MAE_pseudo: 0.326 LAB: 0.000 
(epoch: 66, iters: 56, time: 0.295, data: 0.003) G_GAN: 0.195 D_real: 0.173 D_fake: 0.387 G: 64.897 NCE: 64.595 REG: 5.434 TV: 1.022 MAE: 0.031 MAE_pseudo: 0.076 LAB: 0.000 
(epoch: 66, iters: 87, time: 0.296, data: 0.003) G_GAN: 0.332 D_real: 0.341 D_fake: 0.152 G: 81.092 NCE: 80.587 REG: 4.857 TV: 0.546 MAE: 0.020 MAE_pseudo: 0.153 LAB: 0.000 
(epoch: 66, iters: 112, time: 0.295, data: 0.003) G_GAN: 0.159 D_real: 0.026 D_fake: 0.693 G: 69.684 NCE: 69.404 REG: 6.988 TV: 0.857 MAE: 0.030 MAE_pseudo: 0.090 LAB: 0.000 
(epoch: 66, iters: 137, time: 0.295, data: 0.003) G_GAN: 0.322 D_real: 0.344 D_fake: 0.229 G: 71.453 NCE: 70.996 REG: 6.275 TV: 0.903 MAE: 0.021 MAE_pseudo: 0.113 LAB: 0.000 
(epoch: 66, iters: 162, time: 0.295, data: 0.003) G_GAN: 0.176 D_real: 0.111 D_fake: 0.438 G: 74.725 NCE: 74.482 REG: 4.661 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 71, iters: 25, time: 0.295, data: 0.002) G_GAN: 0.260 D_real: 0.134 D_fake: 0.347 G: 63.542 NCE: 63.223 REG: 5.410 TV: 0.919 MAE: 0.019 MAE_pseudo: 0.041 LAB: 0.000 
(epoch: 71, iters: 50, time: 0.295, data: 0.002) G_GAN: 0.234 D_real: 0.168 D_fake: 0.406 G: 70.783 NCE: 70.450 REG: 6.636 TV: 0.866 MAE: 0.017 MAE_pseudo: 0.082 LAB: 0.000 
(epoch: 71, iters: 81, time: 0.295, data: 0.002) G_GAN: 0.221 D_real: 0.281 D_fake: 0.234 G: 82.729 NCE: 82.488 REG: 4.701 TV: 0.660 MAE: 0.002 MAE_pseudo: 0.018 LAB: 0.000 
(epoch: 71, iters: 106, time: 0.295, data: 0.002) G_GAN: 0.236 D_real: 0.145 D_fake: 0.489 G: 82.313 NCE: 81.995 REG: 6.420 TV: 0.699 MAE: 0.003 MAE_pseudo: 0.080 LAB: 0.000 
(epoch: 71, iters: 137, time: 0.296, data: 0.002) G_GAN: 0.251 D_real: 0.055 D_fake: 0.440 G: 76.678 NCE: 76.333 REG: 5.283 TV: 0.629 MAE: 0.008 MAE_pseudo: 0.087 LAB: 0.000 
(epoch: 71, iters: 162, time: 0.296, data: 0.002) G_GAN: 0.230 D_real: 0.115 D_fake: 0.468 G: 72.605 NCE: 72.190 REG: 7.973 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 76, iters: 31, time: 0.295, data: 0.002) G_GAN: 0.225 D_real: 0.042 D_fake: 0.450 G: 78.419 NCE: 78.089 REG: 5.151 TV: 0.607 MAE: 0.010 MAE_pseudo: 0.094 LAB: 0.000 
(epoch: 76, iters: 62, time: 0.296, data: 0.002) G_GAN: 0.335 D_real: 0.412 D_fake: 0.143 G: 81.239 NCE: 80.760 REG: 4.756 TV: 0.557 MAE: 0.013 MAE_pseudo: 0.130 LAB: 0.000 
(epoch: 76, iters: 87, time: 0.295, data: 0.002) G_GAN: 0.212 D_real: 0.100 D_fake: 0.408 G: 76.068 NCE: 75.754 REG: 4.709 TV: 0.394 MAE: 0.006 MAE_pseudo: 0.096 LAB: 0.000 
(epoch: 76, iters: 112, time: 0.295, data: 0.002) G_GAN: 0.314 D_real: 0.193 D_fake: 0.216 G: 81.568 NCE: 81.143 REG: 7.483 TV: 0.723 MAE: 0.003 MAE_pseudo: 0.108 LAB: 0.000 
(epoch: 76, iters: 137, time: 0.295, data: 0.002) G_GAN: 0.235 D_real: 0.158 D_fake: 0.321 G: 63.203 NCE: 62.877 REG: 5.352 TV: 0.947 MAE: 0.033 MAE_pseudo: 0.059 LAB: 0.000 
(epoch: 76, iters: 162, time: 0.295, data: 0.002) G_GAN: 0.284 D_real: 0.247 D_fake: 0.314 G: 70.586 NCE: 70.180 REG: 4.941 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 81, iters: 25, time: 0.295, data: 0.002) G_GAN: 0.245 D_real: 0.103 D_fake: 0.408 G: 78.197 NCE: 77.872 REG: 4.607 TV: 0.370 MAE: 0.006 MAE_pseudo: 0.073 LAB: 0.000 
(epoch: 81, iters: 50, time: 0.295, data: 0.002) G_GAN: 0.276 D_real: 0.287 D_fake: 0.255 G: 71.024 NCE: 70.658 REG: 6.174 TV: 0.837 MAE: 0.015 MAE_pseudo: 0.075 LAB: 0.000 
(epoch: 81, iters: 81, time: 0.295, data: 0.002) G_GAN: 0.248 D_real: 0.310 D_fake: 0.224 G: 81.965 NCE: 81.702 REG: 5.142 TV: 0.566 MAE: 0.001 MAE_pseudo: 0.013 LAB: 0.000 
(epoch: 81, iters: 106, time: 0.295, data: 0.002) G_GAN: 0.126 D_real: 0.102 D_fake: 0.616 G: 64.485 NCE: 64.293 REG: 5.471 TV: 0.936 MAE: 0.029 MAE_pseudo: 0.037 LAB: 0.000 
(epoch: 81, iters: 131, time: 0.295, data: 0.002) G_GAN: 0.276 D_real: 0.277 D_fake: 0.322 G: 68.229 NCE: 67.860 REG: 6.787 TV: 0.881 MAE: 0.023 MAE_pseudo: 0.070 LAB: 0.000 
(epoch: 81, iters: 162, time: 0.295, data: 0.002) G_GAN: 0.207 D_real: 0.059 D_fake: 0.438 G: 78.500 NCE: 78.205 REG: 5.323 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 86, iters: 25, time: 0.295, data: 0.002) G_GAN: 0.312 D_real: 0.087 D_fake: 0.353 G: 73.622 NCE: 73.246 REG: 4.513 TV: 0.389 MAE: 0.006 MAE_pseudo: 0.058 LAB: 0.000 
(epoch: 86, iters: 56, time: 0.295, data: 0.002) G_GAN: 0.638 D_real: 0.333 D_fake: 0.428 G: 94.354 NCE: 93.381 REG: 9.350 TV: 0.609 MAE: 0.031 MAE_pseudo: 0.304 LAB: 0.000 
(epoch: 86, iters: 81, time: 0.295, data: 0.002) G_GAN: 0.403 D_real: 0.465 D_fake: 0.135 G: 95.376 NCE: 94.710 REG: 6.771 TV: 0.822 MAE: 0.022 MAE_pseudo: 0.242 LAB: 0.000 
(epoch: 86, iters: 112, time: 0.295, data: 0.002) G_GAN: 0.291 D_real: 0.024 D_fake: 0.319 G: 84.013 NCE: 83.276 REG: 6.017 TV: 0.637 MAE: 0.031 MAE_pseudo: 0.416 LAB: 0.000 
(epoch: 86, iters: 137, time: 0.295, data: 0.002) G_GAN: 0.403 D_real: 0.221 D_fake: 0.145 G: 71.379 NCE: 70.788 REG: 6.137 TV: 0.979 MAE: 0.026 MAE_pseudo: 0.162 LAB: 0.000 
(epoch: 86, iters: 168, time: 0.295, data: 0.002) G_GAN: 0.385 D_real: 0.379 D_fake: 0.150 G: 79.647 NCE: 78.702 REG: 5.032 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 91, iters: 31, time: 0.295, data: 0.002) G_GAN: 0.735 D_real: 0.390 D_fake: 0.512 G: 82.663 NCE: 81.907 REG: 4.752 TV: 0.578 MAE: 0.002 MAE_pseudo: 0.020 LAB: 0.000 
(epoch: 91, iters: 62, time: 0.295, data: 0.002) G_GAN: 0.185 D_real: 0.085 D_fake: 0.507 G: 78.299 NCE: 78.023 REG: 5.063 TV: 0.616 MAE: 0.007 MAE_pseudo: 0.084 LAB: 0.000 
(epoch: 91, iters: 93, time: 0.295, data: 0.002) G_GAN: 0.312 D_real: 0.406 D_fake: 0.182 G: 80.198 NCE: 79.812 REG: 4.727 TV: 0.541 MAE: 0.009 MAE_pseudo: 0.065 LAB: 0.000 
(epoch: 91, iters: 118, time: 0.295, data: 0.002) G_GAN: 0.213 D_real: 0.123 D_fake: 0.411 G: 76.009 NCE: 75.738 REG: 4.446 TV: 0.399 MAE: 0.005 MAE_pseudo: 0.052 LAB: 0.000 
(epoch: 91, iters: 143, time: 0.295, data: 0.002) G_GAN: 0.278 D_real: 0.437 D_fake: 0.228 G: 71.073 NCE: 70.645 REG: 7.118 TV: 0.815 MAE: 0.025 MAE_pseudo: 0.125 LAB: 0.000 
(epoch: 91, iters: 168, time: 0.295, data: 0.002) G_GAN: 0.187 D_real: 0.138 D_fake: 0.402 G: 82.679 NCE: 82.406 REG: 6.618 TV: 

  patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)


(epoch: 96, iters: 25, time: 0.295, data: 0.001) G_GAN: 0.313 D_real: 0.137 D_fake: 0.267 G: 67.306 NCE: 66.919 REG: 5.289 TV: 1.043 MAE: 0.025 MAE_pseudo: 0.049 LAB: 0.000 
(epoch: 96, iters: 56, time: 0.295, data: 0.001) G_GAN: 0.322 D_real: 0.309 D_fake: 0.223 G: 81.491 NCE: 81.099 REG: 4.813 TV: 0.560 MAE: 0.008 MAE_pseudo: 0.062 LAB: 0.000 
(epoch: 96, iters: 81, time: 0.295, data: 0.001) G_GAN: 0.291 D_real: 0.269 D_fake: 0.225 G: 82.554 NCE: 82.186 REG: 6.314 TV: 0.787 MAE: 0.001 MAE_pseudo: 0.075 LAB: 0.000 
(epoch: 96, iters: 112, time: 0.295, data: 0.001) G_GAN: 0.921 D_real: 0.130 D_fake: 0.025 G: 98.391 NCE: 97.173 REG: 5.558 TV: 0.605 MAE: 0.027 MAE_pseudo: 0.270 LAB: 0.000 
(epoch: 96, iters: 137, time: 0.295, data: 0.001) G_GAN: 0.301 D_real: 0.200 D_fake: 0.222 G: 90.665 NCE: 90.078 REG: 5.151 TV: 0.972 MAE: 0.030 MAE_pseudo: 0.255 LAB: 0.000 
(epoch: 96, iters: 168, time: 0.295, data: 0.001) G_GAN: 0.220 D_real: 0.031 D_fake: 0.616 G: 82.945 NCE: 82.322 REG: 5.663 TV: 

0.19.3
