In [None]:
from dataclasses import dataclass
from fastai.vision.all import *
from fastai.vision.gan import *
from functools import partial
import gc
import itertools
import numpy as np 
import pandas as pd
from pathlib import Path
import seaborn as sns
import subprocess
from sys import platform
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Tuple, Union

In [None]:
run_as_standalone_nb = False

In [None]:
if run_as_standalone_nb:
    root_lib_path = Path('face2anime').resolve()
    if not root_lib_path.exists():
        !git clone https://github.com/davidleonfdez/face2anime.git
    if str(root_lib_path) not in sys.path:
        sys.path.insert(0, str(root_lib_path))
else:
    import local_lib_import

In [None]:
from face2anime.gen_utils import is_iterable
from face2anime.layers import (ConcatPoolHalfDownsamplingOp2d, ConvHalfDownsamplingOp2d, FeatureStatType, 
                               TransformsLayer, ParentNetSource, ResBlockDown)
from face2anime.losses import (ContentLossCallback, CritPredsTracker, CycleConsistencyLossCallback, 
                               CrossIdentityLossCallback, CycleGANLoss, IdentityLossCallback, LossWrapper, 
                               MultiCritPredsTracker, R1GANGPCallback)
from face2anime.misc import FeaturesCalculator
from face2anime.networks import (CycleCritic, CycleGenerator, default_decoder, default_encoder, 
                                 Img2ImgGenerator, PatchResCritic, res_critic)
from face2anime.plot import plot_c_preds, plot_multi_c_preds
from face2anime.train_utils import (add_ema_to_gan_learner, clean_mem, custom_load_model,
                                    custom_save_model, EpochFilterAll, SaveCheckpointsCallback)
from face2anime.transforms import AdaptiveAugmentsCallback, ADATransforms

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_size = 64
n_channels = 3
bs = 64
save_cycle_len = 5

# Data

You need to use two different datasets, with each of them containing images from one specific domain.

## Domain A ds

`input_a_fns` must be set to a list that contains the paths of the training images that belong to domain A. For example:

In [None]:
celeba_path = Path('/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba')
#input_fns = get_image_files(celeba_path)
# get_image_files is too slow, there's no need to check the extension here
input_a_fns = celeba_path.ls()
input_a_fns

## Domain B ds

The value of `anime_ds_path` must be set to the path of the parent folder of the training images that belong to domain B. 

In this example, the ds used is "animecharacterfaces", by Kaggle user *aadilmalik94*

In [None]:
anime_ds_path = Path('/kaggle/input/animecharacterfaces/animeface-character-dataset/data').resolve()

The ds path passed to `dblock.dataloaders()` or `ImageDataLoaders.from_dblock()` will be forwarded
to `get_items`, which will return a list of items, usually a list of image paths if `get_items=get_image_files`.

So, for each item, we are expected to receive a filename `fn` and be able to
derive x and y from it, with `get_x(fn)` and `get_y(fn)`.

For bidirectional unpaired image to image translation, we can:
* Use the domain B ds path as the DataBlock `source`. 
* Load independently the filenames of the domain A ds; let's call it `input_a_fns`
* `get_y` needs two functions:
  * The first one can just return the path received (domain B).
  * The second returns a random item from `input_a_fns` (domain A).
* `get_x` also needs two functions:
  * The first one return a random item from `input_a_fns` (domain A).
  * The second can just return the path received (domain B).
* `get_x` and `get_y` are called every time a data item is used; so, by using random, we can be sure every x is not tied to a fixed y; i.e., they won't be together in the same (x, y) batch every epoch for loss calculation.
* In this case, domain A: human faces; domain B: anime faces.


In [None]:
class Domain(Enum):
    A = 0
    B = 1


def reverse_domain(domain): return Domain.B if domain == Domain.A else Domain.A


def get_random_fn_a(fn):
    return input_a_fns[random.randint(0, len(input_a_fns)-1)]


def get_input_dependant_fn_a(fn):
    return input_a_fns[hash(fn) % len(input_a_fns)]


normalize_tf = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))


def get_dblock(extra_batch_tfms=None):
    if extra_batch_tfms is None: extra_batch_tfms = []    
    return DataBlock(blocks=(ImageBlock, ImageBlock, ImageBlock, ImageBlock),
                     get_x=[get_random_fn_a, noop],
                     get_y=[noop, get_random_fn_a],
                     get_items=get_image_files,
                     #get_items=lambda path: target_fns,
                     splitter=IndexSplitter([]),
                     item_tfms=Resize(img_size, method=ResizeMethod.Crop), 
                     batch_tfms=[normalize_tf] + extra_batch_tfms,
                     n_inp=2)


