In [None]:
from dataclasses import dataclass
from fastai.vision.all import *
from fastai.vision.gan import *
from functools import partial
import numpy as np 
import pandas as pd
from pathlib import Path
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable

In [None]:
run_as_standalone_nb = True

In [None]:
if run_as_standalone_nb:
    root_lib_path = Path('face2anime').resolve()
    if not root_lib_path.exists():
        !git clone https://github.com/davidleonfdez/face2anime.git
    if str(root_lib_path) not in sys.path:
        sys.path.insert(0, str(root_lib_path))
else:
    import local_lib_import

In [None]:
from face2anime.gen_utils import is_iterable
from face2anime.layers import ConcatPoolHalfDownsamplingOp2d, ConvHalfDownsamplingOp2d, TransformsLayer
from face2anime.losses import ContentLossCallback, CritPredsTracker, LossWrapper, R1GANGPCallback
from face2anime.misc import FeaturesCalculator
from face2anime.networks import Img2ImgGenerator, patch_res_critic, res_critic
from face2anime.train_utils import (add_ema_to_gan_learner, custom_load_model,
                                    custom_save_model)
from face2anime.transforms import AdaptiveAugmentsCallback, ADATransforms

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_size = 64
n_channels = 3
bs = 64
save_cycle_len = 5

# Data

## Target ds

animecharacterfaces, by Kaggle user *aadilmalik94*

In [None]:
anime_ds_path = Path('/kaggle/input/animecharacterfaces/animeface-character-dataset/data').resolve()

## Input ds

In [None]:
celeba_path = Path('/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba')
#input_fns = get_image_files(celeba_path)
# get_image_files is too slow, there's no need to check the extension here
input_fns = celeba_path.ls()
input_fns

The ds path passed to `dblock.dataloaders()` or `ImageDataLoaders.from_dblock()` will be forwarded
to `get_items`, which will return a list of items, usually a list of image paths if `get_items=get_image_files`.

So, for each item, we are expected to receive a filename `fn` and be able to
derive x and y from it, with `get_x(fn)` and `get_y(fn)`.

For unpaired image to image translation, we can:
* Use the target images ds path as the DataBlock `source`. Then, `get_y` can just return the path received.
* Load independently the filenames of the input images ds; let's call it `input_fns`. Then, `get_x` would need to return a random item from `input_fns`. `get_x` is called every time a data item is used; so, by using random, we can be sure every x is not tied to a fixed y; i.e., they won't be together in the same (x, y) batch every epoch for loss calculation.


---

In [None]:
def get_random_input(fn):
    return input_fns[random.randint(0, len(input_fns)-1)]


normalize_tf = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))


def get_dblock(extra_batch_tfms=None):
    if extra_batch_tfms is None: extra_batch_tfms = []    
    return DataBlock(blocks=(ImageBlock, ImageBlock),
                     get_x=get_random_input,
                     get_items=get_image_files,
                     #get_items=lambda path: target_fns,
                     splitter=IndexSplitter([]),
                     item_tfms=Resize(img_size, method=ResizeMethod.Crop), 
                     batch_tfms=[normalize_tf] + extra_batch_tfms)


dblock = get_dblock()
main_path = anime_ds_path
dls = dblock.dataloaders(main_path, path=main_path, bs=bs)

In [None]:
dls.show_batch()

# Loss function utils

In [None]:
vgg_content_layers_idx = [22]
ftrs_calc = FeaturesCalculator([], vgg_content_layers_idx, device=device,
                               input_norm_tf=normalize_tf)

## Reconstruction loss

In [None]:
@dataclass
class ReconstructionLossWeights:
    real_to_real:float=1.
    latent_a_to_latent:float=1.
    latent_b_to_latent:float=1.

