In [None]:
from dataclasses import dataclass
from fastai.vision.all import *
from fastai.vision.gan import *
from functools import partial
import numpy as np 
import pandas as pd
from pathlib import Path
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable

In [None]:
run_as_standalone_nb = True

In [None]:
if run_as_standalone_nb:
    root_lib_path = Path('face2anime').resolve()
    if not root_lib_path.exists():
        !git clone https://github.com/davidleonfdez/face2anime.git
    if str(root_lib_path) not in sys.path:
        sys.path.insert(0, str(root_lib_path))
else:
    import local_lib_import

In [None]:
from face2anime.gen_utils import is_iterable
from face2anime.layers import ConcatPoolHalfDownsamplingOp2d, ConvHalfDownsamplingOp2d, TransformsLayer
from face2anime.losses import (ContentLossCallback, CritPredsTracker, CycleConsistencyLossCallback, 
                               CycleGANLoss, IdentityLossCallback, LossWrapper, MultiCritPredsTracker,
                               R1GANGPCallback)
from face2anime.misc import FeaturesCalculator
from face2anime.networks import (CycleCritic, CycleGenerator, default_decoder, default_encoder, Img2ImgGenerator, 
                                 patch_res_critic, res_critic)
from face2anime.train_utils import (add_ema_to_gan_learner, custom_load_model,
                                    custom_save_model)
from face2anime.transforms import AdaptiveAugmentsCallback, ADATransforms

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_size = 64
n_channels = 3
bs = 64
save_cycle_len = 5

# Data

## Target ds

animecharacterfaces, by Kaggle user *aadilmalik94*

In [None]:
anime_ds_path = Path('/kaggle/input/animecharacterfaces/animeface-character-dataset/data').resolve()

## Input ds

In [None]:
celeba_path = Path('/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba')
#input_fns = get_image_files(celeba_path)
# get_image_files is too slow, there's no need to check the extension here
input_a_fns = celeba_path.ls()
input_a_fns

The ds path passed to `dblock.dataloaders()` or `ImageDataLoaders.from_dblock()` will be forwarded
to `get_items`, which will return a list of items, usually a list of image paths if `get_items=get_image_files`.

So, for each item, we are expected to receive a filename `fn` and be able to
derive x and y from it, with `get_x(fn)` and `get_y(fn)`.

For bidirectional unpaired image to image translation, we can:
* Use the domain B ds path as the DataBlock `source`. 
* Load independently the filenames of the domain A ds; let's call it `input_a_fns`
* `get_y` needs two functions:
  * The first one can just return the path received (domain B).
  * The second returns a random item from `input_a_fns` (domain A).
* `get_x` also needs two functions:
  * The first one return a random item from `input_a_fns` (domain A).
  * The second can just return the path received (domain B).
* `get_x` and `get_y` are called every time a data item is used; so, by using random, we can be sure every x is not tied to a fixed y; i.e., they won't be together in the same (x, y) batch every epoch for loss calculation.
* In this case, domain A: human faces; domain B: anime faces.


---

In [None]:
def get_random_fn_a(fn):
    return input_a_fns[random.randint(0, len(input_a_fns)-1)]


normalize_tf = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))


def get_dblock(extra_batch_tfms=None):
    if extra_batch_tfms is None: extra_batch_tfms = []    
    return DataBlock(blocks=(ImageBlock, ImageBlock, ImageBlock, ImageBlock),
                     get_x=[get_random_fn_a, noop],
                     get_y=[noop, get_random_fn_a],
                     get_items=get_image_files,
                     #get_items=lambda path: target_fns,
                     splitter=IndexSplitter([]),
                     item_tfms=Resize(img_size, method=ResizeMethod.Crop), 
                     batch_tfms=[normalize_tf] + extra_batch_tfms,
                     n_inp=2)


dblock = get_dblock()
main_path = anime_ds_path
dls = dblock.dataloaders(main_path, path=main_path, bs=bs)

In [None]:
# It doesn't work for now
# dls.show_batch()

In [None]:
sample_batch = dls.one_batch()
titles = ['x1 (A)', 'x2 (B)', 'y1 (B)', 'y2 (A)']
_, axs = plt.subplots(1, 4)
for t, ax, title in zip(sample_batch, axs, titles):
    normalize_tf.decode(t)[0].show(ax=ax, title=title)

# Loss function utils

In [None]:
vgg_content_layers_idx = [22]
ftrs_calc = FeaturesCalculator([], vgg_content_layers_idx, device=device,
                               input_norm_tf=normalize_tf)