dblock = get_dblock()
main_path = anime_ds_path
dls = dblock.dataloaders(main_path, path=main_path, bs=bs)

In [None]:
sample_batch = dls.one_batch()
titles = ['x1 (A)', 'x2 (B)', 'y1 (B)', 'y2 (A)']
_, axs = plt.subplots(1, 4)
for t, ax, title in zip(sample_batch, axs, titles):
    normalize_tf.decode(t)[0].show(ax=ax, title=title)

------------------

# Helper methods

In [None]:
def predict_n(learner, inputs_idxs:Union[int, Tuple[int, int]], max_bs=64):
    dummy_path = Path('.')
    items = learner.dls.train.items
    if isinstance(inputs_idxs, int):
        n_imgs = inputs_idxs
        ini_idx = 0
        end_idx = n_imgs
    else:
        ini_idx, end_idx = inputs_idxs
        n_imgs = end_idx - ini_idx
    if len(items) < end_idx: items = list(itertools.islice(itertools.cycle(items), ini_idx, end_idx))
    dl = learner.dls.test_dl(items[:n_imgs], bs=max_bs)   
    inp, _, _, dec_imgs_t = learner.get_preds(dl=dl, with_input=True, with_decoded=True)
    dec_batch = dls.decode_batch(inp + dec_imgs_t, max_n=n_imgs)
    return dec_batch


def predict_show_n(learner, n_imgs, **predict_n_kwargs):
    preds_batch = predict_n(learner, n_imgs, **predict_n_kwargs)
    _, axs = plt.subplots(n_imgs, 4, figsize=(12, n_imgs * 3))
    for i, (in_a, in_b, pred_a2b, pred_b2a) in enumerate(preds_batch):
        in_a.show(ax=axs[i][0], title='In A')
        pred_a2b.show(ax=axs[i][1], title='Out A->B')
        in_b.show(ax=axs[i][2], title='In B')
        pred_b2a.show(ax=axs[i][3], title='Out B->A')
        
        
def save_preds(c_preds_tracker, filepaths):
    return [df.to_csv(filepath) for df, filepath in zip(c_preds_tracker.to_dfs(), filepaths)]

In [None]:
def _forward_batch(model, batch, device):
    input = batch[:2]
    if device is not None:
        for i in range(2): input[i] = input[i].to(device)
    model(*input)


