In [None]:
%matplotlib inline

import datetime
from enum import Enum
from pathlib import Path
from time import time, perf_counter as pcounter, process_time as ptime

from IPython.display import HTML
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch._jit_internal import weak_module, weak_script_method
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd as ConvBase
from torch.nn.modules.batchnorm import _BatchNorm as BatchNormBase
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as tdata
import torchvision.datasets as dsets
import torchvision.transforms as tfs
import torchvision.utils as tvutils

class Norm(Enum):
    # file name parameter interpolation
    BATCH           = 'batch'
    VIRTUAL_BATCH   = 'virtualbatch'
    SPECTRAL        = 'spectral'
    INSTANCE        = 'instance'
    AFFINE_INSTANCE = 'affineinstance'
    NONE            = 'none'

# Parameters

In [None]:
# Paths are written in UNIX-like notation!
# So write `C:\Users\user\GANerator` as `C:/Users/user/GANerator` or `~/GANerator`.

# All parameters that take classes also accept strings of the class.

# Only parameters in the 'Data and Models' section will be saved and loaded!

params = {
    # Experiment specific
    # ===================
    'exp_name':    None,  # File names for this experiment. If `None` or `''`, `append_time` is always `True`.
    'append_time': True,  # Append the current time to the file names (to prevent overwriting).
    'load_dir':    '.',  # Directory to load saved files from. If `save_dir` is `None`, this also acts as `save_dir`.
    'save_dir':    '.',  # Directory to save to. If `None`, use the value of `load_dir`.
    
    # Load the models and parameters from this experiment (previous `exp_name`).
    # Also insert the optionally appended time (WIP: if this value is otherwise ambiguous).
    # Set the parameters `models_file` or `params_file` below to use file names.
    # If set to `True`, use `exp_name`. If `False` or `None`, do not load.
    'load_exp':          False,
    # Load parameters from this path. Set to `False` to not load. Priority over `load_exp`.
    # Set to `True` to ignore this so it does not override `load_exp`.
    'params_file':       True,
    # Load models from this path. Set to `False` to not load. Priority over `load_exp`.
    # Set to `True` to ignore this so it does not override `load_exp`.
    'models_file':       True,
    'load_weights_only': False,  # Load only the models' weights. To continue training, set this to `False`.
    
    'save_params':       False,  # Save the parameters in the 'Data and Models' section to a file.
    'save_weights_only': False,  # Save only the models' weights. To continue training later, set this to `False`.
    'checkpoint_period': 100,  # After how many steps to save a model checkpoint. Set to `0` to only save when finished.
    
    'num_eval_imgs': 64,  # How many images to generate for (temporal) evaluation.
    
    
    # Hardware and Multiprocessing
    # ============================
    'num_workers':    0,  # Amount of worker threads to create on the CPU. Set to `0` to use CPU count.
    'num_gpus':       None,  # Amount of GPUs to use. `None` to use all available ones. Set to `0` to run on CPU only.
    'cuda_device_id': 0,  # ID of CUDA device. In most cases, this should be left at `0`.

    
    # Reproducibility
    # ===============
    'seed':                   0,  # Random seed if `None`. The used seed will always be saved in `saved_seed`.
    'ensure_reproducibility': False,  # If using cuDNN: Set to `True` to ensure reproducibility in favor of performance.
    'flush_denormals':        True,  # Whether to set denormals to zero. Some architectures do not support this.
    
    
    # Data and Models
    # ===============
    # Only parameters in this section will be saved and updated when loading.
    
    # Path to the root folder of the data set. This value is only loaded if set to `None`!
    'dataset_root':  '~/datasets/ffhq',
    # Set this to the torchvision.datasets class (module `dsets`).
    # This value is only loaded if set to `None`!
    'dataset_class': dsets.ImageFolder,
    'epochs':        5,  # Number of training epochs.
    'batch_size':    128,  # Size of each training batch. Strongly depends on other parameters.
    'img_channels':  3,  # Number of channels in the input images. Normally 3 for RGB and 1 for grayscale.
    # Shape of the output images (excluding channel dimension). Can be an integer to get squares.
    # At the moment, an image can only be square sized and a power of two.
    'img_shape':     64,
    'resize':        True,  # If `True`, resize images; if `False`, crop (to the center).
    
    'data_mean':     0.0,  # Data is normalized to this mean (per channel).
    'data_std':      1.0,  # Data is normalized to this standard deviation (per channel).
    'float_dtype':   torch.float32,  # Float precision as `torch.dtype`.
    'g_input':       128,  # Size of the generator's random input vectors (`z` vector).
    
    # GAN hacks
    'g_flip_labels':       False,  # Switch labels for the generator's training step.
    'd_noisy_labels_prob': 0.0,  # Probability to switch labels when training the discriminator.
    'smooth_labels':       False,  # Replace discrete labels with slightly different continuous ones.


    # Values in this paragraph can be either a single value (e.g. an `int`) or a 2-`tuple` of the same type.
    # If a single value, that value will be applied to both the discriminator and generator network.
    # If a 2-`tuple`, the first value will be applied to the discriminator, the second to the generator.
    'features':      64,  # Relative size of the network's internal features.
    'optimizer':     optim.Adam,  # Optimizer class. GAN hacks recommends `(optim.SGD, optim.Adam)`.
    # Optimizer learning rate. (Second optimizer argument, so not necessarily learning rate.)
    'lr':            0.0002,
    # Third optimizer argument. (For example, `betas` for `Adam` or `momentum` for `SGD`.)
    'optim_param':   ((0.5, 0.999),),
    # Any further optimizer keyword arguments as a dictionary.
    'optim_kwargs':  {},
    # Kind of normalization. Must be a `Norm` or in `('b', 'v', 's', 'i', 'a', 'n')`.
    # Usually, spectral normalization is used in the discriminator while
    # virtual batch normalization is used in the generator.
    'normalization': Norm.BATCH,
    'activation':    (nn.LeakyReLU, nn.ReLU),  # Activation between hidden layers. GAN hacks recommends `nn.LeakyReLU`.
    # Activation keyword arguments.
    'activation_kwargs': ({
            'negative_slope': 0.2,
            'inplace': True
    }, {
            'inplace': True
    }),
}

