In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.vision import *
from fastai.vision.gan import *
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

You should set the following option to True if the notebook isn't located in the file system inside a clone of the git repo (with the needed Python modules available) it belongs to; i.e., it's running independently.

In [None]:
run_as_standalone_nb = False

In [None]:
# This cell needs to be executed before importing local project modules, like import core.gan
if run_as_standalone_nb:
    root_lib_path = os.path.abspath('generative-lab')
    if not os.path.exists(root_lib_path):
        !git clone https://github.com/davidleonfdez/generative-lab.git
    if root_lib_path not in sys.path:
        sys.path.insert(0, root_lib_path)
else:
    import local_lib_import

In [None]:
# Local project modules. Must be imported after local_lib_import or cloning git repo.
from core.gan import (CustomGANLearner, GANLossArgs, gan_loss_from_func, gan_loss_from_func_std, 
                      load_gan_learner, save_gan_learner, train_checkpoint_gan)
# from core.gen_utils import RandomProbability, SingleProbability
from core.nb_utils import mount_gdrive
from core.net_builders import interpolation_generator

`models_root` is used as the base path to save models. Next cell sets assumes the nb is being executed from Google Colab and you have a "ML" dir in Google Drive. Alternatively, you could set it manually to something like './' to save and load models to/from the current directory.

In [None]:
# Optional, allows saving parameters in gdrive
root_gdrive = mount_gdrive()
models_root = root_gdrive + 'ML/'

In [None]:
img_size = 64
img_n_channels = 3
batch_size = 64
use_cuda = torch.cuda.is_available()

# DATA

In [None]:
ds_url = "http://vis-www.cs.umass.edu/lfw/lfw"

In [None]:
realImagesPath = untar_data(ds_url)
realImagesPath

In [None]:
sampleImg1Path = realImagesPath/'Aaron_Eckhart/Aaron_Eckhart_0001.jpg'

In [None]:
im = Image.open(sampleImg1Path)
im.size

In [None]:
from IPython.display import Image
Image(filename=str(sampleImg1Path))

In [None]:
def get_data(path, bs, size):
    return (GANItemList.from_folder(path, noise_sz=100)
               .split_none()
               .label_from_func(noop)
               .transform(tfms=[[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], []], size=size, tfm_y=True)
               .databunch(bs=bs)
               .normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_x=False, do_y=True))

In [None]:
data = get_data(realImagesPath, batch_size, img_size)
data.show_batch()

# GENERATOR

Generator uses interpolation followed with a regular convolution to upsample, instead of the traditional transpose convolution, in order to avoid checkerboard artifacts, as proposed here: https://distill.pub/2016/deconv-checkerboard/.

Input is bs x noise_sz * 1 * 1<br>
**n_features** is the number of feature maps (so kernels) generated after penultimate layer (the last layer of course outputs n_channels) if n_extra_layers = 0 . At the beginning there will be n_features * 2^(n_intermediate_conv_blocks), and this number will be reduced by half in any subsequent layer.

```
interpolation_generator(in_size:int, n_channels:int, noise_sz:int=100,  
                        n_features:int=64, n_extra_layers:int=0, 
                        dense:bool=False, upsample_mode:str='bilinear', 
                        **conv_kwargs) -> nn.Module:
```

In [None]:
generator = interpolation_generator(img_size, img_n_channels)

Test the generator. Should return 2 x img_n_channels x img_size x img_size

In [None]:
generator(torch.rand(2, 100, 1, 1)).size()

# CRITIC

Basic critic

**n_features** is the number of feature maps (so kernels) generated after first layer (from the n_channels of the input). This number will be doubled in any subsequent layer.

`basic_critic(in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, **conv_kwargs)`

In [None]:
critic = basic_critic(img_size, img_n_channels)

Test the critic. Should return [1].

In [None]:
critic(torch.rand(2, 3, 64, 64)).size()

# GAN LEARNER

In [None]:
def gen_loss_func(*args): return 0
crit_loss_func = nn.BCEWithLogitsLoss()

losses = gan_loss_from_func_std(gen_loss_func, crit_loss_func)

learner = CustomGANLearner(data, generator, critic, GANLossArgs(*losses))

# TRAINING

* The parameters of a trained model can be saved with `save_gan_learner`.
* A training run can resumed (using weights saved during a previous session) with `load_gan_learner`. For example:
        load_gan_learner(learner, models_root + 'interpBilinearGANTr1_40ep.pth')
    This must be executed after instantiating the learner and BEFORE running `learner.fit()`.

* Another alternative to launch a long training run is the method `save_checkpoint_gan`. It will automatically save the weights every `n_epochs_save_split` epochs.