class ReconstructionLossCallback(Callback):
    def __init__(self, enc_dec_generator:nn.Module, weights:ReconstructionLossWeights,
                 n_ch=3, eval_real_b_to_real=True, eval_latent_a_to_latent=True, 
                 eval_latent_b_to_latent=True, loss_func:Callable=None):
        self.generator = enc_dec_generator
        self.weights = weights
        self.eval_real_b_to_real = eval_real_b_to_real
        self.eval_latent_a_to_latent = eval_latent_a_to_latent
        self.eval_latent_b_to_latent = eval_latent_b_to_latent
        self.latent_layer_idx = 0
        if eval_latent_a_to_latent or eval_latent_b_to_latent:
            self.latent_layer_idx = self._get_latent_layer_idx(n_ch)
        self.loss_func = F.l1_loss if loss_func is None else loss_func
        
    def _get_latent_layer_idx(self, n_ch:int):
        with hook_outputs(self.generator) as all_hooks:
            last_spatial_size = 64
            self.generator(torch.rand(2, n_ch, last_spatial_size, last_spatial_size))
            latent_code_layer_idx = 0
            for i, h in enumerate(all_hooks):
                if h.stored.shape[-1] > last_spatial_size:
                    break
                last_spatial_size = h.stored.shape[-1]
                latent_code_layer_idx = i
            return latent_code_layer_idx
        
    def after_loss(self):
        if not self.gan_trainer.gen_mode: return
        eval_any_latent = self.eval_latent_a_to_latent or self.eval_latent_b_to_latent
        mid_hook = hook_output(self.generator[self.latent_layer_idx]) if eval_any_latent else None
        #encoder = nn.Sequential(*self.generator[:self.latent_layer_idx+1]) if eval_any_latent else None
        
        if self.eval_real_b_to_real:
            fake = self.generator(self.y)
            real_to_real_loss = self.loss_func(fake, self.y) * self.weights.real_to_real
            self.learn.loss_grad += real_to_real_loss
            # Store result inside learn.loss_func to make it visible to metrics display
            self.learn.loss_func.real_rec_loss = real_to_real_loss
            
        # Watchout: order matters! This `if` needs to be placed before the next one
        # (`if self.eval_latent_a_to_latent`) because mid_hook.stored is reused 
        # when (self.eval_real_b_to_real == True)       
        if self.eval_latent_b_to_latent:
            if mid_hook.stored is None:
                fake = self.generator(self.y)
            latent_b = mid_hook.stored
            # TODO: it would be more efficient to execute `encoder(fake)` only, but not sure 
            # if it could be possible w/o losing computation graph
            self.generator(fake)
            latent_b_rec = mid_hook.stored
            latent_b_rec_loss = self.loss_func(latent_b, latent_b_rec) * self.weights.latent_b_to_latent
            self.learn.loss_grad += latent_b_rec_loss
            # Store result inside learn.loss_func to make it visible to metrics display
            self.learn.loss_func.latent_b_rec_loss = latent_b_rec_loss
            
        if self.eval_latent_a_to_latent:
            fake = self.generator(self.x)
            latent_a = mid_hook.stored
            self.generator(fake)
            latent_a_rec = mid_hook.stored
            latent_a_rec_loss = self.loss_func(latent_a, latent_a_rec) * self.weights.latent_a_to_latent
            self.learn.loss_grad += latent_a_rec_loss
            # Store result inside learn.loss_func to make it visible to metrics display
            self.learn.loss_func.latent_a_rec_loss = latent_a_rec_loss
            
        if eval_any_latent:
            mid_hook.remove()


class DummyGen(nn.Module):
    """Generator composed by convs whose weights are full of ones with no bias."""
    def __init__(self, n_ftrs, k_szs, strides, paddings, transpose):
        super().__init__()
        convs = []
        for i, (nf, ks, stride, pad, tr) in enumerate(zip(n_ftrs[:-1], k_szs, strides, paddings, transpose)):
            conv_func = nn.ConvTranspose2d if tr else nn.Conv2d
            conv = conv_func(nf, n_ftrs[i+1], kernel_size=ks, stride=stride, padding=pad, bias=False)
            convs.append(conv)
            nn.init.constant_(conv.weight, 1)
        self.convs = nn.Sequential(*convs)

    def forward(self, x):
        return self.convs(x)
    
    def __iter__(self): return iter(self.convs)
    def __getitem__(self,i): return self.convs[i]

class DummyCritic(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x): return torch.zeros(x.size()[0], 1).requires_grad_(True)
    

