## Pretrained GAN

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0' 

In [None]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from fasterai.generators import *
from fasterai.critics import *
from fasterai.dataset import *
from fasterai.loss import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile

## Setup

In [None]:
path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
path_hr = path
path_lr = path/'bandw'

proj_id = 'ColorizeNew73'
gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

name_gen = proj_id + '_image_gen'
path_gen = path/name_gen

TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

nf_factor = 2

In [None]:
def save_all(suffix=''):
    learn_gen.save(gen_name + str(sz) + suffix)
    learn_crit.save(crit_name + str(sz) + suffix)

In [None]:
def load_all(suffix=''):
    learn_gen.load(gen_name + str(sz) + suffix, with_opt=False)
    learn_crit.load(crit_name + str(sz) + suffix, with_opt=False)

In [None]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, 
                             random_seed=None, keep_pct=keep_pct)

In [None]:
def get_crit_data(classes, bs, sz):
    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
           .databunch(bs=bs).normalize(imagenet_stats))
    return data

In [None]:
def crappify(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    img.save(dest)  

In [None]:
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1

In [None]:
def save_gen_images(learn_gen):
    if path_gen.exists(): shutil.rmtree(path_gen)
    path_gen.mkdir(exist_ok=True)
    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)
    save_preds(data_gen.fix_dl)
    PIL.Image.open(path_gen.ls()[0])

## Crappified data

Prepare the input data by crappifying images.

Uncomment the first time you run this notebook.

In [None]:
#il = ImageItemList.from_folder(path_hr)
#parallel(crappify, il.items)

# Pre-training

### Pre-train generator

Now let's pretrain the generator.

In [None]:
bs=88
sz=64
keep_pct=1.0

In [None]:
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)

In [None]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))

In [None]:
learn_gen.fit_one_cycle(2, pct_start=0.8, max_lr=slice(1e-3))

In [None]:
learn_gen.save(gen_name)

In [None]:
learn_gen.load(gen_name, with_opt=False)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.fit_one_cycle(2, pct_start=0.01,  max_lr=slice(3e-7, 3e-4))

In [None]:
learn_gen.save(gen_name)

In [None]:
bs=20
sz=128
keep_pct=1.0

In [None]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.load(gen_name, with_opt=False)

In [None]:
learn_gen.fit_one_cycle(2, pct_start=0.01, max_lr=slice(1e-7,1e-4))

In [None]:
learn_gen.save(gen_name)

In [None]:
learn_gen.load(gen_name, with_opt=False)

In [None]:
bs=8
sz=192
keep_pct=0.50

In [None]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.fit_one_cycle(1, pct_start=0.01, max_lr=slice(5e-8,5e-5))

In [None]:
learn_gen.save(gen_name)

### Save generated images

In [None]:
save_gen_images(gen_name)

### Train critic

Pretrain the critic on crappy vs not crappy.

In [None]:
bs=64
sz=128

In [None]:
learn_gen=None
gc.collect()

In [None]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

In [None]:
learn_critic = colorize_crit_learner(data=data_crit, nf=256)

In [None]:
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))

In [None]:
learn_critic.fit_one_cycle(6, 1e-3)

In [None]:
learn_critic.save(crit_name)

In [None]:
bs=16
sz=192

In [None]:
learn_critic.data=get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
learn_critic.fit_one_cycle(4, 1e-4)

In [None]:
learn_critic.save(crit_name)

## GAN

Now we'll combine those pretrained model in a GAN.

In [None]:
learn_crit=None
learn_gen=None
gc.collect()

In [None]:
lr=1e-5
sz=192
bs=5

In [None]:
#placeholder- not actually used
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name, with_opt=False)

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_name, with_opt=False)

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))

In [None]:
for i in range(1,101):
    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)
    learn_gen.freeze_to(-1)
    learn.fit(1,lr)
    save_all('_03_' + str(i))

In [None]:
save_all('_01')

### Save Generated Images Again

In [None]:
bs=8
sz=192

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load('ColorizeNew73_gen192_05_7', with_opt=False)

In [None]:
save_gen_images(gen_name)

### Train Critic Again

In [None]:
bs=16
sz=192

In [None]:
learn_gen=None
gc.collect()

In [None]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

In [None]:
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '5', with_opt=False)

In [None]:
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))

In [None]:
learn_critic.fit_one_cycle(4, 1e-4)

In [None]:
learn_critic.save(crit_name + '6')

In [None]:
learn_critic.load(crit_name + '6', with_opt=False)

In [None]:
learn_critic.fit_one_cycle(4, 1e-5)

In [None]:
learn_critic.save(crit_name + '6')

### GAN Again

In [None]:
learn_crit=None
learn_gen=None
gc.collect()

In [None]:
lr=1e-5
sz=192
bs=5

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '6', with_opt=False)

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load('ColorizeNew73_gen192_05_7', with_opt=False)

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))

In [None]:
for i in range(1,101):
    learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)
    learn_gen.freeze_to(-1)
    learn.fit(1,lr)
    save_all('_06b_' + str(i))

## fin