In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from fasterai.dataset import *
from fasterai.visualize import *
from fasterai.tensorboard import *
from fasterai.loss import *
from fasterai.critics import *
from fasterai.generators import *
from pathlib import Path
from itertools import repeat
torch.cuda.set_device(2)
plt.style.use('dark_background')
torch.backends.cudnn.benchmark=True


In [None]:
IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
BWIMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')

proj_id = 'colorizeV5o'
TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')
dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')

torch.backends.cudnn.benchmark=True

In [None]:
def decolorize(fn:str, i:int):
    dest = BWIMAGENET/fn.relative_to(IMAGENET)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    img.save(dest)  

Uncomment the first time you run this notebook.

In [None]:
#il = ImageItemList.from_folder(IMAGENET/'val')
#parallel(decolorize, il.items, max_workers=16)

In [None]:
#il = ImageItemList.from_folder(IMAGENET/'train')
#parallel(decolorize, il.items, max_workers=16)

In [None]:
def get_data(sz:int, bs:int, keep_pct:float):
    return get_colorize_data(sz=sz, bs=bs, crappy_path=BWIMAGENET, good_path=IMAGENET, 
                             random_seed=None, keep_pct=keep_pct,num_workers=16)

In [None]:
def save():
    learn_gen.save(proj_id + '_gen_' + str(sz))
    learn_crit.save(proj_id + '_crit_' + str(sz))

In [None]:
def load():
    learn_gen.load(proj_id + '_gen_' + str(sz))
    learn_crit.load(proj_id + '_crit_' + str(sz))

In [None]:
def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):
    return unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)

## Training

In [None]:
#Needed to instantiate critic but not actually used
sz=64
bs=32

data = get_data(sz=sz, bs=bs, keep_pct=1.0)
learn_crit = colorize_crit_learner(data=data, nf=256)
learn_crit.unfreeze()

gen_loss = FeatureLoss2(gram_wgt=5e3)
learn_gen = colorize_gen_learner_exp(data=data)

switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.0), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=1e-3)

learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))

lr=1e-4
unfreeze_fctr=0.05

## 64px

In [None]:
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save()

In [None]:
learn_gen.unfreeze()
learn.fit(1,lr*unfreeze_fctr)

In [None]:
save()

## 96px

In [None]:
load()
lr=lr/2
sz=96
#bs=bs//2

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)
learn_gen.freeze_to(-1)
learn.fit(1,lr/10)

In [None]:
save()

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)

In [None]:
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save()

In [None]:
learn_gen.unfreeze()
learn.fit(1,lr*unfreeze_fctr)

In [None]:
save()

## 128px

In [None]:
lr=lr/2
sz=128
bs=bs//2

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)
learn_gen.freeze_to(-1)
learn.fit(1,lr/10)

In [None]:
save()

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)

In [None]:
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save()

In [None]:
learn_gen.unfreeze()
learn.fit(1,lr*unfreeze_fctr)

In [None]:
save()

## 160px

In [None]:
lr=lr/1.5
sz=160
bs=int(bs//1.5)

In [None]:
bs=10

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)
learn_gen.freeze_to(-1)
learn.fit(1,lr/10)

In [None]:
save()

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)

In [None]:
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save()

In [None]:
learn_gen.unfreeze()
learn.fit(1,lr*unfreeze_fctr)

In [None]:
save()

## 192px

In [None]:
lr=lr/1.5
sz=192
bs=int(bs//1.5)

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)
learn_gen.freeze_to(-1)
learn.fit(1,lr/10)

In [None]:
save()

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)

In [None]:
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save()

In [None]:
learn_gen.unfreeze()
learn.fit(1,lr*unfreeze_fctr)

In [None]:
save()

## 224px

In [None]:
lr=lr/1.5
sz=224
bs=int(bs//1.5)

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.1)
learn_gen.freeze_to(-1)
learn.fit(1,lr/10)

In [None]:
save()

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.25)

In [None]:
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save()

In [None]:
learn_gen.unfreeze()
learn.fit(1,lr*unfreeze_fctr)

In [None]:
save()