def test_rec_loss():        
    n_ch = 3
    mid_ftrs = 6
    resample_ks = 4
    dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
                       # float tensor full of 1.'s
                       get_x=lambda fn: np.full((4, 4, n_ch), 255, dtype=np.uint8),
                       splitter = IndexSplitter([]),
                       # float tensor full of 0.2's
                       get_y=lambda fn: np.full((4, 4, n_ch), 51, dtype=np.uint8))
    dls = dblock.dataloaders(['', ''], bs=2)
    gen = DummyGen([n_ch, mid_ftrs, n_ch], [resample_ks, resample_ks], 
                   [2, 2], [0, 0], [False, True])
    crit = DummyCritic()    
    weights = ReconstructionLossWeights(1, 2, 3)
    rec_loss_cb = ReconstructionLossCallback(gen, weights)
    learn = GANLearner.wgan(dls, gen, crit, cbs=[rec_loss_cb], gen_first=True) #metrics = [...]
    learn.fit(1)
    
    deeper_gen = DummyGen([n_ch, mid_ftrs, mid_ftrs, mid_ftrs, n_ch], [resample_ks, 3, 3, resample_ks], 
                          [2, 1, 1, 2], [0, 0, 0, 0], [False, False, False, True])
    rec_loss_cb_deeper_gen = ReconstructionLossCallback(deeper_gen, weights)
    
    
    # Expected results are almost hardcoded in order to avoid repeating potential coding errors
    # from test code, although we are unnecessarily testing DummyGen at the same time.
    # If input=torch.full((bs, n_ch, 4, 4), item_val) ...
    #   -After first conv, out=torch.full((bs, 6, 1, 1), item_val*n_ch*(resample_ks**2))
    #   -After second conv, out=torch.full((bs, 3, 4, 4), (item_val*n_ch*(resample_ks**2))*mid_ftrs
    #   -After first conv, second forward, out=torch.full((bs, 6, 1, 1), ((item_val*n_ch*(resample_ks**2))*mid_ftrs)*n_ch*(resample_ks**2))
    real_out_values = 0.2 * mid_ftrs * n_ch * resample_ks**2
    expected_real_rec_loss = weights.real_to_real * abs(real_out_values - 0.2)    
    
    # Latent code obtained passing real target (y) as input
    latent_b_values = 0.2 * n_ch * resample_ks**2
    latent_b_rec_values = 0.2 * mid_ftrs * n_ch**2 * resample_ks**4
    expected_latent_b_rec_loss = weights.latent_b_to_latent * abs(latent_b_rec_values - latent_b_values)
    
    # Latent code obtained passing real input (x) as input
    latent_a_values = n_ch * resample_ks**2
    latent_a_rec_values = mid_ftrs * n_ch**2 * resample_ks**4
    expected_latent_a_rec_loss = weights.latent_a_to_latent * abs(latent_a_rec_values - latent_a_values)

    # A more concise but also error prone form would be:
#     y = torch.full((bs, n_ch, 4, 4), 0.2)
#     real_out = gen(y)
#     expected_real_rec_loss = weights.real_to_real * F.l1_loss(real_out, y)
#     latent_b = gen[0](y)
#     latent_b_rec = gen[0](gen[1](latent_b)) 
#     expected_latent_b_rec_loss = weights.latent_b_to_latent * F.l1_loss(latent_b_rec, latent_b)
#     x = torch.ones(bs, n_ch, 4, 4)
#     latent_a = gen[0](x)
#     latent_a_rec = gen[0](gen[1](latent_a))
#     expected_latent_a_rec_loss = weights.latent_a_to_latent * F.l1_loss(latent_a_rec, latent_a)
    
    expected_loss = expected_real_rec_loss + expected_latent_b_rec_loss + expected_latent_a_rec_loss
        
    assert rec_loss_cb._get_latent_layer_idx(n_ch) == 0
    assert rec_loss_cb_deeper_gen._get_latent_layer_idx(n_ch) == 2
    assert math.isclose(learn.recorder.losses[0], expected_loss, rel_tol=1e-5)

In [None]:
test_rec_loss()

------------------

# Training

