In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fastai.transforms import TfmType
from fasterai.transforms import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.callbacks import *
from fasterai.loss import *
from fasterai.modules import *
from fasterai.training import *
from fasterai.generators import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(1)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [3]:
IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')
proj_id = 'colorize'
TENSORBOARD_PATH = Path('data/tensorboard/')

gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')
dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')

c_lr=5e-4
c_lrs = np.array([c_lr,c_lr,c_lr])

g_lr=c_lr/5
g_lrs = np.array([g_lr/1000,g_lr/100,g_lr])

keep_pcts=[0.25,0.25]
gen_freeze_tos=[-1,0]
lrs_unfreeze_factor=1.0
x_tfms = [BlackAndWhiteTransform()]
extra_aug_tfms = [RandomLighting(0.1, 0.1, tfm_y=TfmType.PIXEL)]
torch.backends.cudnn.benchmark=True

## Training

In [4]:
netG = Unet34(nf_factor=2).cuda()
#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')
#load_model(netG, gpath)

netD = DCCritic(ni=3, nf=512).cuda()
#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')
#load_model(netD, dpath)

In [5]:
trainer = GANTrainer(netD=netD, netG=netG, genloss_fns=[FeatureLoss(multiplier=1e2)])
trainerVis = GANVisualizationHook(TENSORBOARD_PATH, trainer, 'trainer', jupyter=False, visual_iters=100)

In [6]:
scheds=[]

scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

#unshock
scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], 
    save_base_name=proj_id, c_lrs=c_lrs/20, g_lrs=g_lrs/20, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[16,16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs/2, g_lrs=g_lrs/2, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))


#unshock
scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], 
    save_base_name=proj_id, c_lrs=c_lrs/30, g_lrs=g_lrs/30, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs/3, g_lrs=g_lrs/3, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))


#unshock
scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], 
    save_base_name=proj_id, c_lrs=c_lrs/80, g_lrs=g_lrs/80, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/8, g_lrs=g_lrs/8, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))



#unshock
scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], 
    save_base_name=proj_id, c_lrs=c_lrs/160, g_lrs=g_lrs/160, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))


In [None]:
trainer.train(scheds=scheds)