In [1]:
import cv2
import time
import torch
import matplotlib.pyplot as plt
import math
import numpy as np
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.local_visualizer import LocalVisualizer

In [2]:
NAME = 'gen_new5'
DS_NAME = ''

defaults = {
    'dataroot': f'../../generator_data/{DS_NAME}',
    'model': 'my',
    'dataset_mode': 'my',
    'dataset_root': f'../../generator_data/{DS_NAME}',
    'embedding_save_dir': './checkpoints/emb_vidit4',
    'name': NAME
}

# defaults = {
#     'dataroot': '../../embedding_data/vidit/',
#     'model': 'emb',
#     'dataset_mode': 'emb',
#     'dataset_root': '../../embedding_data/vidit/',
#     'name': NAME
# }


opt = TrainOptions(defaults=defaults).parse()

----------------- Options ---------------
               batch_size: 1                             
                    beta1: 0.5                           
          checkpoints_dir: ./checkpoints                 
           continue_train: False                         
                crop_size: 256                           
               d_lr_ratio: 1.0                           
                 dataroot: ../../generator_data/         
             dataset_mode: my                            
             dataset_root: ../../generator_data/         
                direction: AtoB                          
              display_env: main                          
             display_freq: 400                           
               display_id: 1                             
            display_ncols: 4                             
             display_port: 9333                          
           display_server: http://localhost              
          display_winsize: 256

In [3]:
# opt.save_epoch_freq = 10
# opt.display_freq = 1000
# opt.print_freq = 200
# opt.save_latest_freq = 10000
# opt.batch_size = 8

opt.beta1 = 0.9

opt.save_epoch_freq = 10
opt.display_freq = 40000
opt.print_freq = 8000
opt.save_latest_freq = 60000
opt.batch_size = 32

assert opt.isTrain == True

In [4]:
datasets = {}
dataset_sizes = {}

datasets['train'] = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
dataset_sizes['train'] = len(datasets['train'])    # get the number of images in the dataset.
print('The number of training images = %d' % dataset_sizes['train'])

model = create_model(opt)      # create a model given opt.model and other options
model.setup(opt)               # regular setup: load and print networks; create schedulers
visualizer = LocalVisualizer(opt)   # create a visualizer that display/save images and plots

loading train file
dataset [MyDataset] was created
The number of training images = 65699
initialize network with normal
initialize network with normal
loading the model from ./checkpoints/emb_vidit4/150_net_G.pth
---------- Networks initialized -------------
[Network G] Total number of parameters : 54.415 M
-----------------------------------------------
model [MyModel] was created
---------- Networks initialized -------------
[Network G] Total number of parameters : 73.946 M
-----------------------------------------------


In [5]:
opt.isTrain = False

train_batch_size = opt.batch_size
test_batch_size = 1

opt.num_threads = 0   # test code only supports num_threads = 0
opt.batch_size = test_batch_size    # test code only supports batch_size = 1
opt.serial_batches = True  # disable data shuffling;

datasets['test'] = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
dataset_sizes['test'] = len(datasets['test'])    # get the number of images in the dataset.
print('The number of test images = %d' % dataset_sizes['test'])
ds_ratio = dataset_sizes['train'] / dataset_sizes['test']

opt.batch_size = train_batch_size
opt.isTrain = True

loading test file
dataset [MyDataset] was created
The number of test images = 7393


In [6]:
from skimage.metrics import mean_squared_error as mse
import torchvision.transforms as tf


