In [52]:
from fastai.vision.all import *
from fastai.vision.gan import *
from functools import partial
import math
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
from pathlib import Path
import PIL
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg19
from typing import List, Tuple

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_size = 224
n_channels = 3
bs = 32

# Data

## Target ds

animecharacterfaces, by Kaggle user *aadilmalik94*

In [None]:
#anime_faces_path = Path('/kaggle/input/anime-faces-safebooru/anime-faces').resolve()
anime_faces_path = Path('/kaggle/input/animecharacterfaces/animeface-character-dataset/data').resolve()

## Input ds

In [None]:
celeba_path = Path('/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba')
#input_fns = get_image_files(celeba_path)
# get_image_files is too slow, there's no need to check the extension here
input_fns = celeba_path.ls()
input_fns

The ds path passed to `dblock.dataloaders()` or `ImageDataLoaders.from_dblock()` will be forwarded
to `get_items`, which will return a list of items, usually a list of image paths if `get_items=get_image_files`.

So, for each item, we are expected to receive a filename `fn` and be able to
derive x and y from it, with `get_x(fn)` and `get_y(fn)`.

For unpaired image to image translation, we can:
* Use the target images ds path as the DataBlock `source`. Then, `get_y` can just return the path received.
* Load independently the filenames of the input images ds; let's call it `input_fns`. Then, `get_x` would need to return a random item from `input_fns`. `get_x` is called every time a data item is used; so, by using random, we can be sure every x is not tied to a fixed y; i.e., they won't be together in the same (x, y) batch every epoch for loss calculation.

---

In [None]:
def get_random_input(fn):
    return input_fns[random.randint(0, len(input_fns)-1)]
    
