## Stable Model Training

#### NOTES:  
* This is "NoGAN" based training, described in the DeOldify readme.
* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model.

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0' 

In [2]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from fasterai.generators import *
from fasterai.critics import *
from fasterai.dataset import *
from fasterai.loss import *
from fasterai.save import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile
from torch.utils.data.sampler import RandomSampler, SequentialSampler
import wandb
from wandb.fastai import WandbCallback

## Setup

In [3]:
path = Path('/data/Open_Images')
path_hr = path
path_lr = path/'bandw'

proj_id = 'StableModel'

gen_name = proj_id + '_gen'
pre_gen_name = gen_name + '_0'
crit_name = proj_id + '_crit'

name_gen = proj_id + '_image_gen'
path_gen = path/name_gen

TENSORBOARD_PATH = path / ('tensorboard/' + proj_id)

In [4]:
def get_data(bs:int, sz:int, keep_pct:float, random_seed=None, samplers=None):
    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, 
                             random_seed=random_seed, keep_pct=keep_pct, samplers=samplers)

def get_crit_data(classes, bs, sz):
    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
           .databunch(bs=bs).normalize(imagenet_stats))
    return data

def create_training_images(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    img.save(dest)  
    
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1
    
def save_gen_images():
    if path_gen.exists(): shutil.rmtree(path_gen)
    path_gen.mkdir(exist_ok=True)
    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)
    save_preds(data_gen.fix_dl)
    PIL.Image.open(path_gen.ls()[0])

In [5]:
# Reduce quantity of samples per training epoch
# Adapted from https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10

@classmethod
def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,
            val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
            device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, sampler=None, **dl_kwargs)->'DataBunch':
    "Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`"
    datasets = cls._init_ds(train_ds, valid_ds, test_ds)
    val_bs = ifnone(val_bs, bs)
    if sampler is None: sampler = [RandomSampler] + 3*[SequentialSampler]
    dls = [DataLoader(d, b, sampler=sa(d), drop_last=sh, num_workers=num_workers, **dl_kwargs) for d,b,sh,sa in
            zip(datasets, (bs,val_bs,val_bs,val_bs), (True,False,False,False), sampler) if d is not None]
    return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)

ImageDataBunch.create = create
ImageImageList._bunch = ImageDataBunch

class FixedLenRandomSampler(RandomSampler):
    def __init__(self, data_source, epoch_size):
        super().__init__(data_source)
        self.epoch_size = epoch_size
        self.not_sampled = np.array([True]*len(data_source))
    
    @property
    def reset_state(self): self.not_sampled[:] = True
        
    def __iter__(self):
        ns = sum(self.not_sampled)
        idx_last = []
        if ns >= len(self):
            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self), replace=False).tolist()
            if ns == len(self): self.reset_state
        else:
            idx_last = np.where(self.not_sampled)[0].tolist()
            self.reset_state
            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self)-len(idx_last), replace=False).tolist()
        self.not_sampled[idx] = False
        idx = [*idx_last, *idx]
        return iter(idx)
    
    def __len__(self):
        return self.epoch_size

## Create black and white training images

Only runs if the directory isn't already created.

In [6]:
if not path_lr.exists():
    il = ImageList.from_folder(path_hr)
    parallel(create_training_images, il.items)

## Pre-train generator

In [7]:
# W&B config
wandb.init(project="DeOldify")
config = wandb.config  # for shortening
config.epoch_size = 1000
config.nf_factor = 2
config.pct_start = 0.3
config.step1_bs = 12
config.step1_sz = 64
config.step1a_epochs = 10
config.step1a_pct_start = 0.8
config.step1a_lr = 1e-3
config.step1b_epochs = 10
config.step1b_pct_start = 0.3
config.step1b_lr_min = 3e-7
config.step1b_lr_max = 3e-4
config.step2_bs = 2
config.step2_sz = 128
config.step2_epochs = 10
config.step2_pct_start = 0.3
config.step2_lr_min = 1e-7
config.step2_lr_max = 1e-4
config.step3_bs = 1
config.step3_sz = 192
config.step3_epochs = 10
config.step3_pct_start = 0.3
config.step3_lr_min = 5e-8
config.step3_lr_max = 5e-5
random_seed = 1