class NormalizeInverse(tf.Normalize):
    """
    Undoes the normalization and returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor)
    
unnorm = NormalizeInverse((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    
def to_img(tensor):
    tensor = torch.squeeze(tensor.detach().to('cpu'))
    if unnorm is not None:
        tensor = unnorm(tensor)
    np_img = tensor.numpy()
    np_img = np_img.transpose((1, 2, 0))
    return (np_img.clip(0, 1) * 255).astype(np.uint8)

In [7]:
total_iters = 0
train_iters = 0
min_test_loss = math.inf

phases = ['train']
# model.train()
# phases.append('test')

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    visualizer.set_epoch(epoch)

    if epoch == 40:
        phases.append('test')
    
    for phase in phases:
        epoch_start_time = time.time()  # timer for entire epoch
#         iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch
        epoch_loss = 0.0
        dataset = datasets[phase]
        visualizer.set_phase(phase)
        
        if phase == 'train':
            isTrain = True
            batch_size = train_batch_size
            model.train()  # Set model to training mode
        else:
            isTrain = False
            batch_size = test_batch_size
            model.eval()   # Set model to evaluate mode
        
        for data in dataset:  # inner loop within one epoch
#             iter_start_time = time.time()  # timer for computation per iteration
#             if total_iters % opt.print_freq == 0:
#                 t_data = iter_start_time - iter_data_time
            
#             if isTrain:
#                 train_iters += batch_size
#             total_iters += batch_size
#             epoch_iter += batch_size
            
            model.set_input(data)         # unpack data from dataset and apply preprocessing

            if phase == 'train':
                model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
            else:
                model.test()
                
                real = to_img(model.real)
                harmonized = to_img(model.harmonized)
                epoch_loss += mse(harmonized, real)

#             losses = model.get_current_losses()
#             epoch_loss += losses['G_L1']
                
#             if total_iters % opt.display_freq == 0:   # display images 
#                 model.compute_visuals()
#                 visualizer.display_visuals(model.get_current_visuals())

#             if total_iters % opt.print_freq == 0:    
#                 losses = model.get_current_losses()
#                 t_comp = (time.time() - iter_start_time) / opt.batch_size
#                 visualizer.print_current_losses(epoch_iter, losses, t_comp, t_data)

#             if isTrain and train_iters % opt.save_latest_freq == 0:   
#                 print('saving the latest model (epoch %d, train_iters %d)' % (epoch, train_iters))
#                 save_suffix = 'iter_%d' % train_iters if opt.save_by_iter else 'latest'
#                 model.save_networks(save_suffix)

#             iter_data_time = time.time()

        if phase == 'train' epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print(f'End of {phase.capitalize()} epoch {epoch} / {opt.niter+opt.niter_decay} \t ',
              f'Time Taken: {time.time() - epoch_start_time:.3f} sec', sep='')
        
        epoch_loss /= dataset_sizes[phase]
#         visualizer.add_epoch_loss(epoch_loss)
        if phase == 'test' and epoch_loss < min_test_loss:
            model.save_networks('best')
            min_test_loss = epoch_loss
            print(f'Updating best model at epoch {epoch}, average test loss: {epoch_loss:.4f}')
    
#     visualizer.plot_epoch_losses()
    model.update_learning_rate()

End of Train epoch 1 / 160 	 Time Taken: 632.500 sec
End of Train epoch 2 / 160 	 Time Taken: 640.488 sec
End of Train epoch 3 / 160 	 Time Taken: 596.464 sec
End of Train epoch 4 / 160 	 Time Taken: 596.667 sec
End of Train epoch 6 / 160 	 Time Taken: 596.773 sec
End of Train epoch 7 / 160 	 Time Taken: 596.393 sec
End of Train epoch 8 / 160 	 Time Taken: 610.868 sec
End of Train epoch 9 / 160 	 Time Taken: 659.157 sec
saving the model at the end of epoch 10, iters 0
End of Train epoch 10 / 160 	 Time Taken: 764.822 sec
End of Train epoch 11 / 160 	 Time Taken: 596.125 sec
End of Train epoch 12 / 160 	 Time Taken: 596.696 sec
End of Train epoch 13 / 160 	 Time Taken: 596.125 sec
End of Train epoch 14 / 160 	 Time Taken: 596.490 sec
End of Train epoch 15 / 160 	 Time Taken: 596.669 sec
End of Train epoch 16 / 160 	 Time Taken: 671.992 sec
End of Train epoch 17 / 160 	 Time Taken: 595.984 sec
End of Train epoch 18 / 160 	 Time Taken: 595.905 sec
End of Train epoch 19 / 160 	 Time Taken:

KeyboardInterrupt: 