In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import multiprocessing
from torch import autograd
from fastai.conv_learner import *
from fastai.transforms import TfmType
from fasterai.transforms import *
from fasterai.images import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.callbacks import *
from fasterai.loss import *
from fasterai.modules import *
from fasterai.wgan import *
from fasterai.generators import *
from fastai.torch_imports import *
from pathlib import Path
from itertools import repeat
import tensorboardX
torch.cuda.set_device(0)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [3]:
#TEST NOTES:  Replacing batchnorm with instance norm; Adding "shock absorbing" training sessions between size changes.
IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')
CIFAR10 = Path('data/cifar10/train')

proj_id = 'bwc_withattn_sn_small'
TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

c_lr=1e-3
c_lrs = np.array([c_lr,c_lr,c_lr])

g_lr=c_lr/5
g_lrs = np.array([g_lr/100,g_lr/10,g_lr])

keep_pcts=[0.50,0.50]
gen_freeze_tos=[-1,0]

sn=True
self_attention=sn
lrs_unfreeze_factor=0.25

x_tfms = [BlackAndWhiteTransform()]
torch.backends.cudnn.benchmark=True

## Training

In [4]:
netG = Unet34(nf_factor=1, self_attention=self_attention, sn=sn, leakyReLu=False).cuda()
#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')
#load_model(netG, gpath)

netD = DCCritic(ni=3, nf=128, scale=16, self_attention=self_attention, sn=sn).cuda()
#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')
#load_model(netD, dpath)

In [5]:
trainer = WGANTrainer(netD=netD, netG=netG, genloss_fns=[FeatureLoss(multiplier=1e2)], sn=sn)
trainerVis = WganVisualizationHook(TENSORBOARD_PATH, trainer, 'trainer', jupyter=False, visual_iters=100)

In [6]:
scheds=[]

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[96], bss=[32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[96,96], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192], bss=[16], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/40, g_lrs=g_lrs/40, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[192,192], bss=[16,16], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs/4, g_lrs=g_lrs/4, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))

#unshock
scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256], bss=[8], path=IMAGENET, x_tfms=x_tfms, keep_pcts=[0.25], 
    save_base_name=proj_id, c_lrs=c_lrs/40, g_lrs=g_lrs/40, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))

scheds.extend(WGANTrainSchedule.generate_schedules(szs=[256,256], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, keep_pcts=keep_pcts, 
    save_base_name=proj_id, c_lrs=c_lrs/4, g_lrs=g_lrs/4, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))


In [None]:
trainer.train(scheds=scheds)

 83%|████████▎ | 16302/19618 [1:31:27<10:44,  5.14it/s]
WDist 2.0228359699249268; RScore 1.2061773538589478; FScore 0.816658616065979; GAddlLoss [2.95351]; Iters: 17970; GCost: 0.32711493968963623; GPenalty: [0]; ConPenalty: [0]
 83%|████████▎ | 16322/19618 [1:31:37<1:08:29,  1.25s/it]
WDist 1.9103670120239258; RScore 1.1418077945709229; FScore 0.7685591578483582; GAddlLoss [3.07316]; Iters: 17980; GCost: 0.4024626612663269; GPenalty: [0]; ConPenalty: [0]
 83%|████████▎ | 16342/19618 [1:31:43<20:21,  2.68it/s]
WDist 2.034581184387207; RScore 1.1565803289413452; FScore 0.8780007362365723; GAddlLoss [2.8762]; Iters: 17990; GCost: 0.32540178298950195; GPenalty: [0]; ConPenalty: [0]
 83%|████████▎ | 16362/19618 [1:31:48<12:33,  4.32it/s]
WDist 2.0167198181152344; RScore 1.164120078086853; FScore 0.8525998592376709; GAddlLoss [3.14497]; Iters: 18000; GCost: 0.21575692296028137; GPenalty: [0]; ConPenalty: [0]
 84%|████████▎ | 16382/19618 [1:31:59<09:01,  5.97it/s]
WDist 1.828202724456787; RS