In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fastai.transforms import TfmType
from fasterai.transforms import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.callbacks import *
from fasterai.loss import *
from fasterai.modules import *
from fasterai.wgan import *
from fasterai.generators import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(0)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [None]:
#TEST NOTES:  Replacing batchnorm with instance norm; Adding "shock absorbing" training sessions between size changes.
IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')
OPENIMAGES = Path('data/openimages')
CIFAR10 = Path('data/cifar10/train')

proj_id = 'bwc_withattn_sn_supertrain'
TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id + '_cont')

gpath = IMAGENET.parent/('bwc_withattn_sn_supertrain_gen_64.h5')
dpath = IMAGENET.parent/('bwc_withattn_sn_supertrain_critic_64.h5')

c_lr=1e-3
c_lrs = np.array([c_lr,c_lr,c_lr])

g_lr=c_lr/5
g_lrs = np.array([g_lr/100,g_lr/10,g_lr])

keep_pcts=[0.50,0.50]
gen_freeze_tos=[-1,0]

sn=True
self_attention=sn
lrs_unfreeze_factor=0.25

x_tfms = [BlackAndWhiteTransform()]
torch.backends.cudnn.benchmark=True

## Training

In [None]:
netG = Unet34(nf_factor=1, self_attention=self_attention, sn=sn, leakyReLu=False).cuda()
#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')
load_model(netG, gpath)

netD = DCCritic(ni=3, nf=128, scale=16, self_attention=self_attention, sn=sn).cuda()
#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')
load_model(netD, dpath)

In [None]:
trainer = WGANTrainer(netD=netD, netG=netG, genloss_fns=[FeatureLoss(multiplier=1e2)], sn=sn)
trainerVis = WganVisualizationHook(TENSORBOARD_PATH, trainer, 'trainer', jupyter=False, visual_iters=100)

In [None]:
scheds=[]

#scheds.extend(WGANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    #save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=OPENIMAGES, x_tfms=x_tfms, keep_pcts=[1.0,1.0], 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[96], bss=[32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[96,96], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[96,96], bss=[32,32], path=OPENIMAGES, x_tfms=x_tfms, keep_pcts=[1.0,1.0], 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))


#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=OPENIMAGES, x_tfms=x_tfms, keep_pcts=[1.0,1.0], 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))



#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/20, g_lrs=g_lrs/20, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.5], 
    save_base_name=proj_id, c_lrs=c_lrs/2, g_lrs=g_lrs/2, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=OPENIMAGES, x_tfms=x_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/2, g_lrs=g_lrs/2, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.5], 
    save_base_name=proj_id, c_lrs=c_lrs/6, g_lrs=g_lrs/6, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=OPENIMAGES, x_tfms=x_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/6, g_lrs=g_lrs/6, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))


#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/120, g_lrs=g_lrs/120, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.5], 
    save_base_name=proj_id, c_lrs=c_lrs/12, g_lrs=g_lrs/12, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=OPENIMAGES, x_tfms=x_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/12, g_lrs=g_lrs/12, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[4], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.5], 
    save_base_name=proj_id, c_lrs=c_lrs/24, g_lrs=g_lrs/24, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[4], path=OPENIMAGES, x_tfms=x_tfms, keep_pcts=[1.0], 
    save_base_name=proj_id, c_lrs=c_lrs/24, g_lrs=g_lrs/24, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))


In [None]:
trainer.train(scheds=scheds)

 12%|█▏        | 3600/30938 [12:06<1:07:50,  6.72it/s]
WDist 1.6886177062988281; RScore 1.1205977201461792; FScore 0.5680199265480042; GAddlLoss [3.88616]; Iters: 17270; GCost: 0.20145948231220245; GPenalty: [0]; ConPenalty: [0]
 12%|█▏        | 3620/30938 [12:11<1:19:42,  5.71it/s]
WDist 2.0627074241638184; RScore 1.41167414188385; FScore 0.6510331630706787; GAddlLoss [3.76019]; Iters: 17280; GCost: 0.3240566551685333; GPenalty: [0]; ConPenalty: [0]
 12%|█▏        | 3640/30938 [12:14<1:13:57,  6.15it/s]
WDist 1.9444940090179443; RScore 1.1564162969589233; FScore 0.7880776524543762; GAddlLoss [3.82534]; Iters: 17290; GCost: 0.4683200716972351; GPenalty: [0]; ConPenalty: [0]
 12%|█▏        | 3660/30938 [12:18<1:14:23,  6.11it/s]
WDist 1.7119567394256592; RScore 0.9997768402099609; FScore 0.7121798396110535; GAddlLoss [3.38173]; Iters: 17300; GCost: 0.4247523546218872; GPenalty: [0]; ConPenalty: [0]
 12%|█▏        | 3680/30938 [12:22<57:27,  7.91it/s]  
WDist 1.8268730640411377; RScore 1