In [1]:
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
import torch
import itertools

In [2]:
class BaseOptions():
    pass

In [3]:
opt = BaseOptions()
opt.max_dataset_size = 100000
opt.load_size = 112
opt.crop_size = 112
opt.no_flip = False
opt.batch_size = 32
opt.serial_batches = False
opt.num_threads = 0
opt.preprocess = 'resize_and_crop'

In [4]:
opt.dataroot = '/Projects/NIR_FR_PTH/data/Oulu_ALIGN'
opt.dataset_mode = 'paired_nir'
opt.dataset_name = 'oulu'
dataset_paired = create_dataset(opt)
dataset_paired_size = len(dataset_paired)    # get the number of images in the dataset.
opt.dataroot = '/Projects/NIR_FR_PTH/data/CASIA_ALIGN'
opt.dataset_mode = 'unpaired_nir'
opt.dataset_name = 'casia'
dataset_unpaired = create_dataset(opt)
dataset_unpaired_size = len(dataset_unpaired)    # get the number of images in the dataset.
print('The number of training images: paired = %d, unpaired = %d' % (dataset_paired_size, dataset_unpaired_size))
if dataset_paired_size >= dataset_unpaired_size:
    dl_more = dataset_paired
    dl_less = dataset_unpaired
    dl_flag = 'pair'
else:
    dl_more = dataset_unpaired
    dl_less = dataset_paired
    dl_flag = 'unpair'

dataset_size = len(dl_more)

dataset [PairedNirDataset] was created
dataset [UnpairedNirDataset] was created
The number of training images: paired = 9687, unpaired = 12487


In [5]:
opt.isTrain = True
opt.ngf = 64
opt.ndf = 64
opt.checkpoints_dir = './checkpoints'
opt.netD = 'basic'
opt.netG = 'resnet_9blocks'
opt.input_nc = 3
opt.output_nc = 3
opt.norm = 'batch'
opt.init_type = 'kaiming'
opt.init_gain = 0.02
opt.no_dropout = True
opt.n_layers_D = 3
opt.gan_mode = 'vanilla'
opt.lr = 0.0002
opt.beta1 = 0.5
opt.pool_size = 50
opt.name = 'first_try'
opt.model = 'nir2vis'
opt.lambda_L1 = 100.0

# set gpu ids
opt.gpu_ids = '1'
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
    id = int(str_id)
    if id >= 0:
        opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
    torch.cuda.set_device(opt.gpu_ids[0])
model = create_model(opt)

initialize network with kaiming
initialize network with kaiming
initialize network with kaiming
model [Nir2VisModel] was created


In [6]:
opt.lr_policy = 'linear'
opt.epoch_count = 1
opt.niter = 10
opt.niter_decay = 10
opt.continue_train = False
opt.epoch_count = 1
opt.verbose = True
model.setup(opt)