In [None]:
# Process parameters

# Model parameters as tuples
tuple_params = (
    'features',
    'optimizer',
    'lr',
    'optim_param',
    'optim_kwargs',
    'normalization',
    'activation',
    'activation_kwargs',
)

# Parameters that we do *not* want to save (or load).
# We list these instead of the model parameters as those should be easier to extend.
static_params = [
    'exp_name',
    'append_time',
    'load_dir',
    'save_dir',

    'load_exp',
    'params_file',
    'models_file',
    'load_weights_only',

    'save_params',
    'save_weights_only',
    'checkpoint_period',

    'num_workers',
    'num_gpus',
    'cuda_device_id',

    'seed',
    'ensure_reproducibility',
    'flush_denormals',
]


def string_to_class(string):
    if type(string) is str:
        string = string.split('.')
        if len(string) == 1:
            m = __builtins__
        else:
            m = globals()[string[0]]
            for part in string[1:-1]:
                m = getattr(m, part)
        return getattr(m, string[-1])
    else:
        return string


# Experiment name

append_time = params['append_time']
exp_name    = params['exp_name']
if not exp_name or append_time:
    if exp_name is not str:
        exp_name = ''
    exp_name = ''.join((exp_name, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))


# Load parameters

load_dir = params['load_dir']
save_dir = params['save_dir']
if save_dir is None:
    save_dir = load_dir
    
load_exp = params['load_exp']

params_file = params['params_file']
load_params = params_file and (load_exp or type(params_file) is str)

dataset_root  = params['dataset_root']
dataset_class = string_to_class(params['dataset_class'])

# Check whether these parameters are `None`.
# If yes, check that parameters loading is enabled. Otherwise do not update them.
if dataset_root is None:
    assert load_params, '`dataset_root` cannot be `None` if not loading parameters.'
else:
    static_params.append('dataset_root')
if dataset_class is None:
    assert load_params, '`dataset_class` cannot be `None` if not loading parameters.'
else:
    static_params.append('dataset_class')


if params_file and (load_exp or type(params_file) is str):
    if type(params_file) is str:
        params_path = Path(params_file)
    elif type(load_exp) is bool:  # 
        params_path = Path('{}/params_{}.pt'.format(load_dir, exp_name))
    else:
        params_path = Path('{}/params_{}.pt'.format(load_dir, load_exp))

    params_path = params_path.expanduser()
    upd_params = torch.load(params_path)
    params.update(upd_params)
    del upd_params
elif params_file == '':
    print("`params_file` is an empty string (`''`). Parameters were not loaded. "
          'Set to `False` to suppress this warning or to `True` to let `load_exp` handle loading.')