def create_learner(for_inference=False, dblock=dblock, dls=dls, gp_w=10., latent_sz=100, 
                   mid_mlp_depth=2, g_norm=NormType.Instance, n_crit_iters=3,
                   n_extra_convs_by_c_res_block=1, cycle_cons_w=1., id_loss_w=1., 
                   cross_id_loss_w=0, use_patch_critic=True, ftrs_stats=FeatureStatType.MEAN, 
                   ftrs_stats_source=None):
    leakyReLU02 = partial(nn.LeakyReLU, negative_slope=0.2)
    down_op = ConvHalfDownsamplingOp2d(ks=4, act_cls=leakyReLU02, bn_1st=False,
                                       norm_type=NormType.Batch)
    id_down_op = ConcatPoolHalfDownsamplingOp2d(conv_ks=3, act_cls=None, norm_type=None)
    crit_args = [img_size, n_channels, down_op, id_down_op]
    if use_patch_critic: crit_args.insert(2, img_size//8)
    crit_kwargs = dict(n_extra_convs_by_res_block=n_extra_convs_by_c_res_block, 
                       act_cls=leakyReLU02, bn_1st=False, n_features=128, 
                       flatten_full=True)
    if use_patch_critic:
        crit_kwargs['ftrs_stats'] = ftrs_stats
        crit_kwargs['ftrs_stats_source'] = ftrs_stats_source
        crit_kwargs['input_norm_tf'] = normalize_tf
        crit_kwargs['device'] = device
    crit_builder = PatchResCritic if use_patch_critic else res_critic  
    base_critics = [crit_builder(*crit_args, **crit_kwargs) for _ in range(2)]
    base_critics = CycleCritic(*base_critics)
    critic = base_critics
    
    def _decoder_builder(imsz, nch, latsz, hooks_by_sz=None): 
        return default_decoder(imsz, nch, latsz, norm_type=g_norm, hooks_by_sz=hooks_by_sz)
    generators = [Img2ImgGenerator(img_size, n_channels, mid_mlp_depth=mid_mlp_depth, skip_connect=True,
                                   encoder=default_encoder(img_size, n_channels, latent_sz, norm_type=g_norm),
                                   decoder_builder=_decoder_builder)
                  for _ in range(2)]
    generator = CycleGenerator(*generators)
    
    cbs = []
    c_loss_interceptors = []
    metrics = []
    if not for_inference:
        # Pass base_critic to avoid grid_sample 2nd order derivative issue with ada critic
        cbs.append(R1GANGPCallback(weight=gp_w, critic=base_critics))
        if cycle_cons_w > 0: 
            cbs.append(CycleConsistencyLossCallback(generator.g_a2b, 
                                                    generator.g_b2a, 
                                                    weight=cycle_cons_w))
            metrics.append('cycle_loss')
        if id_loss_w > 0:
            cbs.append(IdentityLossCallback(generator.g_a2b, 
                                            generator.g_b2a, 
                                            weight=id_loss_w))
            metrics.append('identity_loss')
        if cross_id_loss_w > 0:
            cbs.append(CrossIdentityLossCallback(generator.g_a2b, 
                                                 generator.g_b2a, 
                                                 weight=cross_id_loss_w))
            metrics.append('cross_identity_loss')
        overall_crit_preds_tracker = MultiCritPredsTracker(reduce_batch=True)
        c_loss_interceptors.append(overall_crit_preds_tracker)
        
    def gen_loss_func(*args): return 0
    crit_loss_func = nn.BCEWithLogitsLoss()
    loss_G, loss_C = gan_loss_from_func(gen_loss_func, crit_loss_func)
    loss_C = LossWrapper(loss_C, c_loss_interceptors)
    
    learn = GANLearner(dls, generator, critic, loss_G, loss_C,
                       opt_func=partial(Adam, mom=0., sqr_mom=0.99, wd=0.),
                       cbs=cbs, switcher=FixedGANSwitcher(n_crit=n_crit_iters, n_gen=1),
                       switch_eval=False, metrics=LossMetrics(metrics) or None)
    learn.loss_func = CycleGANLoss(learn.loss_func)
    learn.recorder.train_metrics=True
    learn.recorder.valid_metrics=False
    add_ema_to_gan_learner(learn, dblock, decay=0.999, forward_batch=_forward_batch)
    if not for_inference: learn.crit_preds_tracker = overall_crit_preds_tracker
    return learn

# Training

We begin by creating an instance of a fastai GANLearner object with the help of our custom `create_learner` method.

The main parameters you may need to tweak are:

* `use_patch_critic`: if False, a global critic is used instead of a patch critic. A global critic could be preferable when there's not a clear pixel correspondence between an image from one domain and its expected translation from the other domain.
* `gp_w`: gradient penalty strength.
* `g_norm`: type of normalization performed by the generator. The most common choice is NormType.Instance but NormType.Batch could produce even better results for EMA models.

In [None]:
def create_learner_1(*args, **kwargs):
    return create_learner(*args, **kwargs, g_norm=NormType.Batch)

learn = create_learner_1()

An alternative model (Pytorch's nn.Module) that contains an EMA of the weights of the generator is stored in `learn.ema_model` and updated after every optimizer step. To handle it in an easier way, we should a create a learner for it:

In [None]:
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))

This model will (almost always) generate higher quality images than the trained generator.

If you wish to automatically store checkpoints of your model in the middle of a training run, it can be done with `SaveCheckpointsCallback`. The parameter `save_cycle_len` lets you choose the number of epochs between checkpoints.

Given that the EMA model is not trained, the running statistics of its BN layers need to be updated manually after training. This update is actually done automatically by an `EMACallback` (attached by `create_learner`), only at the end of each `learn.fit()` call because it's expensive. However, as a consequence, the checkpoints saved during a single `fit` execution will have outdated BN running stats; to avoid this situation, you can pass an aditional parameter `pre_save_actions = [UpdateEMAPreSaveAction(learn, epoch_filter=EpochFilterAll())]` to enforce than the aforementioned update is performed right before the creation of every checkpoint.

In [None]:
pre_save_actions = [UpdateEMAPreSaveAction(learn, epoch_filter=EpochFilterAll())]
learn.add_cb(SaveCheckpointsCallback('face2anime_bidir_tr1', initial_epoch=1,
                                     save_cycle_len=save_cycle_len,
                                     pre_save_actions=pre_save_actions))