---------- Networks initialized -------------
DataParallel(
  (module): ResnetGenerator(
    (model): Sequential(
      (0): ReflectionPad2d((3, 3, 3, 3))
      (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace)
      (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace)
      (10): ResnetBlock(
        (conv_block): Sequential(
          (0): ReflectionPad2d((1, 1, 1, 1))
          (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
          (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, af

In [9]:
opt.display_id = 1
opt.display_freq = 100
opt.display_ncols = 3
opt.display_server = "http://localhost"
opt.display_env = 'main'
opt.display_port = 8097
opt.display_winsize = 256
opt.update_html_freq = 100
opt.no_html = True
opt.print_freq = 100
opt.save_latest_freq = 10000
visualizer = Visualizer(opt)



In [10]:
total_iters = 0                # the total number of training iterations

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
    epoch_start_time = time.time()  # timer for entire epoch
    iter_data_time = time.time()    # timer for data loading per iteration
    epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch

    #for i, data in enumerate(dataset):  # inner loop within one epoch
    for i, data in enumerate(zip(itertools.cycle(dl_less), dl_more)):  
        iter_start_time = time.time()  # timer for computation per iteration
        if total_iters % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        visualizer.reset()
        total_iters += opt.batch_size
        epoch_iter += opt.batch_size
        if dl_flag == 'pair':
            model.set_input(data[1], data[0])         # unpack data from dataset and apply preprocessing
        else:
            model.set_input(data[0], data[1])         # unpack data from dataset and apply preprocessing

        model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

        if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
            save_result = total_iters % opt.update_html_freq == 0
            model.compute_visuals()
            visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

        if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
            losses = model.get_current_losses()
            t_comp = (time.time() - iter_start_time) / opt.batch_size
            visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
            if opt.display_id > 0:
                visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

        if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
            print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
            save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
            model.save_networks(save_suffix)

        iter_data_time = time.time()
    if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
        model.save_networks('latest')
        model.save_networks(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
    model.update_learning_rate()                     # update learning rates at the end of every epoch.

(epoch: 1, iters: 800, time: 0.014, data: 0.205) G_GAN_P: 1.266 G_GAN_U: 3.469 G_L1: 29.111 D_P_real: 0.460 D_P_fake: 0.518 D_U_real: 0.044 D_U_fake: 0.073 
(epoch: 1, iters: 1600, time: 0.014, data: 0.151) G_GAN_P: 1.200 G_GAN_U: 3.384 G_L1: 26.555 D_P_real: 0.332 D_P_fake: 0.859 D_U_real: 0.097 D_U_fake: 0.027 
(epoch: 1, iters: 2400, time: 0.015, data: 0.200) G_GAN_P: 1.216 G_GAN_U: 4.351 G_L1: 24.372 D_P_real: 0.485 D_P_fake: 0.440 D_U_real: 0.009 D_U_fake: 0.033 
(epoch: 1, iters: 3200, time: 0.015, data: 0.233) G_GAN_P: 1.603 G_GAN_U: 3.069 G_L1: 24.057 D_P_real: 0.545 D_P_fake: 0.252 D_U_real: 0.126 D_U_fake: 0.430 
(epoch: 1, iters: 4000, time: 0.015, data: 0.195) G_GAN_P: 1.138 G_GAN_U: 0.762 G_L1: 24.486 D_P_real: 0.454 D_P_fake: 0.398 D_U_real: 1.276 D_U_fake: 0.030 


KeyboardInterrupt: 

In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
from models import networks
from models.pthresnet import PthResNetSimple
import sys
sys.path.append('/Projects/mk_utils/')

In [2]:
class BaseOptions():
    pass

In [3]:
ModelGan = BaseOptions()
ModelGan.pretrainModel = '/Projects/mk_utils/Convert_Mxnet_to_Pytorch/Pytorch_NewModel_state_dict.pth'
ModelGan.ftExtractorCutNum = -3

In [4]:
preModel_weights = torch.load(ModelGan.pretrainModel)
myModel = PthResNetSimple(3,112,512,[1,2,5,2],[64, 64, 128, 256, 512],res_ver='v3')
myModel.load_state_dict(preModel_weights)

In [5]:
preModel = list(myModel.children())[0]
encoder = nn.Sequential(*list(preModel.children())[:ModelGan.ftExtractorCutNum])

In [6]:
encoder

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): BatchNorm2d(64, eps=2e-05, momentum=0.9, affine=True, track_running_stats=True)
  (2): PReLU(num_parameters=64)
  (3): ResnetBlock_V3(
    (bn_stem): BatchNorm2d(64, eps=2e-05, momentum=0.9, affine=True, track_running_stats=True)
    (conv_red): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (bn_red): BatchNorm2d(64, eps=2e-05, momentum=0.9, affine=True, track_running_stats=True)
    (conv_block): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=2e-05, momentum=0.9, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=64)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=2e-05, momentum=0.9, affine=True, track_running_stats=True)
    )
  )
  (4): ResnetBlock_V3(
    (bn_stem): BatchNorm2d(64, eps=2

In [15]:
preModel

OrderedDict([('model.0.weight',
              tensor([[[[-3.1602e-32, -2.7851e-32, -3.5874e-32],
                        [-2.2606e-32, -1.7628e-32, -2.1528e-32],
                        [-2.4243e-32, -1.3120e-32, -1.4962e-32]],
              
                       [[ 4.2243e-33,  5.9648e-33,  3.1711e-33],
                        [ 1.3723e-34, -4.7007e-34,  1.3039e-33],
                        [-3.6474e-33, -1.3004e-33,  1.8638e-34]],
              
                       [[ 6.0295e-33,  6.2360e-33,  9.3637e-33],
                        [ 3.8564e-34, -1.3620e-33,  7.0870e-33],
                        [ 2.1905e-33,  4.1704e-33,  1.1369e-32]]],
              
              
                      [[[-8.7749e-42, -6.8902e-42, -1.1292e-41],
                        [-4.2726e-42, -1.6816e-42, -7.3652e-42],
                        [-1.0532e-41, -8.2817e-42, -1.1712e-41]],
              
                       [[ 1.3765e-41,  4.6972e-42,  3.2987e-42],
                        [-4.5977e-42,  4.74

NameError: name 'fastai' is not defined