In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F
import fastai
from fastai.vision import *
from fastai.vision.gan import *

You should set the following option to True if the notebook isn't located in the file system inside a clone of the git repo (with the needed Python modules available) it belongs to; i.e., it's running independently.

In [None]:
run_as_standalone_nb = False

In [None]:
# This cell needs to be executed before importing local project modules, like import core.gan
if run_as_standalone_nb:
    root_lib_path = os.path.abspath('generative-lab')
    if not os.path.exists(root_lib_path):
        !git clone https://github.com/davidleonfdez/generative-lab.git
    if root_lib_path not in sys.path:
        sys.path.insert(0, root_lib_path)
else:
    import local_lib_import

In [None]:
# Local project modules. Must be imported after local_lib_import or cloning git repo.
from core.biggan import biggan_disc_64, biggan_gen_64, BigGANItemList, BigGANGenImagesSampler
from core.gan import GANGPLearner, GANLossArgs, GeneratorFuncStateLoader, save_gan_learner
from core.layers import AvgFlatten
from core.losses import hinge_adversarial_losses, loss_func_with_kernel_regularizer, OrthogonalRegularizer
from core.gan_metrics import evaluate_models_fid, EvaluationItem, FIDCalculator
from core.gen_utils import PrinterProgressTracker

In [None]:
# Point this variable to the path where you want to save your models
models_root = Path('./')

In [None]:
img_size = 64
img_n_channels = 3
batch_size = 128 # Std is 512-2048
ch_mult = 32 # Std is 64-96

In [None]:
# Disable occasional annoying warnings produced by libraries using pytorch, which 
# may collapse the output during data loading or training
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")

# DATA

Set `real_images_path` to the location of the dataset you want to work with.  If needed, as a previous step, fastai provides the method `untar_data` to download and extract a dataset from a remote URL.

In [None]:
real_images_path = Path('/kaggle/input/celeba-dataset/img_align_celeba/')
real_images_path

In [None]:
def get_data(path, bs, size, noise_sz=100):
    return (BigGANItemList.from_folder(path, noise_sz=noise_sz)
               .split_none()
               .label_from_func(noop)
               .transform(tfms=[[crop_pad(size=size, row_pct=0.5, col_pct=0.5)], []], size=size, tfm_y=True)
               .databunch(bs=bs)
               .normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_x=False, do_y=True))

In [None]:
data = get_data(real_images_path, batch_size, img_size)
data.show_batch()

# CRITIC

The simplest method to build a BigGAN discriminator is:

```
biggan_disc_64(in_n_channels: int=3, ch_mult: int=96, **disc_kwargs)`
```

It creates an architecture with the same depth and feature maps as in the paper.

The parameters `disc_kwargs` are passed through to `BigGANDiscriminator` constructor:

```
BigGANDiscriminator(in_sz: int, res_blocks_n_ftrs: List[Tuple[int, int]], 
                    idx_block_self_att: int, n_classes: int=1, 
                    down_op: core.layers.DownsamplingOperation2d=None, 
                    activ: nn.Module=None)
```

In [None]:
critic = biggan_disc_64(img_n_channels, ch_mult)

In [None]:
# Return just one element per batch, as required by GAN loss management
critic = nn.Sequential(critic, AvgFlatten())

# GENERATOR

The simplest method to build a BigGAN generator is:

```
biggan_gen_64(out_n_channels: int=3, ch_mult: int=96, **gen_kwargs)
```

It creates an architecture with the same depth and feature maps as in the paper.

The parameters `gen_kwargs` are passed through to `BigGANGenerator` constructor:

```
BigGANGenerator(out_sz: int, out_n_channels: int, up_blocks_n_ftrs: List[Tuple[int, int]], 
                z_split_sz: int=20, n_classes: int=1, class_embedding_sz: int=128, 
                up_op: core.layers.UpsamplingOperation2d=None)
```

In [None]:
generator = biggan_gen_64(img_n_channels, ch_mult)

# LEARNER

In [None]:
g_loss, d_loss = hinge_adversarial_losses()
g_loss_reg = loss_func_with_kernel_regularizer(g_loss, 
                                               OrthogonalRegularizer(generator))

In [None]:
gp_lambda = 0.1
learner = GANGPLearner(data, generator, critic, GANLossArgs(g_loss, d_loss),
                       opt_func=partial(optim.Adam, betas=(0.,0.999)), wd=0.,
                       switch_eval=False, plambda=gp_lambda)

# TRAINING

In [None]:
lr = 5e-4

In [None]:
learner.fit(10, lr)
save_gan_learner(learner, models_root/'biggan-celeba-tr1-10ep.pth')

In [None]:
learner.show_results(ds_type=DatasetType.Train)

# RESULTS EVALUATION

In [None]:
calculator = FIDCalculator()

In [None]:
model_ids = ['1']
n_epochs = 10

In [None]:
n_total_imgs = 10000
n_imgs_by_group = 500

models = [EvaluationItem(model_id, biggan_gen_64, [img_n_channels, ch_mult], {})
          for model_id in model_ids]

def resolve_state_path(model_id:str):
    return models_root/f'biggan-celeba-tr{model_id}-{n_epochs}ep.pth'

results = evaluate_models_fid(models, data, GeneratorFuncStateLoader(resolve_state_path),
                              n_total_imgs, n_imgs_by_group, calculator, PrinterProgressTracker(),
                              BigGANGenImagesSampler)