------------------

# Training

In [None]:
def predict_n(learner, n_imgs, max_bs=64):
    dummy_path = Path('.')
    items = learner.dls.train.items
    if len(items) < n_imgs: 
        items = list(itertools.islice(itertools.cycle(items), n_imgs))
    dl = learner.dls.test_dl(items[:n_imgs], bs=max_bs)   
    inp, imgs_t, _, dec_imgs_t = learner.get_preds(dl=dl, with_input=True, with_decoded=True)
    dec_batch = dls.decode_batch(inp + dec_imgs_t, max_n=n_imgs)
    return dec_batch

def predict_show_n(learner, n_imgs, **predict_n_kwargs):
    preds_batch = predict_n(learner, n_imgs, **predict_n_kwargs)
    _, axs = plt.subplots(n_imgs, 4, figsize=(12, n_imgs * 3))
    for i, (in_a, in_b, pred_a2b, pred_b2a) in enumerate(preds_batch):
        in_a.show(ax=axs[i][0], title='In A')
        pred_a2b.show(ax=axs[i][1], title='Out A->B')
        in_b.show(ax=axs[i][2], title='In B')
        pred_b2a.show(ax=axs[i][3], title='Out B->A')

class SaveCheckpointsCallback(Callback):
    "Callback that saves the model at the end of each epoch."
    def __init__(self, fn_prefix, base_path=Path('.'), initial_epoch=1,
                 save_cycle_len=1):
        self.fn_prefix = fn_prefix
        self.base_path = base_path
        self.epoch = initial_epoch
        self.save_cycle_len = save_cycle_len
        
    def after_epoch(self):
        if (self.epoch % self.save_cycle_len) == 0:
            fn = f'{self.fn_prefix}_{self.epoch}ep'
            custom_save_model(self.learn, fn, base_path=self.base_path)
        self.epoch += 1
        
def save_preds(c_preds_tracker, filepaths):
    return [df.to_csv(filepath) for df, filepath in zip(c_preds_tracker.to_dfs(), filepaths)]
        
def plot_c_preds(c_preds_tracker):
    preds_xs = range(len(c_preds_tracker.real_preds))
    sns.lineplot(x=preds_xs, y=c_preds_tracker.fake_preds.cpu(), label='Fake preds')
    ax=sns.lineplot(x=preds_xs, y=c_preds_tracker.real_preds.cpu(), label='Real preds')
    ax.set_xlabel('Number of batches')
    ax.set_ylabel('Critic preds')
    return ax

def plot_multi_c_preds(multi_c_preds_tracker, titles):
    for c_preds_tracker, title in zip(multi_c_preds_tracker.trackers, titles):
        ax=plot_c_preds(c_preds_tracker)
        plt.legend(title=title)
        plt.figure()

In [None]:
def _forward_batch(model, batch, device):
    input = batch[:2]
    if device is not None:
        for i in range(2): input[i] = input[i].to(device)
    model(*input)