# Hardware and multiprocessing

num_gpus       = params['num_gpus']
cuda_device_id = params['cuda_device_id']
if num_gpus is None:
    num_gpus = torch.cuda.device_count()
    print('Using {} GPUs.'.format(num_gpus))
use_gpus = num_gpus > 0
multiple_gpus = num_gpus > 1
if use_gpus:
    assert torch.cuda.is_available(), 'CUDA is not available. ' \
            'Check what is wrong or set `num_gpus` to `0` to run on CPU.'  # Never check for this again
    device = torch.device('cuda:' + str(cuda_device_id))
else:
    device = torch.device('cpu')
    
num_workers = params['num_workers']
if not num_workers:
    num_workers = mp.cpu_count()
    print('Using {} worker threads.'.format(num_workers))


# Load model

models_file = params['models_file']
models_cp = None
if models_file and (load_exp or type(models_file) is str):
    if type(models_file) is str:
        models_path = Path(models_file)
    elif type(load_exp) is bool:
        models_path = Path('{}/models_{}.tar'.format(load_dir, exp_name))
    else:
        models_path = Path('{}/models_{}.tar'.format(load_dir, load_exp))
    models_path = models_path.expanduser()
    models_cp = torch.load(models_path, map_location=device)
elif models_file == '':
    print("`models_file` is an empty string (`''`). Models were not loaded. "
          'Set to `False` to suppress this warning or to `True` to let `load_exp` handle loading.')


# Reproducibility

seed = params['seed']
if seed is None:
    seed = np.random.randint(10000)
print('Seed: {}.'.format(seed))
params['saved_seed'] = seed
np.random.seed(seed)
torch.manual_seed(seed)

ensure_reproducibility = params['ensure_reproducibility']
torch.backends.cudnn.deterministic = ensure_reproducibility
if ensure_reproducibility:
    torch.backends.cudnn.benchmark = False  # This is the default but do it anyway
    
flush_denormals = params['flush_denormals']
set_flush_success = torch.set_flush_denormal(flush_denormals)
if flush_denormals and not set_flush_success:
    print('Not able to flush denormals. `flush_denormals` set to `False`.')
    flush_denormals = False


# Dataset root

dataset_root = Path(dataset_root).expanduser()


# Floating point precision

float_dtype = string_to_class(params['float_dtype'])
if float_dtype is torch.float16:
    print('PyTorch does not support half precision well yet. Be careful and assume errors.')
torch.set_default_dtype(float_dtype)


# Parameters we do not need to process

load_weights_only = params['load_weights_only']
save_weights_only = params['save_weights_only']
checkpoint_period = params['checkpoint_period']
num_eval_imgs     = params['num_eval_imgs']

epochs       = params['epochs']
batch_size   = params['batch_size']
img_channels = params['img_channels']
resize       = params['resize']

data_mean   = params['data_mean']
data_std    = params['data_std']
g_input     = params['g_input']

g_flip_labels       = params['g_flip_labels']
d_noisy_labels_prob = params['d_noisy_labels_prob']
smooth_labels       = params['smooth_labels']

assert 0.0 <= d_noisy_labels_prob <= 1.0, \
        'Invalid probability for `d_noisy_labels`. Must be between 0 and 1 inclusively.'

# Single or tuple parameters

def param_as_ntuple(key, n=2):
    val = params[key]
    if type(val) in (tuple, list):
        assert 0 < len(val) <= n, 'Tuples should have length {} (`{}` is `{}`).'.format(n, key, val)
        if len(val) < n:
            if len(val) > 1:
                print('`{}` is `{}`. Length is less than {}; '.format(key, val, n)
                      + 'last entry has been repeated to fit length.')
            return tuple(val) + (val[-1],) * (n - len(val))
        else:
            return tuple(val)
    return (val,) * n

def ispow2(x):
    log2 = np.log2(x)
    return log2 == int(log2)


img_shape = param_as_ntuple('img_shape')
assert img_shape[0] == img_shape[1], '`img_shape` must be square (same width and height).'
assert ispow2(img_shape[0]), '`img_shape` must be a power of two (2^n).'

d_params = {}
g_params = {}
for key in tuple_params:
    d_params[key], g_params[key] = param_as_ntuple(key)


# Normalization and class parameters