# Load data
train_sampler = partial(FixedLenRandomSampler, epoch_size=config.epoch_size // config.step1_bs * config.step1_bs)
samplers = [train_sampler, SequentialSampler, SequentialSampler, SequentialSampler]

#### NOTE
Most of the training takes place here in pretraining for NoGAN.  The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training.

### 64px

In [8]:
data_gen = get_data(bs=config.step1_bs, sz=config.step1_sz, keep_pct=1., random_seed=random_seed, samplers=samplers)
print(data_gen)

ImageDataBunch;

Train: LabelList (907950 items)
x: ImageImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
y: ImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
Path: /data/Open_Images/bandw;

Valid: LabelList (4562 items)
x: ImageImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
y: ImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
Path: /data/Open_Images/bandw;

Test: None


In [9]:
learn_gen=gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=config.nf_factor)
learn_gen.callback_fns.append(partial(WandbCallback, input_type='images'))  # log prediction samples

In [None]:
%%wandb
# log nicely results
learn_gen.fit_one_cycle(config.step1a_epochs, pct_start=config.step1a_pct_start, max_lr=slice(config.step1a_lr))

epoch,train_loss,valid_loss,time
0,5.096366,4.463193,06:07


Better model found at epoch 0 with valid_loss value: 4.463193416595459.


In [14]:
learn_gen.save(pre_gen_name)

In [15]:
learn_gen.unfreeze()

In [16]:
learn_gen.fit_one_cycle(10, pct_start=pct_start,  max_lr=slice(3e-7, 3e-4))

epoch,train_loss,valid_loss,time
0,2.898474,2.964967,19:23
1,2.934219,2.969269,19:27
2,2.868092,2.972064,19:33
3,2.847696,3.041619,19:34
4,2.869638,3.048059,19:42
5,2.898197,3.054916,19:39
6,2.900482,2.987764,19:39
7,2.847276,3.004486,19:35
8,2.829326,3.002153,19:41
9,2.848644,2.992971,19:40




In [17]:
learn_gen.save(pre_gen_name)

### 128px

In [18]:
bs=1
sz=128
keep_pct=1.0

In [19]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [20]:
learn_gen.unfreeze()

In [21]:
learn_gen.fit_one_cycle(10, pct_start=pct_start, max_lr=slice(1e-7,1e-4))

epoch,train_loss,valid_loss,time
0,2.235295,5.112829,1:59:08
1,2.18263,4.481052,1:59:54
2,2.306794,4.355889,2:00:04
3,2.189272,4.601309,2:00:06
4,2.229708,5.483531,2:00:29
5,2.273671,6.196134,2:00:19
6,2.298797,5.458822,2:00:34
7,2.278257,4.206965,2:00:50
8,2.098526,4.536943,2:01:05
9,2.202603,4.300625,2:00:48




In [22]:
learn_gen.save(pre_gen_name)

### 192px

In [23]:
bs=1
sz=192
keep_pct=1.0

In [24]:
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

In [25]:
learn_gen.unfreeze()

In [26]:
learn_gen.fit_one_cycle(10, pct_start=pct_start, max_lr=slice(5e-8,5e-5))

epoch,train_loss,valid_loss,time
0,1.977535,3.874527,2:24:57
1,2.005705,2.650965,2:25:58
2,1.898938,2.315888,2:25:52
3,2.039177,3.111342,2:26:17




KeyboardInterrupt: 

In [None]:
learn_gen.save(pre_gen_name)

## Repeatable GAN Cycle

#### NOTE
Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality).  Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.  

In [None]:
old_checkpoint_num = 0
checkpoint_num = old_checkpoint_num + 1
gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)
gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)
crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)
crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)

### Save Generated Images

In [None]:
bs=1
sz=192

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)

In [None]:
save_gen_images()

### Pretrain Critic

##### Only need full pretraining of critic when starting from scratch.  Otherwise, just finetune!

In [None]:
if old_checkpoint_num == 0:
    bs=12
    sz=128
    learn_gen=None
    gc.collect()
    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
    learn_critic = colorize_crit_learner(data=data_crit, nf=256)
    learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))
    learn_critic.fit_one_cycle(6, 1e-3)
    learn_critic.save(crit_old_checkpoint_name)

In [None]:
bs=4
sz=192

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)

In [None]:
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)

In [None]:
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))

In [None]:
learn_critic.fit_one_cycle(4, 1e-4)

In [None]:
learn_critic.save(crit_new_checkpoint_name)

### GAN

In [None]:
learn_crit=None
learn_gen=None
gc.collect()

In [None]:
lr=2e-5
sz=192
bs=1

In [None]:
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)

In [None]:
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)

In [None]:
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)

In [None]:
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))
learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))

#### Instructions:  
Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct.

In [None]:
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)
learn_gen.freeze_to(-1)
learn.fit(1,lr)