In [1]:

# In[1]:


import time
import sys
import torch
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from util.metrics import PSNR, SSIM
from skimage.measure import compare_psnr


In [3]:


# ## Import dataloader and show it

# In[2]:

sys.argv = ['fullModel/.py', '--dataroot', '/scratch/user/jiangziyu/train/',
             '--learn_residual', '--resize_or_crop', 'scale_width',
             '--fineSize', '256','--batchSize','4','--name','fullModel']

opt = TrainOptions().parse()




['fullModel/.py', '--dataroot', '/scratch/user/jiangziyu/train/', '--learn_residual', '--resize_or_crop', 'scale_width', '--fineSize', '256', '--batchSize', '4', '--name', 'fullModel']
------------ Options -------------
batchSize: 4
beta1: 0.5
checkpoints_dir: ./checkpoints
continue_train: False
dataroot: /scratch/user/jiangziyu/train/
dataset_mode: aligned
display_freq: 100
display_id: 1
display_port: 8097
display_single_pane_ncols: 0
display_winsize: 256
epoch_count: 1
fineSize: 256
gan_type: wgan-gp
gpu_ids: [0]
identity: 0.0
input_nc: 3
isTrain: True
lambda_A: 100.0
lambda_B: 10.0
learn_residual: True
loadSizeX: 640
loadSizeY: 360
lr: 0.0001
max_dataset_size: inf
model: content_gan
nThreads: 2
n_layers_D: 3
name: fullModel
ndf: 64
ngf: 64
niter: 150
niter_decay: 150
no_dropout: False
no_flip: False
no_html: False
norm: instance
output_nc: 3
phase: train
pool_size: 50
print_freq: 100
resize_or_crop: scale_width
save_epoch_freq: 5
save_latest_freq: 5000
serial_batches: False
which_di

In [4]:
# In[3]:


#get_ipython().magic('matplotlib inline')
import numpy as np
import matplotlib.pyplot as plt
from data.full_model_dataset import fullModelDataSet 

dataset = fullModelDataSet()
dataset.initialize(opt)


In [5]:
# In[4]:


dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))


# ## define model and load pretrained weights

In [6]:
# In[5]:


import os
from torch.autograd import Variable
from collections import OrderedDict
from models import networks
from models import multi_in_networks

def load_network(network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
        network.load_state_dict(torch.load(save_path))
def save_network(network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if torch.cuda.is_available():
            network.cuda()

netG_deblur = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.gpu_ids, False,
                                      opt.learn_residual)
netG_blur = multi_in_networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.gpu_ids, False,
                                      opt.learn_residual)

load_network(netG_deblur, 'deblur_G', opt.which_epoch)
load_network(netG_blur, 'blur_G', opt.which_epoch)
print('------- Networks deblur_G initialized ---------')
networks.print_network(netG_deblur)
print('-----------------------------------------------')

------- Networks deblur_G initialized ---------
ResnetGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False)
    (3): ReLU(inplace)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False)
    (6): ReLU(inplace)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False)
    (9): ReLU(inplace)
    (10): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False)
        (3): ReLU(inplace)
        (4): Dropout(p=0.5)
        (5): ReflectionPad2d((1, 1, 1, 1))
        (6): Conv2d(256, 256, kernel_size=(3, 3), stride=

In [23]:
# ### Freeze layers

# In[6]:


def freeze_single_input(model,num_layers_frozen=19):

    ct=0
    for child in list(netG_deblur.children())[0]:
        ct+=1
        if ct<num_layers_frozen:
            for param in child.parameters():
                param.requires_grad=False


    print("Total number of layers are:",ct,",number of layers frozen are:", num_layers_frozen)
    return model

def freeze_multi_input(model,num_layers_frozen=19):
    if num_layers_frozen < 2:
        pass
    for i,child in enumerate(list(netG_deblur.children())[0].children()):
        if i == 3:
            break
        for param in child.parameters():
            param.requires_grad=False
    ct=0
    for child in list(list(netG_deblur.children())[0].children())[0]:
        ct+=1
        if ct<num_layers_frozen-2:
            for param in child.parameters():
                param.requires_grad=False


    print("Total number of layers are:",ct,",number of layers frozen are:", num_layers_frozen)
    return model

# netG_frozen_deblur= fine_tune_existing_layers(netG_deblur, num_layers_frozen=21)
# netG_frozen_blur= fine_tune_existing_layers(netG_blur, num_layers_frozen=21)

print(list(netG_blur.children())[0])

ResnetGenerator(
  (input0Prc): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
  )
  (input1Prc): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
  )
  (input2Prc): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
  )
  (model): Sequential(
    (0): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False)
    (1): ReLU(inplace)
    (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False)
    (4): ReLU(inplace)
    (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False)
    (7): ReLU(inplace)
    (8): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, 

In [None]:
# ### Net training parameters

# In[8]:


num_epoch=100
num_workers=2
learning_rate=0.0002
transforms=None       #make data augmentation. For now using only the transforms defined above
results_file_path="./results/experiment_name/full_model_results/"

# ### Cycle consistency loss

# In[9]:


import itertools
import util.util as util
import numpy as np

"""Quote from the paper about the loss function: For all the experiments, we set λ = 10 in Equation 3.
We use the Adam solver [24] with a batch size of 1"""

cycle_consistency_criterion= torch.nn.L1Loss()

#criterion= forward_cycle_consistency_criterion+backward_cycle_consistency_criterion()

#lambda_cycle is irrelevant for the moment as we use only cycle consistency loss as of now

optimizer = torch.optim.Adam(itertools.chain(filter(lambda p: p.requires_grad, netG_frozen_deblur.parameters()),
filter(lambda p: p.requires_grad, netG_frozen_blur.parameters())), lr=learning_rate)