def create_learner(for_inference=False, dblock=dblock, dls=dls, gp_w=10., latent_sz=100, 
                   mid_mlp_depth=2, g_norm=NormType.Instance, n_crit_iters=3,
                   cycle_cons_w=0, id_loss_w=0, use_patch_critic=False):
    leakyReLU02 = partial(nn.LeakyReLU, negative_slope=0.2)
    down_op = ConvHalfDownsamplingOp2d(ks=4, act_cls=leakyReLU02, bn_1st=False,
                                       norm_type=NormType.Batch)
    id_down_op = ConcatPoolHalfDownsamplingOp2d(conv_ks=3, act_cls=None, norm_type=None)   
    crit_args = [img_size, n_channels, down_op, id_down_op]
    if use_patch_critic: crit_args.insert(2, img_size//8)
    crit_kwargs = dict(n_extra_convs_by_res_block=1, act_cls=leakyReLU02, 
                       bn_1st=False, n_features=128, flatten_full=True)
    crit_builder = patch_res_critic if use_patch_critic else res_critic    
    base_critics = [crit_builder(*crit_args, **crit_kwargs) for _ in range(2)]
    base_critics = CycleCritic(*base_critics)
    critic = base_critics
    
    def _decoder_builder(imsz, nch, latsz, hooks_by_sz=None): 
        return default_decoder(imsz, nch, latsz, norm_type=g_norm, hooks_by_sz=hooks_by_sz)
    generators = [Img2ImgGenerator(img_size, n_channels, mid_mlp_depth=mid_mlp_depth, skip_connect=True,
                                   encoder=default_encoder(img_size, n_channels, latent_sz, norm_type=g_norm),
                                   decoder_builder=_decoder_builder)
                  for _ in range(2)]
    generator = CycleGenerator(*generators)
    
    cbs = []
    c_loss_interceptors = []
    metrics = []
    if not for_inference:
        # Pass base_critic to avoid grid_sample 2nd order derivative issue with ada critic
        cbs.append(R1GANGPCallback(weight=gp_w, critic=base_critics))
        if cycle_cons_w > 0: 
            cbs.append(CycleConsistencyLossCallback(generator.g_a2b, 
                                                    generator.g_b2a, 
                                                    weight=cycle_cons_w))
            metrics.append('cycle_loss')
        if id_loss_w > 0:
            cbs.append(IdentityLossCallback(generator.g_a2b, 
                                            generator.g_b2a, 
                                            weight=id_loss_w))
            metrics.append('identity_loss')
        overall_crit_preds_tracker = MultiCritPredsTracker(reduce_batch=True)
        c_loss_interceptors.append(overall_crit_preds_tracker)
        
    def gen_loss_func(*args): return 0
    crit_loss_func = nn.BCEWithLogitsLoss()
    loss_G, loss_C = gan_loss_from_func(gen_loss_func, crit_loss_func)
    loss_C = LossWrapper(loss_C, c_loss_interceptors)
    
    learn = GANLearner(dls, generator, critic, loss_G, loss_C,
                       opt_func=partial(Adam, mom=0., sqr_mom=0.99, wd=0.),
                       cbs=cbs, switcher=FixedGANSwitcher(n_crit=n_crit_iters, n_gen=1),
                       switch_eval=False, metrics=LossMetrics(metrics) or None)
    learn.loss_func = CycleGANLoss(learn.loss_func)
    learn.recorder.train_metrics=True
    learn.recorder.valid_metrics=False
    add_ema_to_gan_learner(learn, dblock, decay=0.999, forward_batch=_forward_batch)
    if not for_inference: learn.crit_preds_tracker = overall_crit_preds_tracker
    return learn

## TR 1: SN+BN critic, BN+SN gen (both encoder and decoder), no extra losses

In [None]:
def create_learner_1(*args, **kwargs):
    return create_learner(*args, **kwargs, gp_w=10., g_norm=NormType.Batch)

In [None]:
learn = create_learner_1()
learn.add_cb(SaveCheckpointsCallback('face2anime_bidir_tr1', initial_epoch=1,
                                     save_cycle_len=save_cycle_len))
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))
lr = 2e-4

In [None]:
learn.model.critic, learn.model.generator

In [None]:
# custom_load_model(learn, 'face2anime_bidir_tr1_50ep', base_path='./models', with_ema=True)
# #preds_df = pd.read_csv(io.StringIO(preds_csv_str), index_col=0)
# #learn.crit_preds_tracker.load_from_df(preds_df, device)
# with learn.removed_cbs([learn.save_checkpoints]) as displayable_learn:
#     displayable_learn.show_results(ds_idx=0)

In [None]:
learn.fit(100, lr)
with learn.removed_cbs([learn.save_checkpoints]) as displayable_learn:
    displayable_learn.show_results(ds_idx=0)

In [None]:
ema_g_learn.show_results(ds_idx=0)

In [None]:
plot_multi_c_preds(learn.crit_preds_tracker, ['A->B', 'B->A'])

In [None]:
save_preds(learn.crit_preds_tracker, 
           [Path('crit_preds_face2anime_bidir_a2b_tr1_100ep.csv'),
            Path('crit_preds_face2anime_bidir_b2a_tr1_100ep.csv')])

## TR 2: SN+BN critic, BN+SN gen (both encoder and decoder), cycle consistency loss, identity loss

In [None]:
def create_learner_2(*args, **kwargs):
    return create_learner(*args, **kwargs, gp_w=10., g_norm=NormType.Batch, 
                          cycle_cons_w=1., id_loss_w=1.)

In [None]:
learn = create_learner_2()
learn.add_cb(SaveCheckpointsCallback('face2anime_bidir_tr2', initial_epoch=1,
                                     save_cycle_len=save_cycle_len))
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))
lr = 2e-4

