---
WGAN with gradient penalty to satisfy Lipschitz constraint.
DCGAN is used as base architecture of the networks.

---

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F
from fastai.vision import *
from fastai.vision.gan import *

You should set the following option to True if the notebook isn't located in the file system inside a clone of the git repo (with the needed Python modules available) it belongs to; i.e., it's running independently.

In [None]:
run_as_standalone_nb = False

In [None]:
# This cell needs to be executed before importing local project modules, like import core.gan
if run_as_standalone_nb:
    root_lib_path = os.path.abspath('generative-lab')
    if not os.path.exists(root_lib_path):
        !git clone https://github.com/davidleonfdez/generative-lab.git
    if root_lib_path not in sys.path:
        sys.path.insert(0, root_lib_path)
else:
    import local_lib_import

In [None]:
# Local project modules. Must be imported after local_lib_import or cloning git repo.
from core.gan import GANGPLearner, load_gan_learner, save_gan_learner
from core.nb_utils import mount_gdrive
from core.net_builders import custom_critic

`models_root` is used as the base path to save models. Next cell sets assumes the nb is being executed from Google Colab and you have a "ML" dir in Google Drive. Alternatively, you could set it manually to something like './' to save and load models to/from the current directory.

In [None]:
# Optional, allows saving parameters in gdrive
root_gdrive = mount_gdrive()
models_root = root_gdrive + 'ML/'

In [None]:
img_size = 64
img_n_channels = 3
batch_size = 128
use_cuda = torch.cuda.is_available()
# Gradient penalty coefficient
plambda = 10

# DATA

In [None]:
ds_url = "http://vis-www.cs.umass.edu/lfw/lfw"

In [None]:
realImagesPath = untar_data(ds_url)
realImagesPath

In [None]:
def get_data(path, bs, size):
    return (GANItemList.from_folder(path, noise_sz=100)
               .split_none()
               .label_from_func(noop)
               .transform(tfms=[[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], []], size=size, tfm_y=True)
               .databunch(bs=bs)
               .normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_x=False, do_y=True))

In [None]:
data = get_data(realImagesPath, batch_size, img_size)
data.show_batch()

# CRITIC

`basic_critic(in_size: int, n_channels: int, n_features: int=64, n_extra_layers: int=0, **conv_kwargs)`

In [None]:
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)

# GENERATOR

`basic_generator(in_size: int, n_channels: int, noise_sz: int=100, n_features: int=64, n_extra_layers=0, **conv_kwargs)`

In [None]:
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)

# TRAINING

* The parameters of a trained model can be saved with `save_gan_learner`.
* A training run can resumed (using weights saved during a previous session) with `load_gan_learner`. For example:
        load_gan_learner(learner, models_root + 'wgan-gpTr1a_60it.pth')
    This must be executed after instantiating the learner and BEFORE running `learner.fit()`.

* Another alternative to launch a long training run is the method `save_checkpoint_gan`. It will automatically save the weights every `n_epochs_save_split` epochs.

## TR 1: lambda = 10

In [None]:
plambda = 10

### TR 1a: lr=2e-4

In [None]:
lr = 2e-4
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr1a_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr1a_60ep.pth')

### TR 1b: lr=5e-4

In [None]:
lr = 5e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr1b_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr1b_60ep.pth')

### TR 1C: lr=5e-4, without batch norm in discriminator

In [None]:
lr = 5e-4
critic = custom_critic(img_size, img_n_channels, n_extra_layers=1, norm_type=None)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr1c_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr1c_60ep.pth')

## TR 2: lambda = 50

In [None]:
plambda = 50

### TR 2a: lr=2e-4

In [None]:
lr = 2e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr2a_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr2a_60ep.pth')

### TR 2b: lr=5e-4

In [None]:
lr = 5e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr2b_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr2b_60ep.pth')

### TR 2C: lr=5e-4, without batch norm in discriminator

In [None]:
lr = 5e-4
critic = custom_critic(img_size, img_n_channels, n_extra_layers=1, norm_type=None)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr2c_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr2c_60ep.pth')

## TR 3: lambda = 2

In [None]:
plambda = 2

### TR 3a: lr=2e-4

In [None]:
lr = 2e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr3a_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr3a_60ep.pth')

### TR 3b: lr=5e-4

In [None]:
lr = 5e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr3b_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr3b_60ep.pth')

### TR 3C: lr=5e-4, without batch norm in discriminator

In [None]:
lr = 5e-4
critic = custom_critic(img_size, img_n_channels, n_extra_layers=1, norm_type=None)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr3c_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr3c_60ep.pth')

## TR 4: lambda = 0.5

In [None]:
plambda = 0.5

### TR 4a: lr=2e-4

In [None]:
lr = 2e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr4a_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr4a_60ep.pth')

### TR 4b: lr=5e-4

In [None]:
lr = 5e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr4b_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr4b_60ep.pth')

### TR 4C: lr=5e-4, without batch norm in discriminator

In [None]:
lr = 5e-4
critic = custom_critic(img_size, img_n_channels, n_extra_layers=1, norm_type=None)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr4c_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr4c_60ep.pth')

## TR 5: lambda = 0.1

In [None]:
plambda = 0.1

### TR 5a: lr=2e-4

In [None]:
lr = 2e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr5a_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr5a_60ep.pth')

### TR 5b: lr=5e-4

In [None]:
lr = 5e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr5b_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr5b_60ep.pth')

## TR 6: lambda = 250

In [None]:
plambda = 250

### TR 6a: lr=2e-4

In [None]:
lr = 2e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr6a_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr6a_60ep.pth')

### TR 6b: lr=5e-4

In [None]:
lr = 5e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr6b_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr6b_60ep.pth')

### TR 6C: lr=5e-4, without batch norm in discriminator

In [None]:
lr = 5e-4
critic = custom_critic(img_size, img_n_channels, n_extra_layers=1, norm_type=None)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr6c_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr6c_60ep.pth')

## TR 7: lambda = 50000

In [None]:
plambda = 50000

### TR 7a: lr=2e-4

In [None]:
lr = 2e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr7a_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr7a_60ep.pth')

### TR 7b: lr=5e-4

In [None]:
lr = 5e-4
critic = basic_critic(img_size, img_n_channels, n_extra_layers=1)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr7b_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr7b_60ep.pth')

### TR 7C: lr=5e-4, without batch norm in discriminator

In [None]:
lr = 5e-4
critic = custom_critic(img_size, img_n_channels, n_extra_layers=1, norm_type=None)
generator = basic_generator(img_size, img_n_channels, n_extra_layers=1)
learner = GANGPLearner.wgan(data, generator, critic, switch_eval=False, 
                            opt_func = partial(optim.Adam, betas = (0.,0.99)), 
                            wd=0., plambda=plambda)

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr7c_30ep.pth')

In [None]:
learner.fit(30, lr)
save_gan_learner(learner, models_root + 'wgan-gpTr7c_60ep.pth')

# FINDINGS AND FACTS

*   Recommended value of lambda = 10 works pretty well. Probably anything between 1-50 is ok.
*   As stated in the paper, batch norm shouldn't be used in the critic network. So, executions tagged with TR [number]A or TR [number]B (like TR 2A, TR 3B, ...) aren't theoretically right; only those ending with C and subsequent letters are OK.
  * With batch norm, some results are gotten anyway but the convergence is less smooth.



# POSSIBLE IMPROVEMENTS


* Optimize GANGPLoss._gradient_penalty(). Maybe expand_as is not needed.