dblock = DataBlock(blocks = (ImageBlock, ImageBlock),
                   get_x = get_random_input,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(img_size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
main_path = anime_faces_path
dls = dblock.dataloaders(main_path, path=main_path, bs=bs)

In [None]:
dls.show_batch()

# Loss function

## Content loss

In [None]:
class FeaturesCalculator:
    def __init__(self, vgg_style_layers_idx:List[int], vgg_content_layers_idx:List[int],
                 vgg:nn.Module=None, normalize_inputs=False, device:torch.device=None):
        self.vgg = vgg19(pretrained=True) if vgg is None else vgg
        self.vgg.eval()
        if device is not None: self.vgg.to(device)
        modules_to_hook = [self.vgg.features[idx] for idx in (*vgg_style_layers_idx, *vgg_content_layers_idx)]
        self.hooks = hook_outputs(modules_to_hook, detach=False)
        self.style_ftrs_hooks = self.hooks[:len(vgg_style_layers_idx)]
        self.content_ftrs_hooks = self.hooks[len(vgg_style_layers_idx):]
        self.normalize_inputs = normalize_inputs
        # TODO: when to remove hooks??? `clean` method????
    
    def _get_hooks_out(self, hooks):
        return [h.stored for h in hooks]
    
    def _forward(img_t:torch.Tensor):
        if self.normalize_inputs: 
            mean, std = fastai.vision.imagenet_stats
            img_t = fastai.vision.normalize(img_t, torch.tensor(mean), torch.tensor(std))
        self.vgg(img_t)
    
    def calc_style(self, img_t:torch.Tensor) -> List[torch.Tensor]:
        self.vgg(img_t)
        return self._get_hooks_out(self.style_ftrs_hooks)
    
    def calc_content(self, img_t:torch.Tensor) -> List[torch.Tensor]:
        self.vgg(img_t)
        return self._get_hooks_out(self.content_ftrs_hooks)
    
    def calc_style_and_content(self, img_t:torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        self.vgg(img_t)
        style_ftrs = self._get_hooks_out(self.style_ftrs_hooks)
        content_ftrs = self._get_hooks_out(self.content_ftrs_hooks)
        return style_ftrs, content_ftrs

In [None]:
vgg_content_layers_idx = [22]
ftrs_calc = FeaturesCalculator([], vgg_content_layers_idx, device=device)

In the next cell, there are two versions of content loss, with the same behaviour but different implementation:
* A functional version, returned by `get_content_loss`. A callback parameter is needed 
* A callback version. Although less intuitive, this may be preferable because it stores the loss value inside `learner.loss_func.content_loss`, making it accessible to the metrics display system. If you want the learner to display the content loss every epoch, it only requires passing `metrics=['content_loss', ...]` to GANLearner.init.

In [None]:
content_loss_func = nn.MSELoss(reduction='mean')


def get_content_loss(last_input_cb):
    def _content_loss(output, target):
        input_content_ftrs = ftrs_calc.calc_content(last_input_cb.x)[0]
        output_content_ftrs = ftrs_calc.calc_content(output)[0]
        return content_loss_func(output_content_ftrs, input_content_ftrs)
    
    return _content_loss


class ContentLossCallback(Callback):
    def __init__(self, weight=1., ftrs_calc=None, device=None):
        self.weight = weight
        if ftrs_calc is None:
            vgg_content_layers_idx = [22]
            ftrs_calc = FeaturesCalculator([], vgg_content_layers_idx, device=device)
        self.ftrs_calc = ftrs_calc
        
    def after_loss(self):
        if self.gan_trainer.gen_mode:
            input_content_ftrs = self.ftrs_calc.calc_content(self.x)[0]
            output_content_ftrs = self.ftrs_calc.calc_content(self.pred)[0]
            loss_val = content_loss_func(output_content_ftrs, input_content_ftrs)
            # Store result inside learn.loss_func to make it visible to metrics display
            self.learn.loss_func.content_loss = loss_val            
            # This will probably stop working once a new fastai version is released, as backward 
            # won't be called on learn.loss anymore, but on learn.loss_grad
            self.learn.loss += loss_val

The following functions are only useful if you decide to use the functional version of content loss and want to combine it with Wasserstein loss:

In [None]:
def gen_wgan_loss(fake_pred, output, target): return fake_pred.mean()
def crit_wgan_loss(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()


def get_gen_wgan_content_loss(last_input_cb, content_loss_w=1.):
    content_loss = get_content_loss(last_input_cb)
    
    def _gen_wgan_content_loss(fake_pred, output, target):
        wgan_loss = gen_wgan_loss(fake_pred, output, target) 
        cont_loss = content_loss(output, target)
        return wgan_loss + content_loss_w * cont_loss
    
    return _gen_wgan_content_loss


def create_wgan_w_content_loss_learner(dls, generator, critic, cbs=None, **kwargs):
    if cbs is None: cbs = []
    last_input_cb=Callback()
    cbs.append(last_input_cb)
    gen_loss = get_gen_wgan_content_loss(last_input_cb)
    return GANLearner(dls, generator, critic, gen_loss, crit_wgan_loss,
                      cbs=cbs, **kwargs)

## Gradient penalty

In [None]:
class GANGPCallback(Callback):
    def __init__(self, plambda=10., epsilon_sampler=None): 
        self.plambda = plambda
        if epsilon_sampler is None: epsilon_sampler = random_epsilon_gp_sampler
        self.epsilon_sampler = epsilon_sampler
        
    def _gradient_penalty(self, real, fake, plambda, epsilon_sampler):
        epsilon = epsilon_sampler(real, fake)
        x_hat = epsilon * real + (1 - epsilon) * fake
        x_hat_pred = self.model.critic(x_hat).mean()

        grads = torch.autograd.grad(outputs=x_hat_pred, inputs=x_hat, create_graph=True)[0]
        return plambda * ((grads.norm() - 1)**2)    
        
    def after_loss(self):
        if not self.gan_trainer.gen_mode:
            # In critic mode, GANTrainer swaps x and y; so, here x is original y (real target)
            real = self.x
            assert not self.y.requires_grad
            fake = self.model.generator(self.y).requires_grad_(True)
            # This will probably stop working once a new fastai version is released, as backward 
            # won't be called on learn.loss anymore, but on learn.loss_grad
            self.learn.loss += self._gradient_penalty(real, fake, self.plambda, self.epsilon_sampler)


def random_epsilon_gp_sampler(real: torch.Tensor, fake: torch.Tensor) -> torch.Tensor:
    # A different random value of epsilon for any element of a batch
    epsilon_vec = torch.rand(real.shape[0], 1, 1, 1, dtype=torch.float, device=real.device, requires_grad=False)
    return epsilon_vec.expand_as(real)

------------------

# TRAINING

In [None]:
def custom_save_model(learner, filename, base_path='/kaggle/working'):
    if isinstance(base_path, str): base_path = Path(base_path)
    if not isinstance(base_path, Path): raise Exception('Invalid base_path')
    file = join_path_file(filename, base_path/learner.model_dir, ext='.pth')
    save_model(file, learner.model, learner.opt)
    
def custom_load_model(learner, filename, with_opt=True, device=None, 
                      base_path='/kaggle/input/face2anime-weights', **kwargs):
    if isinstance(base_path, str): base_path = Path(base_path)
    if not isinstance(base_path, Path): raise Exception('Invalid base_path')
    if device is None and hasattr(learner.dls, 'device'): device = learner.dls.device
    if learner.opt is None: learner.create_opt()
    #file = join_path_file(filename, base_path/learner.model_dir, ext='.pth')
    file = base_path/f'{filename}.pth'
    load_model(file, learner.model, learner.opt, with_opt=with_opt, device=device, **kwargs)
    
def predict_n(learner, n_imgs, max_bs=64):
    dummy_path = Path('')
    dl = learner.dls.test_dl([dummy_path]*n_imgs, bs=max_bs)   
    inp, imgs_t, _, dec_imgs_t = learner.get_preds(dl=dl, with_input=True, with_decoded=True)
    dec_batch = dls.decode_batch((inp,) + tuplify(dec_imgs_t), max_n=n_imgs)
    return dec_batch
    
def predict_show_n(learner, n_imgs, **predict_n_kwargs):
    preds_batch = predict_n(learner, n_imgs, **predict_n_kwargs)
    _, axs = plt.subplots(n_imgs, 2, figsize=(6, n_imgs * 3))
    for i, (inp, pred_img) in enumerate(preds_batch):
        inp.show(ax=axs[i][0])
        pred_img.show(ax=axs[i][1])

In [None]:
generator_learner = unet_learner(dls, xresnet18, normalize=True, n_out=n_channels, pretrained=False, 
                                 loss_func=lambda *args: 0, #cbs=[last_input_cb],
                                )
generator = generator_learner.model
critic = xresnet18(n_out=1)
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)

In [None]:
learn.fit(4, 2e-4, wd=0.)

In [None]:
learn.show_results(ds_idx=0)

In [None]:
learn.fit(4, 2e-4, wd=0.)

In [None]:
learn.show_results(ds_idx=0)

# EVALUATION

In [None]:
base_fid_samples_path = Path('/kaggle/working/fid_samples')
n_fid_imgs = 10000

def download_pytorch_fid_calculator():        
    #!git clone https://github.com/mseitzer/pytorch-fid.git
    !pip install pytorch-fid

def create_fid_dirs(base_fid_samples_path): 
    base_fid_samples_path.mkdir()
    (base_fid_samples_path/'fake').mkdir()
    (base_fid_samples_path/'real').mkdir()

def save_real_imgs(dls, n_imgs=10000):
    n_imgs_left = n_imgs
    while n_imgs_left > 0:
        b = dls.one_batch()
        bs = b[1].size()[0]
        dec_b = dls.decode_batch(b, max_n=bs)
        for i in range(bs):
            if n_imgs_left == 0: break
            target_img_t = dec_b[i][1]
            img = PILImage.create(target_img_t)
            img_idx = n_imgs_left-1
            img.save(base_fid_samples_path/f'real/{img_idx}.jpg')
            n_imgs_left -= 1

def save_fake_imgs(learner, n_imgs=10000, **predict_n_kwargs):
    base_path = base_fid_samples_path
    preds_batch = predict_n(learner, n_imgs, **predict_n_kwargs)
    for i, (inp, img) in enumerate(preds_batch):
        PILImage.create(img).save(base_path/f'fake/{i}.jpg')

In [None]:
!rm -R fid_samples/

In [None]:
create_fid_dirs(base_fid_samples_path)
save_fake_imgs(learn, n_imgs=n_fid_imgs)
save_real_imgs(dls, n_fid_imgs)
!ls -R fid_samples/

In [None]:
# Install fid calculator if it isn't yet
#download_pytorch_fid_calculator()

In [None]:
!python -m pytorch_fid --device cuda {base_fid_samples_path/'fake'} {base_fid_samples_path/'real'}

# TODO

* Once decent results appear, compare them with FID.
* UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
* Don't forget get_image_files can make training slower at the beginning
* Maybe, add cycle consistency loss (this would slow down training), another generator, ... (cycle-GAN)
* If no results are gotten:
  * use bigger dataset
  * use more specific anime faces dataset (it could be worse for production)
  * think about initializations
* Add/vary transforms
  