In [None]:
def predict_n(learner, n_imgs, max_bs=64):
    dummy_path = Path('')
    dl = learner.dls.test_dl([dummy_path]*n_imgs, bs=max_bs)   
    inp, imgs_t, _, dec_imgs_t = learner.get_preds(dl=dl, with_input=True, with_decoded=True)
    dec_batch = dls.decode_batch((inp,) + tuplify(dec_imgs_t), max_n=n_imgs)
    return dec_batch
    
def predict_show_n(learner, n_imgs, **predict_n_kwargs):
    preds_batch = predict_n(learner, n_imgs, **predict_n_kwargs)
    _, axs = plt.subplots(n_imgs, 2, figsize=(6, n_imgs * 3))
    for i, (inp, pred_img) in enumerate(preds_batch):
        inp.show(ax=axs[i][0])
        pred_img.show(ax=axs[i][1])
        
class SaveCheckpointsCallback(Callback):
    "Callback that saves the model at the end of each epoch."
    def __init__(self, learn, fn_prefix, base_path=Path('.'), initial_epoch=1,
                 save_cycle_len=1):
        self.fn_prefix = fn_prefix
        self.base_path = base_path
        self.epoch = initial_epoch
        self.save_cycle_len = save_cycle_len
        
    def after_epoch(self):
        if (self.epoch % self.save_cycle_len) == 0:
            fn = f'{self.fn_prefix}_{self.epoch}ep'
            custom_save_model(learn, fn, base_path=self.base_path)
        self.epoch += 1
        
def save_preds(c_preds_tracker, filepath):
    return c_preds_tracker.to_df().to_csv(filepath)
        
def plot_c_preds(c_preds_tracker):
    preds_xs = range(len(c_preds_tracker.real_preds))
    sns.lineplot(x=preds_xs, y=c_preds_tracker.fake_preds.cpu(), label='Fake preds')
    ax=sns.lineplot(x=preds_xs, y=c_preds_tracker.real_preds.cpu(), label='Real preds')
    ax.set_xlabel('Number of batches')
    ax.set_ylabel('Critic preds')

In [None]:
def set_inn_options(net, **inn_kwargs):
    for k, module in net.named_modules():
        if not isinstance(module, nn.InstanceNorm2d): continue
        editable_module = net
        accesors = k.split('.')
        for accesor in accesors[:-1]:
            editable_module = (editable_module[int(accesor)] if accesor.isnumeric()
                              else getattr(editable_module, accesor))
        new_module = InstanceNorm(module.num_features, **inn_kwargs)
        if accesors[-1].isnumeric():
            editable_module[int(accesors[-1])] = new_module
        else:
            setattr(editable_module, accesors[-1], new_module)

In [None]:
@dataclass
class ADAConfig:
    p_change_thresh:float=0.6
    filter_tfms_to_array:Callable=None
    pad_mode:PadMode=PadMode.Reflection
    