for p in d_params, g_params:
    normalization = p['normalization']
    if isinstance(normalization, str) and normalization.lower() in ('b', 'v', 's', 'i', 'a', 'n'):
        normalization = {'b': Norm.BATCH, 'v': Norm.VIRTUAL_BATCH,
                         's': Norm.SPECTRAL, 'i': Norm.INSTANCE,
                         'a': Norm.AFFINE_INSTANCE, 'n': Norm.NONE}[normalization]
    if not isinstance(normalization, Norm):
        try:
            normalization = Norm(normalization)
        except ValueError:
            normalization = string_to_class(normalization)
        finally:
            assert isinstance(normalization, Norm), \
                    "Unknown normalization. Must be a `Norm` or in `('b', 'v', 's', 'i', 'a', 'n')`."
    p['normalization'] = normalization

    p['optimizer'] = string_to_class(p['optimizer'])
    p['activation'] = string_to_class(p['activation'])


save_models_path_str = '{}/models_{}_{{}}_steps.tar'.format(save_dir, exp_name)


# Save parameters

save_params = params['save_params']
if save_params:
    # We save even if we load to associate the parameters with the experiment
    save_params_path = Path('{}/params_{}.pt'.format(save_dir, exp_name)).expanduser()
    save_params_ = params.copy()
    for key in static_params:
        del save_params_[key]
    torch.save(save_params_, save_params_path)
    del save_params_

In [None]:
# Prepare dataset

plt.rcParams['figure.figsize'] = (12, 9)  # Larger plots by default

tfs_list = [
    tfs.Resize(img_shape),
    tfs.ToTensor(),
    tfs.Normalize((data_mean,) * img_channels, (data_std,) * img_channels)
]
if not resize:
    tfs_list[0] = tfs.CenterCrop(img_shape)
transform = tfs.Compose(tfs_list)
dataset = dataset_class(dataset_root, transform=transform)

dataloader = tdata.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)

In [None]:
# Show an example batch of training images

show_example = True

example_cols = int(np.sqrt(num_eval_imgs))

try:
    example_batch
except NameError:
    # Only define `example_batch` once
    example_batch = next(iter(dataloader))
    static_noise = torch.randn(num_eval_imgs, g_input, 1, 1, device=device).to(device, float_dtype)
    example_noise = torch.randn(batch_size, g_input, 1, 1, device=device)

if show_example:
    plt.figure()
    plt.axis('off')
    plt.title('Example Training Batch')
    plt.imshow(np.transpose(tvutils.make_grid(example_batch[0].to(device, float_dtype)[:num_eval_imgs],
                                              nrow=example_cols, padding=2, normalize=True).cpu(),  # show
                            (1, 2, 0)))  # show

In [None]:
# Model helper methods

@weak_module
class VirtualBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.reset_parameters(True)

    def reset_parameters(self, all=False):
        if self.affine:
            nn.init.uniform_(self.weight)
            nn.init.zeros_(self.bias)
        if all:
            self.in_coef = None
            self.ref_coef = None

    @weak_script_method
    def forward(self, input, ref_batch):
        self._check_input_dim(input)
        if self.in_coef is None:
            self._check_input_dim(ref_batch)
            self.in_coef = 1 / (len(ref_batch) + 1)
            self.ref_coef = 1 - self.in_coef

        mean, std, ref_mean, ref_std = self.calculate_statistics(input, ref_batch)
        return self.normalize(input, mean, std), self.normalize(ref_batch, ref_mean, ref_std)
    
    @weak_script_method
    def calculate_statistics(self, input, ref_batch):
        in_mean,  in_sqmean  = self.calculate_means(input)
        ref_mean, ref_sqmean = self.calculate_means(ref_batch)
        
        mean   = self.in_coef * in_mean   + self.ref_coef * ref_mean
        sqmean = self.in_coef * in_sqmean + self.ref_coef * ref_sqmean
        
        std     = torch.sqrt(sqmean     - mean**2     + self.eps)
        ref_std = torch.sqrt(ref_sqmean - ref_mean**2 + self.eps)
        return mean, std, ref_mean, ref_std

    # TODO could be @staticmethod, but check @weak_script_method first
    @weak_script_method
    def calculate_means(self, batch):
        mean   = torch.mean(batch,    0, keepdim=True)
        sqmean = torch.mean(batch**2, 0, keepdim=True)
        return mean, sqmean

    @weak_script_method
    def normalize(self, batch, mean, std):
        return ((batch - mean) / std) * self.weight + self.bias
    
    @weak_script_method
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))