If "before the creation of every checkpoint" is too often for your needs, a different epoch filter can be used:
* `EpochFilterMultipleOfN`
* `EpochFilterAfterN`
* `ComposedEpochFilter`

For instance, if `epoch_filter = EpochFilterMultipleOfN(3)`, the BN update of the EMA generator will only be executed before saving the checkpoints at epochs that are common multiples of `save_cycle_len` and 3.

Finally, we need to choose a learning rate. At least at the beginning, anything between [1e-4, 5e-4] should be reasonable. We then call `learner.fit` with a number of epochs and the lr.

In [None]:
lr = 2e-4
learn.fit(200, lr)

To inspect a sample of results, one should call `predict_show_n` with the learner and the number of images to show. Before, we need to temporarily disable some callbacks only needed for training that would otherwise be called when obtaining the predictions:

In [None]:
cbs_to_remove_for_display = [learn.save_checkpoints, 
                             learn.cycle_consistency_loss, 
                             learn.identity_loss]
with learn.removed_cbs(cbs_to_remove_for_display) as displayable_learn:
    predict_show_n(displayable_learn, 6)

To show some images produced by the EMA generator, we just need to call `predict_show_n` with `ema_g_learn` as its first parameter:

In [None]:
predict_show_n(ema_g_learn, 6)

The learner also tracks a history of the logits of the predictions output by each of the two critics in the attribute `crit_preds_tracker`. It actually stores the mean by batch, not every prediction. To plot them, execute:

In [None]:
plot_multi_c_preds(learn.crit_preds_tracker, ['A->B', 'B->A'])

If you desire to preserve the aforementioned history, you should call `save_preds` and provide two paths with .csv extension: the first one for the critic of the transformation A->B (i.e. the critic of domain B) and the second for the critic of the transformation B->A (i.e. the critic of domain A).

In [None]:
save_preds(learn.crit_preds_tracker, 
           [Path('crit_preds_face2anime_bidir_a2b_tr1_200ep.csv'),
            Path('crit_preds_face2anime_bidir_b2a_tr1_200ep.csv')])

## Resuming training

In order to resume a training run from a past session, first of all, we create the learners as always:

In [None]:
learn = create_learner(g_norm=NormType.Batch)
ema_g_learn = Learner(dls, learn.ema_model, loss_func=lambda *args: torch.tensor(0.))

When attaching `SaveCheckpointsCallback`, remember to set `initial_epoch` to the numbers of epochs already completed plus one.

In [None]:
learn.add_cb(SaveCheckpointsCallback('face2anime_bidir_tr1', initial_epoch=201,
                                     save_cycle_len=save_cycle_len))

Then, to load a saved model, you need at least two files:
* A .pth file containing the weights of the trained generator.
* A .pth file containing the weights of the EMA generator.

If you want to continue tracking the history of critic predictions, so that the new ones are appended to the old history, you must also pass the two .csv files indicated in the previous section.

`custom_load_model` assumes the same naming convention than `custom_save_model`, which is the method used internally by `SaveCheckpointsCallback`, so only one filename is required for the two pytorch files. 

For instance, the example in the following cell assumes these locations of the input files:

* ../input/face2anime-bidir/face2anime_bidir_tr1_200ep.pth
* ../input/face2anime-bidir/face2anime_bidir_tr1_200ep_ema.pth
* ./crit_preds_face2anime_bidir_b2a_tr1_200ep.csv
* ./crit_preds_face2anime_bidir_a2b_tr1_200ep.csv

In [None]:
custom_load_model(learn, 'face2anime_bidir_tr1_200ep', base_path='../input/face2anime-bidir/', with_ema=True)
preds_df_a = pd.read_csv(Path('crit_preds_face2anime_bidir_b2a_tr1_200ep.csv'), index_col=0)
preds_df_b = pd.read_csv(Path('crit_preds_face2anime_bidir_a2b_tr1_200ep.csv'), index_col=0)        
learn.crit_preds_tracker.load_from_dfs([preds_df_b, preds_df_a], device)

From here, we can go on exactly like in the previous section.

# Performing evaluation

To evaluate our trained models, we are going to use FID as metric. We'll rely on an external python package called pytorch-fid. This package just needs the paths of two directories as inputs, one for each set of images to be compared; so, before invoking the FID process, we need to store in disk a set of real images and a set of fake images generated by the model we are evaluating.

Most papers use 50000 images for each set, but 10000 images should be enough for a relative comparison between our own models.

In [None]:
base_fid_samples_path = Path('./fid_samples')
n_fid_imgs = 10000


