In [2]:

# In[1]:


import time
import sys
import torch
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from util.metrics import PSNR, SSIM
from skimage.measure import compare_psnr


# ## Import dataloader and show it

# In[2]:

sys.argv = ['train.py','--dataroot', '/scratch/user/jiangziyu/train/',
             '--learn_residual', '--resize_or_crop', 'scale_width',
             '--fineSize', '256','--batchSize','4','--name','fullModelSupervised','--model','pix2pix']

opt = TrainOptions().parse()


# In[3]:


#get_ipython().magic('matplotlib inline')
import numpy as np
import matplotlib.pyplot as plt
from data.full_model_dataset import fullModelDataSet 

dataset = fullModelDataSet()
dataset.initialize(opt)


# In[4]:


dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))


# ## define model and load pretrained weights

# In[5]:


import os
from torch.autograd import Variable
from collections import OrderedDict
from models import networks
from models import multi_in_networks

def load_network(network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
        network.load_state_dict(torch.load(save_path))
def save_network(network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if torch.cuda.is_available():
            network.cuda()

netG_deblur = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.gpu_ids, False,
                                      opt.learn_residual)
netG_blur = multi_in_networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.gpu_ids, False,
                                      opt.learn_residual)
use_sigmoid = opt.gan_type == 'gan'
netD = networks.define_D(opt.output_nc, opt.ndf,
                                  opt.which_model_netD,
                                  opt.n_layers_D, opt.norm, use_sigmoid, opt.gpu_ids, False)

load_network(netG_deblur, 'deblur_G', opt.which_epoch)
load_network(netG_blur, 'blur_G', opt.which_epoch)

print('------- Networks deblur_G initialized ---------')
networks.print_network(netG_deblur)
print('-----------------------------------------------')

print('------- Networks deblur_D initialized ---------')
networks.print_network(netD)
print('-----------------------------------------------')
load_network(netD, 'D', opt.which_epoch)
# ### Freeze layers

# In[6]:


def freeze_single_input(model,num_layers_frozen=19):

    ct=0
    for child in list(model.children())[0]:
        ct+=1
        if ct<num_layers_frozen:
            for param in child.parameters():
                param.requires_grad=False


    print("Total number of layers are:",ct,",number of layers frozen are:", num_layers_frozen)
    return model

def freeze_multi_input(model,num_layers_frozen=19):
    if num_layers_frozen < 2:
        pass
    for i,child in enumerate(list(model.children())[0].children()):
        if i == 3:
            break
        for param in child.parameters():
            param.requires_grad=False
    ct=0
    for child in list(list(model.children())[0].children())[3]:
        ct+=1
        if ct<num_layers_frozen-2:
            for param in child.parameters():
                param.requires_grad=False


    print("Total number of layers are:",ct+2,",number of layers frozen are:", num_layers_frozen)
    return model

netG_frozen_deblur= freeze_single_input(netG_deblur, num_layers_frozen=0)
netG_frozen_blur= freeze_multi_input(netG_blur, num_layers_frozen=50)
netD_frozen = freeze_single_input(netD,num_layers_frozen=50);

['train.py', '--dataroot', '/scratch/user/jiangziyu/train/', '--learn_residual', '--resize_or_crop', 'scale_width', '--fineSize', '256', '--batchSize', '4', '--name', 'fullModelSupervised', '--model', 'pix2pix']
------------ Options -------------
batchSize: 4
beta1: 0.5
checkpoints_dir: ./checkpoints
continue_train: False
dataroot: /scratch/user/jiangziyu/train/
dataset_mode: aligned
display_freq: 100
display_id: 1
display_port: 8097
display_single_pane_ncols: 0
display_winsize: 256
epoch_count: 1
fineSize: 256
gan_type: wgan-gp
gpu_ids: [0]
identity: 0.0
input_nc: 3
isTrain: True
lambda_A: 100.0
lambda_B: 10.0
learn_residual: True
loadSizeX: 640
loadSizeY: 360
lr: 0.0001
max_dataset_size: inf
model: pix2pix
nThreads: 2
n_layers_D: 3
name: fullModelSupervised
ndf: 64
ngf: 64
niter: 150
niter_decay: 150
no_dropout: False
no_flip: False
no_html: False
norm: instance
output_nc: 3
phase: train
pool_size: 50
print_freq: 100
resize_or_crop: scale_width
save_epoch_freq: 5
save_latest_freq: 50