## Bilinear interpolation

### TRAINING 1: lr=1e-4, wd=0, Adam(beta1=0, beta2=0.99)

In [None]:
lr = 1e-4
data = get_data(realImagesPath, batch_size, img_size)
generator = interpolation_generator(img_size, img_n_channels)
critic = basic_critic(img_size, img_n_channels)
learner = CustomGANLearner(data, generator, critic, GANLossArgs(*losses), switch_eval=False, 
                           opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learner.fit(20, lr)

In [None]:
learner.fit(20, lr)

In [None]:
save_gan_learner(learner, models_root + 'interpBilinearGANTr1_40ep.pth')

### TRAINING 2: lr=2e-4, wd=0, Adam(beta1=0, beta2=0.99)

In [None]:
lr = 2e-4
data = get_data(realImagesPath, batch_size, img_size)
generator = interpolation_generator(img_size, img_n_channels)
critic = basic_critic(img_size, img_n_channels)
learner = CustomGANLearner(data, generator, critic, GANLossArgs(*losses), switch_eval=False, 
                           opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learner.fit(20, lr)

In [None]:
learner.fit(20, lr)

In [None]:
save_gan_learner(learner, models_root + 'interpBilinearGANTrB2_40ep.pth')

### TRAINING 3: lr=2e-4, wd=0, Adam(beta1=0, beta2=0.99), 1 extra layer

In [None]:
lr = 2e-4
data = get_data(realImagesPath, batch_size, img_size)
generator = interpolation_generator(img_size, img_n_channels, n_extra_layers=1)
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
learner = CustomGANLearner(data, generator, critic, GANLossArgs(*losses), switch_eval=False, 
                           opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learner.fit(20, lr)

In [None]:
learner.fit(20, lr)

In [None]:
save_gan_learner(learner, models_root + 'interpBilinearGANTrB3_40ep.pth')

In [None]:
train_checkpoint_gan(learner, 360, initial_epoch=40, filename_start='interpBilinearGANTrB3_', lr=lr)

In [None]:
learner.show_results(ds_type=DatasetType.Train)

### TRAINING 4: WGAN lr=2e-4, wd=0, Adam(beta1=0, beta2=0.99), 1 extra layer

In [None]:
lr = 2e-4
data = get_data(realImagesPath, batch_size, img_size)
generator = interpolation_generator(img_size, img_n_channels, n_extra_layers=1)
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
learner = CustomGANLearner.wgan(data, generator, critic, switch_eval=False, 
                                opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learner.fit(20, lr)

In [None]:
learner.fit(20, lr)

In [None]:
save_gan_learner(learner, models_root + 'interpBilinearGANTrB4_40ep.pth')

In [None]:
train_checkpoint_gan(learner, 360, initial_epoch=40, filename_start='interpBilinearGANTrB4_', lr=2e-4)

In [None]:
learner.show_results(ds_type=DatasetType.Train)

## Bicubic interpolation

### TRAINING 1: lr=2e-4, wd=0, Adam(beta1=0, beta2=0.99), 1 extra layer

In [None]:
lr = 2e-4
data = get_data(realImagesPath, batch_size, img_size)
generator = interpolation_generator(img_size, img_n_channels, n_extra_layers=1, upsample_mode='bicubic')
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
learner = CustomGANLearner(data, generator, critic, GANLossArgs(*losses), switch_eval=False, 
                           opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learner.fit(40, lr)

In [None]:
save_gan_learner(learner, models_root + 'interpBicubicGANTrB1_40ep.pth')

In [None]:
train_checkpoint_gan(learner, 360, initial_epoch=40, filename_start='interpBicubicGANTrB1_', lr=lr)

In [None]:
learner.show_results(ds_type=DatasetType.Train)

### TRAINING 2: WGAN lr=2e-4, wd=0, Adam(beta1=0, beta2=0.99), 1 extra layer

In [None]:
lr = 2e-4
data = get_data(realImagesPath, batch_size, img_size)
generator = interpolation_generator(img_size, img_n_channels, n_extra_layers=1, upsample_mode='bicubic')
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
learner = CustomGANLearner.wgan(data, generator, critic, switch_eval=False, 
                                opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

In [None]:
learner.fit(20, lr)

In [None]:
learner.fit(20, lr)

In [None]:
save_gan_learner(learner, models_root + 'interpBicubicGANTrB2_40ep.pth')

In [None]:
train_checkpoint_gan(learner, 360, initial_epoch=40, filename_start='interpBicubicGANTrB2_', lr=lr)

In [None]:
learner.show_results(ds_type=DatasetType.Train)