def download_pytorch_fid_calculator():        
    if platform == 'win32':
        # As of 08/21, installing from PyPI on Windows doesn't work
        !pip install git+https://github.com/mseitzer/pytorch-fid.git
    else:
        !pip install pytorch-fid

    
def create_fid_dirs(base_fid_samples_path):
    base_fid_samples_path.mkdir()
    (base_fid_samples_path/'fake').mkdir()
    (base_fid_samples_path/'fake/A').mkdir()
    (base_fid_samples_path/'fake/B').mkdir()
    (base_fid_samples_path/'real').mkdir()
    (base_fid_samples_path/'real/A').mkdir()
    (base_fid_samples_path/'real/B').mkdir()
    (base_fid_samples_path/'input').mkdir()
    (base_fid_samples_path/'input/A').mkdir()
    (base_fid_samples_path/'input/B').mkdir()

    
def save_real_imgs(dls, base_path, n_imgs=10000):
    n_imgs_left = n_imgs
    while n_imgs_left > 0:
        b = dls.one_batch()
        bs = b[1].size()[0]
        dec_b = dls.decode_batch(b, max_n=bs)
        for i in range(bs):
            if n_imgs_left == 0: break
            in_a_t, in_b_t, _, _ = dec_b[i]
            img_a = PILImage.create(in_a_t)
            img_b = PILImage.create(in_b_t)
            img_idx = n_imgs_left-1
            img_a.save(base_path/f'real/A/{img_idx}.jpg')
            img_b.save(base_path/f'real/B/{img_idx}.jpg')
            #if n_imgs_left % 1000 == 0: print("saved " + str(img_idx))
            n_imgs_left -= 1    

            
def save_fake_imgs(learner, base_path, n_imgs=10000, max_pred_sz=5000, save_inputs=False, 
                   **predict_n_kwargs):
    n_imgs_left = n_imgs
    n_chunks = math.ceil(n_imgs/max_pred_sz)
    idxs = [(i*max_pred_sz, min(n_imgs, (i+1)*max_pred_sz)) for i in range(n_chunks)]
    for ini_idx, end_idx in idxs:
        preds_batch = predict_n(learner, (ini_idx, end_idx), **predict_n_kwargs)
        for i, (in_a, in_b, img_t_a2b, img_t_b2a) in enumerate(preds_batch):
            idx = i + ini_idx
            PILImage.create(img_t_a2b).save(base_path/f'fake/B/{idx}.jpg')
            PILImage.create(img_t_b2a).save(base_path/f'fake/A/{idx}.jpg')
            if save_inputs:
                PILImage.create(in_a).save(base_path/f'input/A/{idx}.jpg')
                PILImage.create(in_b).save(base_path/f'input/B/{idx}.jpg')
        preds_batch = None
        clean_mem()

We must begin by setting up the environment:
* Install pytorch-fid Python package
* Create the directories where the output images will be placed.

In [None]:
#if run_as_standalone_nb:
download_pytorch_fid_calculator()
create_fid_dirs(base_fid_samples_path)

A set of real images from each domain must be saved into the corresponding directories:

In [None]:
save_real_imgs(dls, base_fid_samples_path, n_fid_imgs)

In [None]:
class FIDEvalType(Enum):
    FAKE_VS_TARGET = 1
    INPUT_VS_FAKE = 2


def exec_fid_proc(domain, eval_type, base_fid_samples_path):
    second_set_path = (base_fid_samples_path/'real'/domain.name if eval_type == FIDEvalType.FAKE_VS_TARGET 
                       else base_fid_samples_path/'input'/reverse_domain(domain).name)
    return subprocess.run(["python", "-m", "pytorch_fid", base_fid_samples_path/'fake'/domain.name, 
                           second_set_path], 
                          stdout=subprocess.PIPE)


def fid_out_to_arr(fid_proc_out):
    if isinstance(fid_proc_out, bytes):
        fid_proc_out = fid_proc_out.decode(sys.stdout.encoding)
    # TODO: a regex could be more robust
    float_fids = [round(float(line[5:].strip()), ndigits=1) 
                  for line in fid_proc_out.split('\n') 
                  if line.startswith('FID')]
    return float_fids