In [None]:
learn.fit(100, lr)
with learn.removed_cbs([learn.save_checkpoints]) as displayable_learn:
    displayable_learn.show_results(ds_idx=0)

In [None]:
ema_g_learn.show_results(ds_idx=0)

In [None]:
plot_multi_c_preds(learn.crit_preds_tracker, ['A->B', 'B->A'])

In [None]:
save_preds(learn.crit_preds_tracker, 
           [Path('crit_preds_face2anime_bidir_a2b_tr2_100ep.csv'),
            Path('crit_preds_face2anime_bidir_b2a_tr2_100ep.csv')])

# Evaluation

In [None]:
base_fid_samples_path = Path('./fid_samples')
n_fid_imgs = 10000

def download_pytorch_fid_calculator():        
    #!git clone https://github.com/mseitzer/pytorch-fid.git
    !pip install pytorch-fid

def create_fid_dirs(base_fid_samples_path):
    base_fid_samples_path.mkdir()
    (base_fid_samples_path/'fake').mkdir()
    (base_fid_samples_path/'fake/A').mkdir()
    (base_fid_samples_path/'fake/B').mkdir()
    (base_fid_samples_path/'real').mkdir()
    (base_fid_samples_path/'real/A').mkdir()
    (base_fid_samples_path/'real/B').mkdir()
         
def save_real_imgs(dls, n_imgs=10000):
    n_imgs_left = n_imgs
    while n_imgs_left > 0:
        b = dls.one_batch()
        bs = b[1].size()[0]
        dec_b = dls.decode_batch(b, max_n=bs)
        for i in range(bs):
            if n_imgs_left == 0: break
            in_a_t, in_b_t, _, _ = dec_b[i]
            img_a = PILImage.create(in_a_t)
            img_b = PILImage.create(in_b_t)
            img_idx = n_imgs_left-1
            img_a.save(base_fid_samples_path/f'real/A/{img_idx}.jpg')
            img_b.save(base_fid_samples_path/f'real/B/{img_idx}.jpg')
            #if n_imgs_left % 1000 == 0: print("saved " + str(img_idx))
            n_imgs_left -= 1    

def save_fake_imgs(learner, n_imgs=10000, **predict_n_kwargs):
    base_path = base_fid_samples_path
    preds_batch = predict_n(learner, n_imgs, **predict_n_kwargs)
    for i, (_, _, img_a2b, img_b2a) in enumerate(preds_batch):
        PILImage.create(img_a2b).save(base_path/f'fake/B/{i}.jpg')
        PILImage.create(img_b2a).save(base_path/f'fake/A/{i}.jpg')

In [None]:
!rm -R $base_fid_samples_path

In [None]:
download_pytorch_fid_calculator()
create_fid_dirs(base_fid_samples_path)

In [None]:
def eval_models(builders, n_epochs, base_path='./models', fn_suffix='', ema=False):
    assert is_iterable(builders) or is_iterable(n_epochs)
    if not is_iterable(builders): 
        builders = [builders] * len(list(n_epochs))
    if not is_iterable(n_epochs): 
        n_epochs = [n_epochs] * len(list(builders))
    for builder, n_ep in zip(builders, n_epochs):
        model_id = builder.__name__.split('_')[-1]
        learner = builder(for_inference=True)
        custom_load_model(learner, f'face2anime_bidir_tr{model_id}{fn_suffix}_{n_ep}ep', with_opt=False,
                          base_path=base_path, with_ema=ema)
        if ema: 
            learner = Learner(learner.dls, learner.ema_model,
                              loss_func=lambda *args: torch.tensor(0.))
        save_fake_imgs(learner, n_imgs=n_fid_imgs)
        for domain in ('A', 'B'):
            print(f'---- {model_id} ({domain}), after {n_ep} epochs ----')
            !python -m pytorch_fid {base_fid_samples_path/'fake'/domain} {base_fid_samples_path/'real'/domain}    

In [None]:
save_real_imgs(dls, n_fid_imgs)

In [None]:
eval_models(create_learner_1, range(5, 101, 5), fn_suffix='', base_path='./models')

In [None]:
eval_models(create_learner_1, range(5, 101, 5), base_path='./models', fn_suffix='', ema=True)

In [None]:
PILImage.create(base_fid_samples_path/'fake/10.jpg')

## Reference FID

With 10000 images:

* FID domain A ds vs itself (CelebA vs CelebA) ~ 2.8
* FID domain B ds vs itself (Animecharacterfaces vs Animecharacterfaces) ~ 4.0