In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fastai.transforms import TfmType
from fasterai.transforms import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.callbacks import *
from fasterai.loss import *
from fasterai.modules import *
from fasterai.wgan import *
from fasterai.generators import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(0)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [3]:
#TEST NOTES:  Replacing batchnorm with instance norm; Adding "shock absorbing" training sessions between size changes.
IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')
OPENIMAGES = Path('data/openimages')
CIFAR10 = Path('data/cifar10/train')

proj_id = 'bwc_withattn_sn_supertrain_retry'
TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

#gpath = IMAGENET.parent/('bwc_withattn_sn_supertrain_retry_gen_96.h5')
#dpath = IMAGENET.parent/('bwc_withattn_sn_supertrain_retry_critic_96.h5')

c_lr=1e-3
c_lrs = np.array([c_lr,c_lr,c_lr])

g_lr=c_lr/5
g_lrs = np.array([g_lr/100,g_lr/10,g_lr])

keep_pcts=[1.0,1.0]
gen_freeze_tos=[-1,0]

sn=True
self_attention=sn
lrs_unfreeze_factor=0.25

x_tfms = [BlackAndWhiteTransform()]
extra_aug_tfms = [RandomLighting(0.1, 0.1, tfm_y=TfmType.PIXEL)]
torch.backends.cudnn.benchmark=True

## Training

In [4]:
netG = Unet34(nf_factor=1, self_attention=self_attention, sn=sn, leakyReLu=False).cuda()
#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')
#load_model(netG, gpath)

netD = DCCritic(ni=3, nf=128, scale=16, self_attention=self_attention, sn=sn).cuda()
#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')
#load_model(netD, dpath)

In [5]:
trainer = WGANTrainer(netD=netD, netG=netG, genloss_fns=[FeatureLoss(multiplier=1e2)], sn=sn)
trainerVis = WganVisualizationHook(TENSORBOARD_PATH, trainer, 'trainer', jupyter=False, visual_iters=100)

In [6]:
scheds=[]

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[96], bss=[32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.3], 
    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[96,96], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))



#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.3], 
    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))


#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.3], 
    save_base_name=proj_id, c_lrs=c_lrs/20, g_lrs=g_lrs/20, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/2, g_lrs=g_lrs/2, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/6, g_lrs=g_lrs/6, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))



#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.3], 
    save_base_name=proj_id, c_lrs=c_lrs/120, g_lrs=g_lrs/120, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/12, g_lrs=g_lrs/12, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/24, g_lrs=g_lrs/24, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))


In [None]:
trainer.train(scheds=scheds)

 67%|██████▋   | 26360/39637 [4:35:20<2:45:30,  1.34it/s]
WDist 1.9326139688491821; RScore 1.2068735361099243; FScore 0.7257404327392578; GAddlLoss [2.9874]; Iters: 13180; GCost: 0.35216811299324036; GPenalty: [0]; ConPenalty: [0]
 67%|██████▋   | 26380/39637 [4:35:27<1:03:07,  3.50it/s]
WDist 1.8195528984069824; RScore 1.1342308521270752; FScore 0.6853220462799072; GAddlLoss [2.8341]; Iters: 13190; GCost: 0.35640522837638855; GPenalty: [0]; ConPenalty: [0]
 67%|██████▋   | 26400/39637 [4:35:34<41:32,  5.31it/s]
WDist 1.6058340072631836; RScore 1.0532474517822266; FScore 0.552586555480957; GAddlLoss [2.88399]; Iters: 13200; GCost: 0.46259281039237976; GPenalty: [0]; ConPenalty: [0]
 67%|██████▋   | 26420/39637 [4:35:48<3:23:30,  1.08it/s]
WDist 1.8115403652191162; RScore 1.0562310218811035; FScore 0.7553094029426575; GAddlLoss [2.84043]; Iters: 13210; GCost: 0.1511613130569458; GPenalty: [0]; ConPenalty: [0]
 67%|██████▋   | 26440/39637 [4:35:55<1:20:00,  2.75it/s]
WDist 1.800656318664