def powers(n, b=2):
    """Yield `n` powers of `b` starting from `b**0`."""
    x = 1
    for i in range(n):
        x_old = x
        x *= b
        yield x_old, x


def layer_with_norm(layer, norm, features):
    if norm is Norm.BATCH:
        return (layer, nn.BatchNorm2d(features))
    elif norm is Norm.VIRTUAL_BATCH:
        return (layer, VirtualBatchNorm2d(features))
    elif norm is Norm.SPECTRAL:
        return (nn.utils.spectral_norm(layer),)
    elif norm is Norm.INSTANCE:
        return (layer, nn.InstanceNorm2d(features))
    elif norm is Norm.AFFINE_INSTANCE:
        return (layer, nn.InstanceNorm2d(features, affine=True))
    elif norm is Norm.NONE:
        return (layer,)
    else:
        raise ValueError("Unknown normalization `'{}'`".format(norm))

In [None]:
# Define and initialize models

# Discriminator

class Discriminator(nn.Module):
    def __init__(self, normalization, activation, activation_kwargs,
                 img_channels, img_shape, features, reference_batch=None):
        super().__init__()
        self.layers = self.build_layers(normalization, activation, activation_kwargs,
                                        img_channels, img_shape, features)
        if normalization is not Norm.VIRTUAL_BATCH:
            self.reference_batch = None  # we can test for VBN with this invariant
            self.layers = nn.Sequential(*self.layers)
        elif reference_batch is None:
            raise ValueError('Normalization is virtual batch norm, but '
                    '`reference_batch` is `None` or missing.')
        else:
            self.reference_batch = reference_batch  # never `None`
            self.layers = nn.ModuleList(self.layers)

    @staticmethod
    def build_layers(norm, activation, activation_kwargs, img_channels, img_shape, features):
        """
        Return a list of the layers for the discriminator network.

        Example for a 64 x 64 image:
        >>> Discriminator.build_layers(Norm.BATCH, nn.LeakyReLU, {'negative_slope': 0.2, 'inplace': True},
                                       img_channels=3, img_shape=(64, 64), features=64)
        [
            # input size is 3 x 64 x 64 (given by `img_channels` and `img_shape`)
            nn.Conv2d(img_channels, features, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, True),
            # state size is (features) x 32 x 32
            nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2, True),
            # state size is (features * 2) x 16 x 16
            nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2, True),
            # state size is (features * 4) x 8 x 8
            nn.Conv2d(features * 4, features * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2, True),
            # state size is (features * 8) x 4 x 4
            nn.Conv2d(features * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # output size is 1 (scalar value)
        ]
        """
        # input size is (img_channels) x (img_shape[0]) x (img_shape[1])
        layers = [
            nn.Conv2d(img_channels, features, 4, 2, 1, bias=False),
            activation(**activation_kwargs)
        ]
        # state size is (features) x (img_shape[0] / 2) x (img_shape[1] / 2)
        # each further layer doubles feature size and halves image size
        for i, j in powers(int(np.log2(img_shape[0])) - 3):
            layers.extend((
                *layer_with_norm(nn.Conv2d(features * i, features * j, 4, 2, 1, bias=False),
                                 norm, features * j),
                activation(**activation_kwargs)
            ))
        # state size is (features * 2^n) x 4 x 4
        layers.extend((
            nn.Conv2d(features * j, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        ))
        # output size is 1 (scalar value)
        return layers

    @weak_script_method
    def forward(self, input):
        # Separation is for performance reasons
        if self.reference_batch is None:
            return self.layers(input)
        else:
            # VBN
            ref_batch = self.reference_batch
            for layer in self.layers:
                if not isinstance(layer, VirtualBatchNorm2d):
                    input     = layer(input)
                    ref_batch = layer(ref_batch)
                else:
                    input, ref_batch = layer(input, ref_batch)
            return input


# Generator

class Generator(nn.Module):
    def __init__(self, normalization, activation, activation_kwargs,
                 img_channels, img_shape, features, g_input, reference_batch=None):
        super().__init__()
        self.layers = self.build_layers(normalization, activation, activation_kwargs, img_channels, img_shape, features, g_input)
        if normalization is not Norm.VIRTUAL_BATCH:
            self.reference_batch = None  # we can test for VBN with this invariant
            self.layers = nn.Sequential(*self.layers)
        elif reference_batch is None:
            raise ValueError('Normalization is virtual batch norm, but '
                    '`reference_batch` is `None` or missing.')
        else:
            self.reference_batch = reference_batch  # never `None`
            self.layers = nn.ModuleList(self.layers)

    @staticmethod
    def build_layers(norm, activation, activation_kwargs, img_channels, img_shape, features, g_input):
        """
        Return a list of the layers for the generator network.

        Example for a 64 x 64 image:
        >>> Generator.build_layers(Norm.BATCH, nn.ReLU, {'inplace': True},
                                   img_channels=3, img_shape=(64, 64), features=64, g_input=128)
        [
            # input size is 128 (given by `g_input`)
            nn.ConvTranspose2d(g_input, features * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features * 8),
            nn.ReLU(True),
            # state size is (features * 8) x 4 x 4
            nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.ReLU(True),
            # state size is (features * 4) x 8 x 8
            nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.ReLU(True),
            # state size is (features * 2) x 16 x 16
            nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(True),
            # state size is (features) x 32 x 32
            nn.ConvTranspose2d(features, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # output size is 3 x 64 x 64 (given by `img_channels` and `img_shape`)
        ]
        """
        j = 2 ** (int(np.log2(img_shape[0])) - 3)
        # input size is (g_input)
        layers = [
            *layer_with_norm(nn.ConvTranspose2d(g_input, features * j, 4, 1, 0, bias=False),
                             norm, features * j),
            activation(**activation_kwargs)
        ]
        # state size is (features * 2^n) x 4 x 4
        # each further layer halves feature size and doubles image size
        while j > 1:
            i = j
            j //= 2
            layers.extend((
                *layer_with_norm(nn.ConvTranspose2d(features * i, features * j, 4, 2, 1, bias=False),
                                 norm, features * j),
                activation(**activation_kwargs)
            ))
        # state size is (features) x (img_shape[0] / 2) x (img_shape[1] / 2)
        layers.extend((
            nn.ConvTranspose2d(features, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        ))
        # output size is (img_channels) x (img_shape[0]) x (img_shape[1])
        return layers

    @weak_script_method
    def forward(self, input):
        # Separation is for performance reasons
        if self.reference_batch is None:
            return self.layers(input)
        else:
            # VBN
            ref_batch = self.reference_batch
            for layer in self.layers:
                if not isinstance(layer, VirtualBatchNorm2d):
                    input     = layer(input)
                    ref_batch = layer(ref_batch)
                else:
                    input, ref_batch = layer(input, ref_batch)
            return input


# Initialization
    
def init_weights(module):
    if isinstance(module, ConvBase):
        nn.init.normal_(module.weight.data, 0.0, 0.02)
    elif isinstance(module, BatchNormBase):
        nn.init.normal_(module.weight.data, 1.0, 0.02)
        nn.init.constant_(module.bias.data, 0)


d_net = Discriminator(d_params['normalization'], d_params['activation'], d_params['activation_kwargs'],
                      img_channels, img_shape, d_params['features'],
                      example_batch[0].to(device, float_dtype)).to(device, float_dtype)
g_net = Generator(g_params['normalization'], g_params['activation'], g_params['activation_kwargs'],
                  img_channels, img_shape, g_params['features'], g_input,
                  example_noise.to(device, float_dtype)).to(device, float_dtype)

# Load models' checkpoints

if models_cp is not None:
    d_net.load_state_dict(models_cp['d_net_state_dict'])
    g_net.load_state_dict(models_cp['g_net_state_dict'])

if multiple_gpus:
    d_net = nn.DataParallel(d_net, list(range(num_gpus)))
    g_net = nn.DataParallel(g_net, list(range(num_gpus)))

if models_cp is None:
    d_net.apply(init_weights)
    g_net.apply(init_weights)

In [None]:
# Optimizers

criterion = nn.BCELoss()

real_label = 1
fake_label = 0

d_optim_cls = d_params['optimizer']
g_optim_cls = g_params['optimizer']

d_optimizer = d_optim_cls(d_net.parameters(), d_params['lr'], d_params['optim_param'], **d_params['optim_kwargs'])
g_optimizer = g_optim_cls(g_net.parameters(), g_params['lr'], g_params['optim_param'], **g_params['optim_kwargs'])

# Load optimizers' checkpoints

if models_cp is not None:
    if not load_weights_only:
        try:
            d_optim_state_dict = models_cp['d_optim_state_dict']
            g_optim_state_dict = models_cp['g_optim_state_dict']
        except KeyError:
            print("One of the optimizers' state dicts was not found; probably because "
                  "only the models' weights were saved. Set `load_weights_only` to `True`.")
        d_optimizer.load_state_dict(d_optim_state_dict)
        g_optimizer.load_state_dict(g_optim_state_dict)
        d_net.train()
        g_net.train()
    else:
        d_net.eval()
        g_net.eval()

In [None]:
# Helpers for training

if models_cp is not None and not load_weights_only:
    start_epoch = models_cp['epoch']
    steps       = models_cp['steps']
    start_i     = models_cp['i']
    d_losses    = models_cp['d_losses']
    g_losses    = models_cp['g_losses']
    eval_imgs   = models_cp['eval_imgs']
else:
    start_epoch = 0
    steps       = 0
    start_i     = 0
    d_losses    = []
    g_losses    = []
    eval_imgs   = []


def save_checkpoint(save_models_path_str, save_weights_only, multiple_gpus,
                    epoch, steps, i, d_losses, g_losses, eval_imgs,
                    d_optimizer, g_optimizer, d_net, g_net):
    if not save_weights_only:
        save_cp = {
            'epoch': epoch,
            'steps': steps + 1,
            'i': i + 1,
            'd_losses': d_losses,
            'g_losses': g_losses,
            'eval_imgs': eval_imgs,
            'd_optim_state_dict': d_optimizer.state_dict(),
            'g_optim_state_dict': g_optimizer.state_dict(),
        }
    else:
        save_cp = {}

    if multiple_gpus:
        save_cp['d_net_state_dict'] = d_net.module.state_dict()
        save_cp['g_net_state_dict'] = g_net.module.state_dict()
    else:
        save_cp['d_net_state_dict'] = d_net.state_dict()
        save_cp['g_net_state_dict'] = g_net.state_dict()
    torch.save(save_cp, Path(save_models_path_str.format(steps + 1)).expanduser())


def generate_labels(curr_batch_size, actual_label, switch_label,
                    device, float_dtype, smooth_labels, noisy_labels_prob):
    if np.random.rand() >= noisy_labels_prob:
        label = actual_label
    else:
        # Flip labels
        label = switch_label
    labels = torch.full((curr_batch_size,), label, device=device).to(device, float_dtype)
    if smooth_labels:
        label_noise = torch.empty_like(labels).to(device, float_dtype).uniform_(-0.2, 0.2)
        return labels + label_noise, label_noise
    return labels, None


def fill_labels(labels, label_noise, actual_label, switch_label,
                smooth_labels, noisy_labels_prob):
    if np.random.rand() >= noisy_labels_prob:
        label = actual_label
    else:
        # Flip labels
        label = switch_label
    labels.fill_(label)
    if smooth_labels:
        label_noise.uniform_(-0.2, 0.2)
        return labels + label_noise, label_noise
    return labels, None

In [None]:
# Training

training_start         = time()
training_start_process = ptime()
cp_start               = pcounter()
cp_start_process       = training_start_process

for epoch in range(start_epoch, epochs):
    for i, data in enumerate(dataloader):
        # Skip until at resume point
        if start_i == len(dataloader):
            start_i = 0
            continue
        elif i < start_i:
            continue
        start_i = 0
        
        
        # Discriminator training step
        # ---------------------------
        
        # Train with all-real batch
        d_net.zero_grad()
        # Format batch
        reals = data[0].to(device, float_dtype)
        curr_batch_size = reals.size(0)
        labels, label_noise = generate_labels(curr_batch_size, real_label, fake_label,
                                              device, float_dtype, smooth_labels, d_noisy_labels_prob)
        # Classify
        outputs = d_net(reals).view(-1)
        # Calculate loss
        d_loss_reals = criterion(outputs, labels)
        # Calculate gradients
        d_loss_reals.backward()
        D_x = outputs.mean().item()
        
        # Train with all-fake batch
        # Generate fakes
        noise = torch.randn(curr_batch_size, g_input, 1, 1, device=device).to(device, float_dtype)
        fakes = g_net(noise)
        fill_labels(labels, label_noise, fake_label, real_label, smooth_labels, d_noisy_labels_prob)
        # Classify
        outputs = d_net(fakes.detach()).view(-1)
        # Calculate loss
        d_loss_fakes = criterion(outputs, labels)
        # Calculate gradients
        d_loss_fakes.backward()
        D_G_z1 = outputs.mean().item()
        
        # Calculate total loss
        d_loss = d_loss_reals + d_loss_fakes
        # Update
        d_optimizer.step()
        
        
        # Generator training step
        # -----------------------
        
        g_net.zero_grad()
        # `real_label` as the actual label since the fakes are "real" to the generator
        # TODO is it correct to do this here or do we do it for the discriminator output instead?
        fill_labels(labels, label_noise, real_label, fake_label, smooth_labels, int(g_flip_labels))
        # Use updated D for fake classification
        outputs = d_net(fakes).view(-1)
        # Calculate loss
        g_loss = criterion(outputs, labels)
        # Calculate gradients
        g_loss.backward()
        D_G_z2 = outputs.mean().item()
        # Update
        g_optimizer.step()

        # Store losses
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())


        # Output training stats
        if i % 50 == 0:
            cp_end = pcounter()
            cp_end_process = ptime()
            print('[{}/{}][{}/{}]\tLoss_D: {:.4f}\tLoss_G: {:.4f}\t'
                  'D(x): {:.4f}\tD(G(z)): {:.4f} / {:.4f}\tTime: {:.1f} s\tPtime: {:.1f} s'.format(
                  epoch, epochs, i, len(dataloader),
                  d_loss.item(), g_loss.item(), D_x, D_G_z1, D_G_z2,
                  cp_end - cp_start, cp_end_process - cp_start_process))
            cp_start         = cp_end
            cp_start_process = cp_end_process

        # Check the generator's progress by saving its output(s) on `static_noise`
        if steps % 500 == 0 or epoch == epochs-1 and i == len(dataloader)-1:
            with torch.no_grad():
                fakes = g_net(static_noise).detach().cpu()
            eval_imgs.append(tvutils.make_grid(fakes, nrow=example_cols, padding=2, normalize=True))

        # Save checkpoint
        if checkpoint_period != 0 and steps % checkpoint_period == 0:
            save_checkpoint(save_models_path_str, save_weights_only, multiple_gpus,
                            epoch, steps, i, d_losses, g_losses, eval_imgs,
                            d_optimizer, g_optimizer, d_net, g_net)

        steps += 1

# Save after training is finished
save_checkpoint(save_models_path_str, save_weights_only, multiple_gpus,
                epoch, steps, i, d_losses, g_losses, eval_imgs,
                d_optimizer, g_optimizer, d_net, g_net)


training_end         = time()
training_end_process = ptime()

total_time         = int(training_end - training_start)
total_time_process = int(training_end_process - training_start_process)
mins         = total_time // 60
mins_process = total_time_process // 60

print('{} training steps finished after {} minutes and {} seconds of real time and '
      '{} minutes and {} seconds of process time.'.format(
      steps - start_epoch, mins, total_time - mins * 60,
      mins_process, total_time_process - mins_process * 60))

# Results

In [None]:
# Loss plot

plt.figure()
plt.title('Discriminator and Generator Losses During Training')
plt.plot(d_losses,label='D')
plt.plot(g_losses,label='G')
plt.xlabel('steps')
plt.ylabel('loss')
plt.legend()
plt.show()

In [None]:
# Compare real and fake

# Use example batch from above
plt.figure()
plt.subplot(1, 2, 1)
plt.axis('off')
plt.title('Real Images')
plt.imshow(np.transpose(tvutils.make_grid(example_batch[0].to(device)[:num_eval_imgs],
                                          nrow=example_cols, padding=2, normalize=True).cpu(),  # show
                        (1, 2, 0)))  # show

# Fake images from the last epoch
plt.subplot(1, 2, 2)
plt.axis('off')
plt.title('Fake Images')
plt.imshow(np.transpose(eval_imgs[-1].cpu(), (1, 2, 0)))
plt.show()

In [None]:
save_animation_path = Path('{}/eval_animation_{}.gif'.format(save_dir, exp_name)).expanduser()

print('{} images to be animated.'.format(len(eval_imgs)))

# Animated evaluation of images on static noise
fig = plt.figure(figsize=((13 * img_shape[0]) // 64,) * 2)  # Heuristically chosen size
plt.axis('off')
plt.title('Example Images During Training')
ims = [[plt.imshow(np.transpose(img.cpu(), (1, 2, 0)), animated=True)] for img in eval_imgs]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
ani.save(save_animation_path)

In [None]:
# WARNING! Can heavily slow down the browser!
HTML(ani.to_jshtml())