def create_learner(for_inference=False, ada_conf=None, dblock=dblock, dls=dls, gp_w=1.,
                   latent_sz=100, g_norm=NormType.Instance, mid_mlp_depth=0,
                   n_extra_convs_by_c_res_block=0, g_skips=False, n_crit_iters=1, 
                   metrics=None, use_patch_critic=False):
    use_ada = ada_conf is not None
    leakyReLU02 = partial(nn.LeakyReLU, negative_slope=0.2)
    down_op = ConvHalfDownsamplingOp2d(ks=4, act_cls=leakyReLU02, bn_1st=False,
                                       norm_type=NormType.Batch)
    id_down_op = ConcatPoolHalfDownsamplingOp2d(conv_ks=3, act_cls=None, norm_type=None)
    crit_args = [img_size, n_channels, down_op, id_down_op]
    if use_patch_critic: crit_args.insert(2, img_size//8)
    crit_kwargs = dict(n_extra_convs_by_res_block=n_extra_convs_by_c_res_block, 
                       act_cls=leakyReLU02, bn_1st=False, n_features=128, 
                       flatten_full=True)
    crit_builder = patch_res_critic if use_patch_critic else res_critic
    base_critic = crit_builder(*crit_args, **crit_kwargs)
    if not use_ada: critic = base_critic
    
    def _decoder_builder(imsz, nch, latsz, hooks_by_sz=None): 
        return default_decoder(imsz, nch, latsz, norm_type=g_norm, hooks_by_sz=hooks_by_sz)
    generator = Img2ImgGenerator(img_size, n_channels, mid_mlp_depth=mid_mlp_depth, skip_connect=g_skips,
                                 encoder=default_encoder(img_size, n_channels, latent_sz, norm_type=g_norm),
                                 decoder_builder=_decoder_builder)
    
    cbs = []
    c_loss_interceptors = []
    tfms_array = []
    
    if not for_inference:
        # Pass base_critic to avoid grid_sample 2nd order derivative issue with ADA critic
        cbs.append(R1GANGPCallback(weight=gp_w, critic=base_critic))
        #cbs.append(ContentLossCallback(weight=content_loss_w, ftrs_calc=ftrs_calc, device=device))
        #cbs.append(ReconstructionLossCallback(generator, rec_loss_weights, **rec_loss_cb_kwargs))
        if use_ada:
            ada_tfms = ADATransforms(0., (img_size, img_size), pad_mode=ada_conf.pad_mode)
            tfms_array = (ada_conf.filter_tfms_to_array(ada_tfms) if ada_conf.filter_tfms_to_array is not None
                          else ada_tfms.to_array())
            ada_crit_preds_tracker = CritPredsTracker(reduce_batch=False)
            ada_cb = AdaptiveAugmentsCallback(ada_tfms, ada_crit_preds_tracker,
                                              preds_above_0_overfit_threshold=ada_conf.p_change_thresh)
            cbs.append(ada_cb)
            c_loss_interceptors.append(ada_crit_preds_tracker)
        overall_crit_preds_tracker = CritPredsTracker(reduce_batch=True)
        c_loss_interceptors.append(overall_crit_preds_tracker)
       
    if use_ada:
        critic = nn.Sequential(TransformsLayer(setup_aug_tfms(tfms_array)),
                               base_critic)       
    
    def gen_loss_func(*args): return 0
    crit_loss_func = nn.BCEWithLogitsLoss()
    loss_G, loss_C = gan_loss_from_func(gen_loss_func, crit_loss_func)
    loss_C = LossWrapper(loss_C, c_loss_interceptors)
    
    learn = GANLearner(dls, generator, critic, loss_G, loss_C,
                       opt_func=partial(Adam, mom=0., sqr_mom=0.99, wd=0.),
                       cbs=cbs, switcher=FixedGANSwitcher(n_crit=n_crit_iters, n_gen=1),
                       switch_eval=False, metrics=metrics)
    #metrics=LossMetric('content_loss')
    #metrics=LossMetrics(['real_rec_loss', 'latent_a_rec_loss', 'latent_b_rec_loss'])
    learn.recorder.train_metrics=True
    learn.recorder.valid_metrics=False
    add_ema_to_gan_learner(learn, dblock, main_path, decay=0.999)
    if not for_inference: learn.crit_preds_tracker = overall_crit_preds_tracker
    return learn

## TR 1: NSGAN-R1GP loss, SN+BN critic, IN+SN gen (both encoder and decoder), no mid MLP 

### TR 1a: global critic

In [None]:
def create_learner_1(*args, **kwargs):
    return create_learner(*args, **kwargs, gp_w=10., g_norm=NormType.Batch, mid_mlp_depth=2,
                          g_skips=True, n_extra_convs_by_c_res_block=1, n_crit_iters=3)

In [None]:
learn = create_learner_1()
learn.add_cb(SaveCheckpointsCallback(learn, 'refined_arch_face2anime_tr1', initial_epoch=1,
                                     save_cycle_len=save_cycle_len))
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))
lr = 2e-4

In [None]:
# custom_load_model(learn, 'refined_arch_face2anime_tr1_100ep', base_path='../input/refined-arch-face2anime/', with_ema=True)
# #preds_df = pd.read_csv(io.StringIO(preds_csv_str), index_col=0)
# #learn.crit_preds_tracker.load_from_df(preds_df, device)
# with learn.removed_cbs([learn.save_checkpoints]) as displayable_learn:
#     displayable_learn.show_results(ds_idx=0)

In [None]:
learn.fit(100, lr)
with learn.removed_cbs([learn.save_checkpoints]) as displayable_learn:
    displayable_learn.show_results(ds_idx=0)

