In [5]:
import os
import argparse

import time
import datetime

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
#import encoding
from torchvision import transforms

import pytorch_ssim
import dataset
import utils


Trainer

In [13]:
def str2bool(v):
    #print(v)
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Unsupported value encountered.')

#opt: arguments
# ----------------------------------------
#        Initialize the parameters
# ----------------------------------------
parser = argparse.ArgumentParser()
# Pre-train, saving, and loading parameters
parser.add_argument('--save_path', type = str, default = './models/models_k3_d4_ssimloss_SavePredictedKernel', help = 'saving path that is a folder')  #often changed
parser.add_argument('--sample_path', type = str, default = './samples', help = 'training samples path that is a folder')  #often changed
parser.add_argument('--save_mode', type = str, default = 'epoch', help = 'saving mode, and by_epoch saving is recommended')
parser.add_argument('--save_by_epoch', type = int, default = 10, help = 'interval between model checkpoints (by epochs)')
parser.add_argument('--save_by_iter', type = int, default = 100000, help = 'interval between model checkpoints (by iterations)')
parser.add_argument('--load_name', type = str, default = '', help = 'load the pre-trained model with certain epoch')
# GPU parameters
parser.add_argument('--no_gpu', type = str2bool, default = False, help = 'True for CPU')
parser.add_argument('--multi_gpu', type = str2bool, default = False, help = 'True for more than 1 GPU')
#parser.add_argument('--multi_gpu', type = bool, default = False, help = 'True for more than 1 GPU')
parser.add_argument('--gpu_ids', type = str, default = '0, 1, 2, 3', help = 'gpu_ids: e.g. 0  0,1  0,1,2  use -1 for CPU')
parser.add_argument('--cudnn_benchmark', type = str2bool, default = True, help = 'True for unchanged input data type')
# Training parameters
parser.add_argument('--epochs', type = int, default = 1, help = 'number of epochs of training')  #often changed
parser.add_argument('--train_batch_size', type = int, default = 16, help = 'size of the batches')
parser.add_argument('--lr_g', type = float, default = 0.0002, help = 'Adam: learning rate for G / D')
parser.add_argument('--b1', type = float, default = 0.5, help = 'Adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type = float, default = 0.999, help = 'Adam: decay of second order momentum of gradient')
parser.add_argument('--weight_decay', type = float, default = 0, help = 'weight decay for optimizer')
parser.add_argument('--lr_decrease_epoch', type = int, default = 50, help = 'lr decrease at certain epoch and its multiple')
parser.add_argument('--num_workers', type = int, default = 1, help = 'number of cpu threads to use during batch generation')
# Initialization parameters
parser.add_argument('--color', type = str2bool, default = True, help = 'input type')
parser.add_argument('--burst_length', type = int, default = 1, help = 'number of photos used in burst setting')
parser.add_argument('--blind_est', type = str2bool, default = True, help = 'variance map')
parser.add_argument('--kernel_size', type = str2bool, default = [3], help = 'kernel size')
parser.add_argument('--sep_conv', type = str2bool, default = False, help = 'simple output type')
parser.add_argument('--channel_att', type = str2bool, default = False, help = 'channel wise attention')
parser.add_argument('--spatial_att', type = str2bool, default = False, help = 'spatial wise attention')
parser.add_argument('--upMode', type = str, default = 'bilinear', help = 'upMode')
parser.add_argument('--core_bias', type = str2bool, default = False, help = 'core_bias')
parser.add_argument('--init_type', type = str, default = 'xavier', help = 'initialization type of generator')
parser.add_argument('--init_gain', type = float, default = 0.02, help = 'initialization gain of generator')
# Dataset parameters
parser.add_argument('--baseroot', type = str, default = './rainy_image_dataset/rain100H/train/', help = 'images baseroot')
parser.add_argument('--rainaug', type = str2bool, default = False, help = 'true for using rainaug')
parser.add_argument('--crop_size', type = int, default = 256, help = 'single patch size')
parser.add_argument('--geometry_aug', type = str2bool, default = False, help = 'geometry augmentation (scaling)')
parser.add_argument('--angle_aug', type = str2bool, default = False, help = 'geometry augmentation (rotation, flipping)')
parser.add_argument('--scale_min', type = float, default = 1, help = 'min scaling factor')
parser.add_argument('--scale_max', type = float, default = 1, help = 'max scaling factor')
parser.add_argument('--mu', type = int, default = 0, help = 'Gaussian noise mean')
parser.add_argument('--sigma', type = int, default = 30, help = 'Gaussian noise variance: 30 | 50 | 70')

opt = parser.parse_args(args=[])
print(opt)

Namespace(angle_aug=False, b1=0.5, b2=0.999, baseroot='./rainy_image_dataset/training', blind_est=True, burst_length=1, channel_att=False, color=True, core_bias=False, crop_size=256, cudnn_benchmark=True, epochs=100, geometry_aug=False, gpu_ids='0, 1, 2, 3', init_gain=0.02, init_type='xavier', kernel_size=[3], load_name='', lr_decrease_epoch=20, lr_g=0.0002, mu=0, multi_gpu=False, no_gpu=False, num_workers=8, rainaug=False, sample_path='./samples', save_by_epoch=10, save_by_iter=100000, save_mode='epoch', save_path='./models_k9_loss14_ft', scale_max=1, scale_min=1, sep_conv=False, sigma=30, spatial_att=False, train_batch_size=16, upMode='bilinear', weight_decay=0)


In [None]:


# cudnn benchmark
cudnn.benchmark = opt.cudnn_benchmark

# configurations
save_folder = opt.save_path
sample_folder = opt.sample_path
utils.check_path(save_folder)
utils.check_path(sample_folder)