def eval_models(builders, n_epochs, eval_type=FIDEvalType.FAKE_VS_TARGET, 
                base_fid_samples_path='./fid_samples', base_models_path='./models', 
                fn_suffix='', ema=False):
    assert is_iterable(builders) or is_iterable(n_epochs)
    if not is_iterable(builders): 
        builders = [builders] * len(list(n_epochs))
    if not is_iterable(n_epochs): 
        n_epochs = [n_epochs] * len(list(builders))
    result = {d.name: [] for d in Domain}
    for builder, n_ep in zip(builders, n_epochs):
        model_id = builder.__name__.split('_')[-1]
        learner = builder(for_inference=True)
        custom_load_model(learner, f'face2anime_bidir_tr{model_id}{fn_suffix}_{n_ep}ep', with_opt=False,
                          base_path=base_models_path, with_ema=ema)
        if ema: 
            learner = Learner(learner.dls, learner.ema_model,
                              loss_func=lambda *args: torch.tensor(0.))
        save_fake_imgs(learner, base_fid_samples_path, n_imgs=n_fid_imgs, 
                       save_inputs=(eval_type==FIDEvalType.INPUT_VS_FAKE))
        for domain in Domain:
            completed_proc = exec_fid_proc(domain, eval_type, base_fid_samples_path)
            fid_value = fid_out_to_arr(completed_proc.stdout)
            result[domain.name].extend(fid_value)
            print(f'---- {model_id} ({domain.name}), after {n_ep} epochs ----')
            print(completed_proc.stdout)
    return result

The method `eval_models` does everything else needed to get the FID measurements requested: it loads the models, generates the fake images, saves them and runs the FID evaluation process.

Its parameters are:
* `builders`: a list of methods that receive, at least, an optional parameter `for_inference` and return a learner.
* `n_epochs`: an iterable containing the indexes of the epochs at which we want to evaluate the learners built by calling the methods passed as `builders`.
* `eval_type`: the default is FIDEvalType.FAKE_VS_TARGET, which is the usual FID measurement that compares a set of generated images against a set of images of the target domain. If you pass FIDEvalType.INPUT_VS_FAKE, the set of input images is compared with the set of output images; it can serve as a measurement of content preservation.
* `base_models_path`: path or str that points to the directory that contains all the .pth files.
* `base_fid_samples_path`: path of the root folder of the fid samples directory tree. It must be same path passed to `create_fid_dirs` and `save_real_imgs`.
* `fn_suffix`: substring expected to appear in the model filenames after the model id.
* `ema`: indicates if the EMA generators must be used to create the fake images, instead of the trained generators.

The filename of the .pth files is infered from the builder name. For instance, if we called `eval_models` with this combinations of parameters:

```
eval_models([create_learner_1, create_learner_2], range(5, 201, 5), fn_suffix='rerun', 
            base_models_path='./models', base_fid_samples_path=base_fid_samples_path)
```

the models files would be expected to be located in:

* ./models/face2anime_bidir_tr1rerun_5ep.pth
* ./models/face2anime_bidir_tr1rerun_5ep_ema.pth
* ./models/face2anime_bidir_tr1rerun_10ep.pth
* ./models/face2anime_bidir_tr1rerun_10ep_ema.pth
* ...
* ./models/face2anime_bidir_tr1rerun_200ep.pth
* ./models/face2anime_bidir_tr1rerun_200ep_ema.pth
* ./models/face2anime_bidir_tr2rerun_5ep.pth
* ./models/face2anime_bidir_tr2rerun_5ep_ema.pth
* ./models/face2anime_bidir_tr2rerun_10ep.pth
* ./models/face2anime_bidir_tr2rerun_10ep_ema.pth
* ...
* ./models/face2anime_bidir_tr2rerun_200ep.pth
* ./models/face2anime_bidir_tr2rerun_200ep_ema.pth



In [None]:
eval_models(create_learner_1, range(5, 201, 5), fn_suffix='', base_models_path='./models',
            base_fid_samples_path=base_fid_samples_path)

In [None]:
eval_models(create_learner_1, range(5, 201, 5), base_path='./models', fn_suffix='', ema=True,
            base_fid_samples_path=base_fid_samples_path)

In [None]:
eval_models(create_learner_1, range(5, 201, 5), base_path='./models', fn_suffix='',
            ema=True, eval_type=FIDEvalType.INPUT_VS_FAKE, 
            base_fid_samples_path=base_fid_samples_path)

If you wish to show some saved images:

In [None]:
PILImage.create(base_fid_samples_path/'fake/A/10.jpg')

In [None]:
PILImage.create(base_fid_samples_path/'fake/B/10.jpg')

In [None]:
PILImage.create(base_fid_samples_path/'real/A/10.jpg')

In [None]:
PILImage.create(base_fid_samples_path/'real/B/10.jpg')