In [None]:
ema_g_learn.show_results(ds_idx=0)

In [None]:
plot_c_preds(learn.crit_preds_tracker)

In [None]:
save_preds(learn.crit_preds_tracker, Path('crit_preds_face2anime_refined_tr1_100ep.csv'))

### TR 1b: patch critic

In [None]:
def create_learner_1b(*args, **kwargs):
    return create_learner(*args, **kwargs, gp_w=10., g_norm=NormType.Batch, mid_mlp_depth=2,
                          g_skips=True, n_extra_convs_by_c_res_block=0, n_crit_iters=3,
                          use_patch_critic=True)

In [None]:
learn = create_learner_1b()
learn.add_cb(SaveCheckpointsCallback(learn, 'refined_arch_face2anime_tr1b', initial_epoch=1,
                                     save_cycle_len=save_cycle_len))
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))
lr = 2e-4

In [None]:
learn.fit(100, lr)
with learn.removed_cbs([learn.save_checkpoints]) as displayable_learn:
    displayable_learn.show_results(ds_idx=0)

In [None]:
ema_g_learn.show_results(ds_idx=0)

In [None]:
plot_c_preds(learn.crit_preds_tracker)

In [None]:
save_preds(learn.crit_preds_tracker, Path('crit_preds_face2anime_refined_tr1b_100ep.csv'))

## TR2 [TR1 + ADA]: NSGAN-R1GP loss, SN+BN critic, IN+SN gen (both encoder and decoder), no mid MLP, ADA

In [None]:
def filter_spatial_minus_flip_rot_small(ada_tfms):
    return [tfm for tfm in ada_tfms.to_array() 
            if (isinstance(tfm, (AffineCoordTfm)) and not isinstance(tfm, (Rotate, Flip)))
                or tfm == ada_tfms.rotate_90x]

aug_dblock = get_dblock(extra_batch_tfms=[Flip(0.5)])
aug_dls = aug_dblock.dataloaders(main_path, path=main_path, bs=bs)


def create_learner_2(*args, **kwargs):
    return create_learner(*args, 
                          ada_conf=ADAConfig(filter_tfms_to_array=filter_spatial_minus_flip_rot_small, 
                                             p_change_thresh=0.8), 
                          dblock=aug_dblock,
                          dls=aug_dls,
                          gp_w=10., 
                          g_norm=NormType.Batch, 
                          mid_mlp_depth=2,
                          g_skips=True, 
                          n_extra_convs_by_c_res_block=1, 
                          n_crit_iters=3,
                          **kwargs)

### TR 2a: global critic

In [None]:
learn = create_learner_2()
learn.add_cb(SaveCheckpointsCallback(learn, 'refined_arch_face2anime_tr2', initial_epoch=1,
                                     save_cycle_len=save_cycle_len))
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))
lr = 2e-4

In [None]:
learn.fit(100, lr)
with learn.removed_cbs([learn.save_checkpoints, learn.loss_store]) as displayable_learn:
    displayable_learn.show_results(ds_idx=0)

In [None]:
plot_c_preds(learn.crit_preds_tracker)

In [None]:
sns.lineplot(x=range(len(learn.adaptive_augments.p_history)), y=learn.adaptive_augments.p_history)

In [None]:
save_preds(learn.crit_preds_tracker, Path('crit_preds_face2anime_refined_tr2_100ep.csv'))
!echo {learn.adaptive_augments.p_history} > p_history_face2anime_refined_tr2_100ep.txt

### TR 2b: patch critic

In [None]:
def create_learner_2(*args, **kwargs):
    return create_learner(*args, 
                          ada_conf=ADAConfig(filter_tfms_to_array=filter_spatial_minus_flip_rot_small, 
                                             p_change_thresh=0.8), 
                          dblock=aug_dblock,
                          dls=aug_dls,
                          gp_w=10., 
                          g_norm=NormType.Batch, 
                          mid_mlp_depth=2,
                          g_skips=True, 
                          n_extra_convs_by_c_res_block=1, 
                          n_crit_iters=3,
                          use_patch_critic=True,
                          **kwargs)

