## Pretrained GAN

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='2' 

In [None]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
from fasterai.generators import *
from fasterai.critics import *
from fasterai.tensorboard import *
from fasterai.dataset import *
from fasterai.loss import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile

## Setup

In [None]:
path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
path_hr = path
path_lr = path/'bandw'

proj_id = 'ColorizeNew11'
gen_name = proj_id + '_gen'
crit_name = proj_id + '_crit'

name_gen = proj_id + '_image_gen'
path_gen = path/name_gen

TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id )

nf_factor = 1.25

In [None]:
loss_gen = FeatureLoss()

In [None]:
def save_all():
    learn_gen.save(gen_name + str(sz))
    learn_crit.save(crit_name + str(sz))

In [None]:
def load_all():
    learn_gen.load(gen_name + str(sz))
    learn_crit.load(crit_name + str(sz))

In [None]:
def get_data(bs:int, sz:int, keep_pct:float):
    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, 
                             random_seed=None, keep_pct=keep_pct)

In [None]:
def get_crit_data(classes, bs, sz):
    src = ImageItemList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
           .databunch(bs=bs).normalize(imagenet_stats))
    return data

In [None]:
def crappify(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    img.save(dest)  

In [None]:
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1

## Crappified data

Prepare the input data by crappifying images.

Uncomment the first time you run this notebook.

In [None]:
#il = ImageItemList.from_folder(path_hr)
#parallel(crappify, il.items)

## Pre-train generator

Now let's pretrain the generator.

## 128px

In [None]:
bs=32
sz=128
keep_pct=0.1

In [None]:
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)

In [None]:
learn_gen = colorize_gen_learner(data=data_gen, gen_loss=loss_gen, nf_factor=nf_factor)

In [None]:
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))

In [None]:
learn_gen.fit_one_cycle(8, pct_start=0.8)

In [None]:
learn_gen.save(gen_name)

In [None]:
learn_gen.unfreeze()

In [None]:
learn_gen.load(gen_name)

In [None]:
learn_gen.fit_one_cycle(8, slice(1e-6,1e-3))

In [None]:
learn_gen.save(gen_name)

## Save generated images

In [None]:
learn_gen.load(gen_name)

In [None]:
# shutil.rmtree(path_gen)

In [None]:
path_gen.mkdir(exist_ok=True)

In [None]:
save_preds(data_gen.fix_dl)

In [None]:
PIL.Image.open(path_gen.ls()[0])

## Train critic

Pretrain the critic on crappy vs not crappy.

In [None]:
learn_gen=None
gc.collect()

In [None]:
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

In [None]:
learn_critic = colorize_crit_learner(data=data_crit, nf=256)

In [None]:
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))

In [None]:
learn_critic.fit_one_cycle(10, 1e-3)

In [None]:
learn_critic.save(crit_name)

## GAN

Now we'll combine those pretrained model in a GAN.

In [None]:
learn_crit=None
learn_gen=None
gc.collect()

In [None]:
bs=24
sz=128
lr=8e-5

In [None]:
#placeholder- not actually used
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name)

In [None]:
learn_gen = colorize_gen_learner(data=data_gen, gen_loss=loss_gen, nf_factor=nf_factor).load(gen_name)

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,0.75), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))

In [None]:
learn.data=get_data(sz=sz, bs=bs, keep_pct=0.25)

In [None]:
learn_gen.freeze_to(-1)

In [None]:
learn.fit(1,lr)

In [None]:
save_all()

In [None]:
lr=lr/1.5
sz=160
bs=int(bs//1.5)

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.05)
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save_all()

In [None]:
load_all()

In [None]:
lr=lr/1.5
sz=192
bs=int(bs//1.5)

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.05)
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save_all()

In [None]:
lr=lr/1.5
sz=224
bs=int(bs//1.5)

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.05)
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save_all()

In [None]:
lr=lr/1.75
sz=256
bs=int(bs//1.5)

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.05)
learn_gen.freeze_to(-1)
learn.fit(1,lr)

In [None]:
save_all()

In [None]:
learn.show_results(rows=bs)

## fin