In [None]:
learn = create_learner_2b()
learn.add_cb(SaveCheckpointsCallback(learn, 'refined_arch_face2anime_tr2b', initial_epoch=1,
                                     save_cycle_len=save_cycle_len))
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))
lr = 2e-4

In [None]:
learn.fit(100, lr)
with learn.removed_cbs([learn.save_checkpoints, learn.loss_store]) as displayable_learn:
    displayable_learn.show_results(ds_idx=0)

In [None]:
plot_c_preds(learn.crit_preds_tracker)

In [None]:
sns.lineplot(x=range(len(learn.adaptive_augments.p_history)), y=learn.adaptive_augments.p_history)

In [None]:
save_preds(learn.crit_preds_tracker, Path('crit_preds_face2anime_refined_tr2b_100ep.csv'))
!echo {learn.adaptive_augments.p_history} > p_history_face2anime_refined_tr2b_100ep.txt

# Evaluation

In [None]:
base_fid_samples_path = Path('/kaggle/working/fid_samples')
n_fid_imgs = 10000

def download_pytorch_fid_calculator():        
    #!git clone https://github.com/mseitzer/pytorch-fid.git
    !pip install pytorch-fid

def create_fid_dirs(base_fid_samples_path):
    base_fid_samples_path.mkdir()
    (base_fid_samples_path/'fake').mkdir()
    (base_fid_samples_path/'real').mkdir()
        
def save_real_imgs(dls, n_imgs=10000, use_input_ds=False):
    n_imgs_left = n_imgs
    while n_imgs_left > 0:
        b = dls.one_batch()
        bs = b[1].size()[0]
        dec_b = dls.decode_batch(b, max_n=bs)
        for i in range(bs):
            if n_imgs_left == 0: break
            tuple_idx = 0 if use_input_ds else 1
            img_t = dec_b[i][tuple_idx]
            img = PILImage.create(img_t)
            img_idx = n_imgs_left-1
            img.save(base_fid_samples_path/f'real/{img_idx}.jpg')
            #if n_imgs_left % 1000 == 0: print("saved " + str(img_idx))
            n_imgs_left -= 1            
            
def save_fake_imgs(learner, n_imgs=10000, **predict_n_kwargs):
    base_path = base_fid_samples_path
    preds_batch = predict_n(learner, n_imgs, **predict_n_kwargs)
    for i, (inp, img) in enumerate(preds_batch):
        PILImage.create(img).save(base_path/f'fake/{i}.jpg')

In [None]:
!rm -R $base_fid_samples_path

In [None]:
download_pytorch_fid_calculator()
create_fid_dirs(base_fid_samples_path)

In [None]:
def eval_models(builders, n_epochs, base_path='/kaggle/input/new-face2anime-weights', ema=False):
    assert is_iterable(builders) or is_iterable(n_epochs)
    if not is_iterable(builders): 
        builders = [builders] * len(list(n_epochs))
    if not is_iterable(n_epochs): 
        n_epochs = [n_epochs] * len(list(builders))
    for builder, n_ep in zip(builders, n_epochs):
        model_id = builder.__name__.split('_')[-1]
        learner = builder(for_inference=True)
        custom_load_model(learner, f'refined_arch_face2anime_tr{model_id}_{n_ep}ep', with_opt=False,
                          base_path=base_path)
        if ema: 
            learner = Learner(learner.dls, learner.ema_model,
                              loss_func=lambda *args: torch.tensor(0.))
        save_fake_imgs(learner, n_imgs=n_fid_imgs)
        print(f'---- {model_id}, after {n_ep} epochs ----')
        !python -m pytorch_fid {base_fid_samples_path/'fake'} {base_fid_samples_path/'real'}

In [None]:
save_real_imgs(dls, n_fid_imgs)

In [None]:
eval_models(create_learner_1, range(5, 101, 5), base_path='./models')

In [None]:
eval_models(create_learner_1, range(5, 101, 5), base_path='./models', ema=True)

In [None]:
PILImage.create(base_fid_samples_path/'fake/9999.jpg')

## Reference FID

With 10000 images:

* FID input ds vs itself (CelebA vs CelebA) ~ 2.8
* FID target ds vs itself (Animecharacterfaces vs Animecharacterfaces) ~ 4.0