# Training DCGAN
**Adapted from Alex Andonian's notebook of Gan Training**

## Introduction

This tutorial will give a brief introduction to training Generative Adversarial Networks (GANs), a rapidly improving field. GANs are still fairly young, first being developed on May 26, 2014. The rapid success of GANs' in image synthesis is particularly impressive.

### Intro to GANs

A **GAN**, Generative Adversarial Network, is a type of **machine learning** that uses a strategy involving two competing (adversarial) networks. The objective of the generator is to learn how to generate realistic looking images to trick the second **discriminator network** into thinking it is a real image. In the beginning an unlearned generator will produce random outputs, but as the network continues to run the accuracy of its generated images will improve and look more similar to the original dataset. Generally the discriminator network is pretrained on the dataset, however, it is also trained overtime.  

Each time the generator generates new images, the discriminator 'gives feedback' on where the generator went wrong and provides confirmation about the areas the generator is doing well to teach it how to make better outputs. 

The generator network is training to generate an image near indistinguishable to a true image to "trick" the discriminator, while the discriminator is simultaneously training to increase its accuracy of distinguishing a generated image from a real image of the dataset. This competition allows the networks to continuously improve their accuracy with each run of the system.

##### What does an image look like to a computer and how does it learn things about an image?
To understand how GANs work, we first have to understand how an image is interpreted by a computer. To a computer, an image is a series of numbers. All pixels of the image are given a number and listed according to its dimensions. This list of numbers informs the computer about the individual pixel's position and color. Training a GAN teaches the generator to understand patterns of numbers in an image and it stores this information in the weights and connections between neurons.

#### How Does A GAN Work?

The GAN generator is a function z->x that transforms random z to realistic images x. Z is random input noise that is in the same dimension that our generated x image. The diagram below shows the general flow of the system.

In this diagram:
* Z - The random noise as input
* X - The real images fed into the discriminator
* X* - Generated images fed into the discriminator


![gan.png](assets/gan-diagram.png) 
First the generator is fed random noise, Z. It then takes this input Z and runs it through the generator. Different neurons are activated by the input data and/or the activation of other neurons. These neurons change the image pixels that it controls based on its activation.

The output of the generator, X*, is then merged with the real image dataset, X, which is passed to the discriminator.

The discriminator attempts to separate the images back into real and generated images. The percentage of images that the discriminator classifies incorrectly is considered the **classification error**. The worst possible classification error is 50% because any higher error could have the classifications switched to get a lower classification error. 

* A high classification error means that the generator has created real enough looking images to trick the discriminator
* A low classification error means that the discriminator could easily tell the real images from the fake ones


After the discriminator has classified the images, that data is run back through the system through backpropagation to help adjust weights in both the generator and discriminator neural networks.

### Examples of the different types of GANs:

##### [Style GANs ](https://towardsdatascience.com/explained-a-style-based-generator-architecture-for-gans-generating-and-tuning-realistic-6cb2be0f431)
Style GANs are able to control the style of generated images at different levels of detail:

This first example is a style GAN used to merge two people together to create a new "person" combining features from each:
[![style_gan.gif](assets/style_gan.jpg)](https://towardsdatascience.com/explained-a-style-based-generator-architecture-for-gans-generating-and-tuning-realistic-6cb2be0f431)

This example converts an image into a painting of the style of an artist:
[![style_gan2.gif](assets/style_gan2.jpeg)](https://deepart.io/)

##### [Progressive GANs](https://towardsdatascience.com/progressively-growing-gans-9cb795caebee)
Progans are a type of GAN in which images are trained at low resolution first and then are gradually trained to a higher resolution adding layers as the resolution increases.
[![progan.gif](assets/progan.gif)](https://towardsdatascience.com/progan-how-nvidia-generated-images-of-unprecedented-quality-51c98ec2cbd2)

##### [BIGGANs](https://medium.com/syncedreview/biggan-a-new-state-of-the-art-in-image-synthesis-cf2ec5694024)
The goal of BIGGANs is to successfully generate high resolution images at a large scale using a lot of iterations and model parameters. The result is the generation of both high-resolution (large) and high-quality (high-fidelity) images.

[![biggan.png](assets/biggan.png)](https://medium.com/syncedreview/biggan-a-new-state-of-the-art-in-image-synthesis-cf2ec5694024)


Here's a timeline of types of GANs that have developed since 2014:
![timeline1.png](assets/timeline.png)
![timeline2.png](assets/timeline1.png)

### What we will be using
In this notebook we will be training a Deep Convolutional GANs or [**DCGANs**](https://towardsdatascience.com/dcgans-deep-convolutional-generative-adversarial-networks-c7f392c2c8f8), which set the stage for the success we are seeing today. In brief, DCGANs are a type of [convolutional network](https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53) that has a multi-layered architecture. These layers expand and increase in dimension as the system develops to process and retain more information about the dataset. DCGANs were developed in 2015, these were one of the first major improvements from the original GAN. DCGANs are still highly relevant and foundational to GANS. 

### Difficulties in training a GAN
GANs can end up being [very difficult to train](https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b) because of the finicky nature of having two networks compete against each other. Here are some of the errors that can occur while training.
#### Non-converging Solutions
If the training function does not find an optimal solution during training then the GAN is considered non-converging. A model will typically finish training at a Nash equilibrium. Since both sides want to beat each other, a Nash equilibrium occurs when one player will not change its action regardless of what the opponent may do. 

Think of a non-converging case as if we were competing two networks in a game of rock paper scissors. Because all options are perfectly balanced there are no optimal solutions so any strategy that one network comes up with will be countered with a counter strategy from the other. This new strategy will then be countered and so on creating a non-coptimizing cycle.

We can also consider an example where two players A and B control the value of x and y respectively. In the diagram below Player A wants to maximize the value xy while B wants to minimize it.

[![non-converging.png](assets/non_converging.png)](https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b)

This GAN will never converge on a solution because any value of a player can be counteracted by negating the sign of the other player. Therefore x and y will be stuck in a continuous loop of positive and negative values forever.

Therefore when designing GAN's we must be careful to determine if the model is non-converging otherwise we could waste time and money trying to train the GAN.

#### Vanishing gradients

Another problem with training GANs is that it‚Äôs easy for one network to overpower another inhibiting the training of both networks. This means that one of the networks isn‚Äôt receiving any positive feedback and the other network does not have any negative feedback. Thus neither network knows how to improve.

This is because networks become better by learning from both their successes and failures. They use their outcome data to shift the network towards tendencies that most resemble the successes. This is a process called gradient descent as shown in the picture below. The function is attempting to minimize the error or cost on a training set.

[![gradient.png](assets/Gradient_descent.png)](https://hackernoon.com/gradient-descent-aynk-7cbe95a778da)

By adjusting weights of neurons in our network we are able to decrease the overall cost of the function along the gradient in the direction of less error.

However, when one network (either generator or discriminator) overpowers the other, this function is more representative of a horizontal line. Therefore our networks do not know which way to adjust weights in order to minimize the cost. This is the problem of vanishing gradients and, in most cases, it occurs because the descrimator overpowers the generator because it trains faster.

Therefore one of our goals in this notebook is to monitor the training of both networks so that they can improve simultaneously.

#### Mode Collapse

One of the most common errors in training a GAN is mode collapse. This occurs when the generator starts producing nearly identical images.

Recall that the goal of our generator is to produce images that create the most classification error. If in our training we have stopped training the discriminator to let the generator improve compared to the discriminator (for example to fight against discriminator overpowering) it will begin to converge on producing an image that causes the most classification error.

In an extreme example case the generator converges on one image (let's call it "A") making its generation completely independent of the random input Z into the generator.

Now when we start training the discriminator again it easily determines that any image "A" is a generated image therefore the classification error is now 0%. The discriminator has now become specialized and only recognizes the one image, "A", as a generated image. Now the generator can change the image a tiny bit from "A" to completely throw off the discriminator and make the classification error 50%. Again the discriminator finds this image and now the process repeats.

This destroys the functionality of both the generator and discriminator, and the generator no longer produces new images based on the random input.

An example of partial mode collapse is shown below. Images with the same color underline are very similar.
[![modecollapse.png](assets/mode-collapse.png)](https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b)

#### What we can do

We will be using various statistics in order to detect these problems in our GAN as shown later in this notebook.

#### Now let's take a look at how to set up our GAN

## Preliminaries
First, let's import the all of the python packages we will use throughout the tutorial. In addition to standard import and PyTorch, we provide a small package called ganocracy, which aggregates a host of useful GAN specific functions and utilities in one place for everything needed to get started with GANs.
Almost all code shown in this notebook can be found inside the ganocracy package. We house reference implementations outside of the notebook so that bits and pieces can be conveniently incorporated into other projects, (without having to copy and paste from this notebook).
Note: One major disadvantage of GAN‚Äôs is it is very difficult to ‚Äútake a look inside‚Äù and make smart adjustments. We attempt to do this a little later in the GANdissect notebook.
 
Below, we configure how the notebook will run and fill it with sensible defaults:

In [1]:
# IMPORTS
import os 
import sys
import numpy as np
import random
from PIL import Image

Standard imports to make use of our system and matrix creation/manipulation.

In [2]:
# PYTORCH
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

Pytorch is the library that we are going to be using to enable our machine learning. It handles a lot of the mathematical training of our models.

In [3]:
# GANOCRACY LIB
import ganocracy
from ganocracy.data import datasets as dset
from ganocracy.data import transforms
from ganocracy import metrics, models
from ganocracy.models import utils as mutils
from ganocracy.utils import visualizer as vutils

Here we import a small package called ganocracy, which aggregates a host of useful GAN specific functions and utilities in one place for everything needed to get started with GANs. Almost all code shown in this notebook can be found inside the ganocracy package.

In [4]:
# NOTEBOOK-SPECIFIC IMPORTS
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, Video
try:
    import moviepy.editor as mpy
except ImportError:
    print('WARNING: Could not import moviepy. Some cells may not work.')
    print('You can install it with `pip install moviepy`')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

Additional libraries to allow us to display the output that we create through our training

In [5]:
# Set random seem for reproducibility.
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Use this command to make a subset of
# GPUS visible to the jupyter notebook.
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5'

Below, we configure how the notebook will run and set it up with our default settings:

In [6]:
TRAIN = True

# ---------------------
# DATASET/Loader CONFIG
# ---------------------
dataset_name = 'CelebA' # Name of dataset.
dataroot = 'data'       # Root directory for the dataset.
download = True         # If data is not found, download and cache it.
split = 'train'         # Dataset split (train or val), if applicable.
num_workers = 1         # Number of workers for dataloader
batch_size = 96         # Batch size per forward/backward pass.

Note: The number of workers allow us to break our computations into parallel run operations
Batch size is the number of samples we propogate through the network at one time. A higher batch size decreases training time, however, it takes more memory and can degrade the quality of the output.

[![Batch Size Graph](assets/optimum-batch-size.png)](https://www.boost.co.nz/blog/2018/11/reduce-batch-size-agile-software-development)

In [7]:
resolution = 128  # Image size (H, W) in pixels. Choices: [32, 64, 128, ...] Resized if needed.

# ------------------
# Model Architecture
# ------------------
dim_z = 100 # Size of z latent vector (i.e. size of generator input)
G_ch = 64   # Channel multiplier - scales number of features per conv in G and D
D_ch = 64

Dimension of z determines the dimension of z is the random noise that will be act as a starting point for our network. Changing the channel multiplier also changes the generator and descriminator's size

In [8]:
# ---------------
# TRAINING CONFIG
# ---------------
num_epochs = 50 # Number of training epochs (1 epoch = 1 presentation of full dataset).

num_D_steps = 1         # Number of updates to Discriminator per 1 step of Generator.
num_D_accumulations = 1 # Number of gradient accumulations per step of G and D.
num_G_accumulations = 1 # Technique to spoof larger "mega-batch" sizes.
D_batch_size = batch_size * num_D_steps * num_D_accumulations

The number of epochs represents how many times we run through all of our data. 
By changing the number of updates to our descriminator per step of the generator we can fine tune the training of the descriminator so it doesn't overpower our generator.

In [9]:
test_every = 500   # How frequently to evaluate our model. Use None to skip testing.
sample_every = 100 # How frequently we save fixed noise samples.

In [10]:
# ----------------
# OPTIMIZER CONFIG
# ----------------
G_lr = 2e-4      # Learning rate for optimizers
D_lr = 2e-4      

# Betas hyperparams for Adam optimizers
G_betas = (0.5, 0.999)
D_betas = (0.5, 0.999)

# ---------------
# HARDWARE CONFIG
# ---------------
ngpu = 6 # Number of GPUs available. Use 0 for CPU mode

# Determine whether or not GPUs are available.
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")


# -----------
# SAVE CONFIG
# -----------
name = 'LIVE_DEMO'         # Experiment specific prefix.
model_dir = 'checkpoints'  # Directory to store model checkpoints.
samples_dir = 'samples'    # Directory to store samples.

# Generate model name based on config.
model_name = '_'.join(map(str, [
    name,
    dataset_name, 
    resolution,
    'bs{}'.format(batch_size),
    'dim_z{}'.format(dim_z),
    'G_ch{}'.format(G_ch),
    'D_ch{}'.format(D_ch),
    'G_lr{}'.format(G_lr),
    'D_lr{}'.format(D_lr),
    'G_betas{}'.format('_'.join(map(str, G_betas))),
    'D_betas{}'.format('_'.join(map(str, D_betas))),
]))

# Prepare dirs to store samples and checkpoints.
save_name = os.path.join(model_dir, model_name)
os.makedirs(os.path.join(model_dir, model_name), exist_ok=True)
os.makedirs(os.path.join(samples_dir, model_name), exist_ok=True)
print(f'Starting experiment {model_name}')

Starting experiment LIVE_DEMO_CelebA_128_bs96_dim_z100_G_ch64_D_ch64_G_lr0.0002_D_lr0.0002_G_betas0.5_0.999_D_betas0.5_0.999


## Data Preparation
In this tutorial, we will use the CelebFaces Attributes Dataset (CelebA), a large-scale face attributes dataset with more than 200K celebrity images. Each image is annotated with 40 attributes (e.g. smiling, eyeglasses, etc.). 

We can download this dataset (along with a host of other commonly used datasets) with PyTorch's [`torchvision`](https://pytorch.org/docs/stable/torchvision/index.html) package. 

In [11]:
# Create the dataset with transforms.
dataset = torchvision.datasets.CelebA('data',
                                      download=download,
                                      transform=transforms.Compose([
                                          transforms.Resize(resolution),
                                          transforms.CenterCrop(resolution),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5, 0.5, 0.5),  # Normalize between -1 and 1.
                                                               (0.5, 0.5, 0.5))]),
                                       target_transform=lambda x: 0)
# Create the dataloader.
dataloader = torch.utils.data.DataLoader(dataset, 
                                         shuffle=True,
                                         batch_size=D_batch_size,
                                         num_workers=num_workers)
# Let's take a look at some samples
vutils.visualize_data(dataloader)

Files already downloaded and verified
Dataset CelebA
    Number of datapoints: 162770
    Root location: data
    Target type: ['attr']
    Split: train
    StandardTransform
Transform: Compose(
               Resize(size=128, interpolation=PIL.Image.BILINEAR)
               CenterCrop(size=(128, 128))
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           )
Target transform: <function <lambda> at 0x00000202D2108400>


PicklingError: Can't pickle <function <lambda> at 0x00000202D2108400>: attribute lookup <lambda> on __main__ failed

### Using your own data

If you would like to train a GAN on a your own custom dataset, subclassing [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) is a sensible approach as it allows you to make use of PyTorch's dataloading utilities, including the multi-threaded dataLoader and transforms from above.

If your training data consist of image files, the [`torchvision.datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder) class facilitates convenient dataloading. Simply arrange your files in the following way:

    root/dogball/xxx.png
    root/dogball/xxy.png
    root/dogball/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png
    
where each subdirectory of `root` is considered an image category containing examples of that category. Then, you can create your dataset with:

```python
dataroot = '/path/to/data/root'
dataset = torchvision.datasets.ImageFolder(dataroot transforms=...)
```

## Measuring a sample's quality during training
This is an optional step and not always possible to do, but if you'd like to see how the quality of samples in training are measured.
Typically, when training any sort of neural network, it is standard practice to monitor the value of the objective function (loss) throughout the course of the experiment; adversarial losses measure the competition between the generator and discriminator. However, the adversarial losses do not necessarily reflect the image quality of generated samples.
Objectively evaluating the ‚Äúrealness‚Äù and ‚Äúdiversity‚Äù of generated images is not an easy task. It‚Äôs on ongoing research project to determine what we consider quality photos. However, two popular metrics are inception Score and Frechet Distance: 
- **[Inception Score](https://medium.com/octavian-ai/a-simple-explanation-of-the-inception-score-372dff6a8c7a)** (IS): Score based on how confidently an ImageNet-pretrained InceptionV3 network can classify generated samples and the diversity of its predictions over large collection of samples. If a model produces samples that InceptionV3 can confidently classify, this contributes to a higher IS. Essentially, more images with recognizable concepts in them corrospond to images that are of greater quality. A high diversity of classifications also contributes to a higher IS.
[![Inception Score](assets/inception-score.png)](https://medium.com/octavian-ai/a-simple-explanation-of-the-inception-score-372dff6a8c7a)
- **[Fr√©chet Distance](https://www.mathworks.com/matlabcentral/fileexchange/31922-discrete-frechet-distance)** (FID): Here, a pretrained Inception Network is used to generate feature representations for both the real images from the dataset of interest and generated samples from the model. These feature distributions are modeled by multivariate Gaussian distributions. The shorter the Fr√©chet Distance between these two distributions, the more closely the fake images resemble the real ones.
[![Fr√©chet Distance](assets/frechet-distance.jpg)](https://www.mathworks.com/matlabcentral/fileexchange/31922-discrete-frechet-distance)
Summary

- Higher IS values mean better image quality and diversity (usually).
- Lower FID values mean better image quality and diversity.

#### Some Caveats
The InceptionV3 model is trained to classify ImageNet categories. The ImageNet database is a source of pictures of objects in specific orientations in order to be used for classifier training.
Since the InceptionV3 model is trained to classify ImageNet categories, Inception Score can be a very poor measure of quality on datasets other than ImageNet. Unless your dataset if very ImageNet-like, FID will likely be a better estimate of sample quality.

In [None]:
if TRAIN and test_every is not None:
    # If we are going to evaluate IS and FID, we need to precompute Inception moments.
    inception_moments_file = metrics.calculate_inception_moments(dataloader, dataroot,
                                                                 '-'.join([dataset_name, str(resolution)]),
                                                                 device=device)
    compute_inception_metrics = metrics.prepare_inception_metrics(inception_moments_file)

## Quick Review of Nets
### Overfitting
If we put too much emphasis on a specific training sets data points it is very likely that we will overfit our data. Overfitting means that we are putting too much emphasis on specific outlier points that don‚Äôt fit our overall trend. Shown below the green line is overfitting the data when we would prefer the black line to be our determining line even though the green line technically splits our training set perfectly.
[![Overfitting](assets/overfitting.png)](https://en.wikipedia.org/wiki/Overfitting)
For example, if we add a new yellow point: Our intuition would tell us it is more likely to be blue but it is instead classified as red by the green line because of the overfitting of our network.
### Regularization
Regularization is the process of preventing overfitting in a neural network. Typically the function of regularization has some perception of complexity and factors into the backpropagation to determine the best fit. A good regularization would come up with a differentiator more akin to the black line as opposed to the green line in the above example of overfitting.
### Multilayer perceptron
In machine learning a multilayer perceptron is a neuron connected to all neurons of the next layer in the network.

[![Multilayer Perceptron](assets/Multilayer-Perceptron.png)](https://www.researchgate.net/figure/A-hypothetical-example-of-Multilayer-Perceptron-Network_fig4_303875065)
The problem with a network with this structure is that it is very easy to overfit since there are so many possibly unnecessary connections.
### Convolutional Neural Networks
Convolutional neural networks are based off the idea of a multilayer perceptron but include a form of structural regularization in order to prevent overfitting. The theory behind them is to build more complicated network structures from less complicated structures. They are particularly good at image classification because they emulate how the human brain sends signals from the animal visual cortex. They also need relatively little pre-processing to classify images which is a major advantage compared to their image classifying counterparts.

[![DCGAN Generator Architecure](assets/Convolutional.jpeg)](https://towardsdatascience.com/dcgans-deep-convolutional-generative-adversarial-networks-c7f392c2c8f8)

## Convolutional GANS 

With all that background out of the way we will now introduce convolutional GANS. DCGANs work by modeling G (generator) and D (discrimator) functions with convolutional neural networks (CNNs). DCGANs stands as one of the most popular and successful baseline GAN architectures because CNN's adeptness at working with images. The CNNs are also more robust and in the GAN training process.

**Challenge:** Design a neural network architecture for efficient and stable image generation. 

**Solution:** A fully convolutional network that does away with max pooling (DCGANs successfully eliminate connecting all of the neurons of the layers of the network as a normal CNN does). Convolutions had already proven successful for discriminative computer vision tasks since they are well suited for handling the spatial structure of images, and were introduced to GANs in this paper.

Below is a figure depicting the design of the Generator:
![DCGAN Generator Architecure](assets/dcgan_generator.png)

Almost all GANs used for image synthesis build on ideas introduced by DCGAN, although some modern architectures are starting to deviate from these early design decisions. None-the-less, due to its simplicity and success, DCGAN remains a good starting point for a new project.

### Implementation

In this implementation, we will build the DCGAN generator out of GBlocks which are a block of weighted units. Using blocks increases the spatial dimensions while decreasing the feature volume depth. 

Essentially, we can increase the desired output resolution by simply adding more blocks to our model.

In [21]:
from GBlock.py import Gblock
from GBlock.py import Generator

NameError: name 'nn' is not defined

`DBlocks`, found in the discriminator, are near inverses to `GBlocks`, trading spatial dimension for feature depth, with the exception of using LearkyReLUs instead of ReLUs for their nonlinearities.

In [20]:
from DBlock.py import DBlock
from DBlock.py import Generator

NameError: name 'nn' is not defined

When building a new GAN, it can really useful to see exactly how the shape of the data changes as it is transformed from noise to RGB image. We provide a small utility for capturing the intermediate sizes and use it below: 

In [None]:
# Instantiate DCGAN Generator instance.
G = Generator(dim_z=dim_z, resolution=resolution).to(device)
print(G)  # View model architecture.

# Create some input tensors.
z = torch.randn(4, G.dim_z, device=device)

# Generate samples while capturing the sizes of intermediate outputs.
Gz, names, sizes = mutils.hook_sizes(G, inputs=(z,), verbose=True)
vutils.visualize_samples(Gz, title='Samples from untrained Generator')

In [None]:
# Create the Discriminator
D = models.dcgan.Discriminator(resolution=resolution).to(device)
print(D)  # View discriminator architecture.

# Let's pass the output of the generator to the dicriminator.
D_Gz, names, sizes = mutils.hook_sizes(D, inputs=(Gz,), verbose=True)
print('D(G(z)) should be probabilities: {}'.format(D_Gz))

## Loss Functions
Recall, the adverarial training strategy is to define a game between two competing networks:
- **generator:** a function $G$ that spawns ‚Äòfake‚Äô images by mapping a sorce of noise to the input space.
- **discriminator** a function $D$ that must distinguish between a generated sample or a true data sample.
In a zero-sum, non-cooperative game, the generator is trained to fool the discriminator

In practice, it's better to start with the non-saturating GAN variant, the generator. The intuition for the non-saturating variant is that, in practice, it's easier to optimize the discriminator than the generator, especially early on in training. 

If the generator is not doing a good job yet, then the discriminator will continually tell the generator that it is not creating real images and the generator has no positive feedback in order to improve itself. This is a case of vanishing gradients in ùê∫

These are the functions that each network uses to judge its performance.

Discriminator maximizes: $\mathbb{E}_{x\sim p_{data}(x)}\big[\log D(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[\log(1-D(G(z))\big]$

Generator maximizes: $\mathbb{E}_{x\sim p_{z}(z)}\big[\log(D(G(z)))\big]$



In [None]:
# Initialize BCELoss function.
criterion = nn.BCELoss()
# Note: criterion(x, 1) = -log(x)
#       criterion(x, 0) = -log(1 - x)

# Create batch of latent vectors that we will use to visualize
# the progression of the generator.
fixed_noise = torch.randn(batch_size, dim_z, device=device)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
G_optim = optim.Adam(G.parameters(), lr=G_lr, betas=G_betas)
D_optim = optim.Adam(D.parameters(), lr=D_lr, betas=D_betas)

Before we begin training, let's prepare to monitor important quantities. In practice, you would also want to log these to disk, tensorboard, etc.:

In [None]:
# Keep a running list of various quantities:
img_list = []
G_losses = []
D_losses = []
IS_scores = []
FIDs = []
iters = 0
best_IS = 0
best_FID = 9999

## Now we will start our training!

In [None]:
# Training Loop
if TRAIN:
    print("Starting Training Loop...")
    
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):
            
            counter = 0 # Keep track of "mini-batches"
            data = [d.to(device) for d in data]
            x, y = [torch.split(d, batch_size) for d in data]
            
            real_labels = torch.full((batch_size,), real_label, device=device)
            fake_labels = torch.full((batch_size,), fake_label, device=device) 
            
            #################################################################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            #################################################################

            D.zero_grad()            
            for step_index in range(num_D_steps):
                D.zero_grad()
                for accumulation_index in range(num_D_accumulations):
                    
                    ## ** Train with all-real batch ** ##

                    # Forward real batch through D
                    output = D(x[counter]).view(-1)
                    
                    # Calculate loss on all-real batch
                    D_loss_real = criterion(output, real_labels[:output.size(0)]) / float(num_D_accumulations)
                    D_loss_real.backward()
                    D_x = output.mean().item()

                    ## ** Train with all-fake batch ** ##
                
                    # Generate batch of latent vectors and targets
                    noise = torch.randn(batch_size, dim_z, device=device)
                    
                    # Generate fake image batch with G
                    fake_image = G(noise)

                    # Classify all fake batch with D
                    output = D(fake_image.detach()).view(-1)

                    # Calculate D's loss on the all-fake batch
                    D_loss_fake = criterion(output, fake_labels[:output.size(0)]) / float(num_D_accumulations)
                    D_loss_fake.backward()
                    D_G_z1 = output.mean().item()
                    
                    # Add the gradients from the all-real and all-fake batches
                    D_loss = (D_loss_real + D_loss_fake)
                    counter += 1
                    
                # Update D
                D_optim.step()

            #############################################
            # (2) Update G network: maximize log(D(G(z)))
            #############################################
            
            G.zero_grad()
            for accumulation_index in range(num_G_accumulations):
                
                # Generate batch of latent vectors.
                noise = torch.randn(batch_size, dim_z, device=device)
                fake_image = G(noise)
                output = D(fake_image).view(-1)
                
                # Calculate G's loss based on this output
                G_loss = criterion(output, real_labels) / float(num_G_accumulations)
                    
                # Calculate gradients for G
                G_loss.backward()
                D_G_z2 = output.mean().item()
                
            # Update G
            G_optim.step()
            
            # Output training stats
            if i % 10 == 0:
                print('[{}/{}][{}/{}]\tLoss_D: {:.4f}\tLoss_G: {:.4f}\tD(x): {:.4f}\tD(G(z)): {:.4f} / {:.4f}'.
                     format(epoch, num_epochs, i, len(dataloader),
                            D_loss.item(), G_loss.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(G_loss.item())
            D_losses.append(D_loss.item())

            # Check how the generator is doing by saving G's samples on fixed_noise
            if (iters % sample_every == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake_image = G(fixed_noise).detach().cpu()
                    
                # Save to disk and keep copy in list
                fname = os.path.join(samples_dir, model_name, f'{iters:06d}.jpg')
                torchvision.utils.save_image(fake_image, fname, padding=2, normalize=True )
                img_list.append(torchvision.utils.make_grid(fake_image, padding=2, normalize=True))
            
            if test_every is not None:
                if ((iters % test_every == 0)
                    or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1))):

                    with torch.no_grad():
                        IS, IS_std,  FID = compute_inception_metrics(G, batch_size, dim_z)

                    print('Itr {}: PYTORCH Inception Score is {:3.3f} +/- {:3.3f} '
                          'PYTORCH FID is {:5.4f}'.format(iters, IS, IS_std, FID))
                    IS_scores.append(IS)
                    FIDs.append(FID)
                    
                    # Remember best IS and FID and save checkpoint.
                    is_best_IS = IS > best_IS
                    is_best_FID = FID < best_FID
                    best_IS = max(IS, best_IS)
                    best_FID = min(FID, best_FID)
                    
                    # Keep a running checkpoint
                    mutils.save_checkpoint({
                        'G': G.state_dict(),
                        'D': D.state_dict(),
                        'iters': iters,
                        'epoch': epoch,
                        'IS_scores': IS_scores,
                        'FIDs': FIDs,
                        'best_IS': best_IS,
                        'best_FID': best_FID,
                    }, is_best_IS, is_best_FID, 
                        filename=save_name + '.pth.tar')
                    
            iters += 1
        # Finally, save a checkpoint every epoch.
        torch.save({
            'G': G.state_dict(),
            'D': D.state_dict(),
            'G_losses': G_losses,
            'D_losses': D_losses,
            'iters': iters,
            'epoch': epoch,
            'IS_scores': IS_scores,
            'FIDs': FIDs,
            'best_IS': best_IS,
            'best_FID': best_FID,
            }, f'{save_name}_epoch{epoch}.pth.tar')

### And we're off!

... And now we wait. At lower resolutions or fewer classes, it is possible to obtain farily respectable results in a short time frames. However, acheiving  the eye-catching results commonly advertised in paper and the media still takes quite a while, on the order of weeks potentially. 

**Note on hardware utilization**: If you are using GPUs make sure that you are using them to the fullest. The command `watch -n 1 nvidia-smi` will show utilization and memory consumption continuously (refresehed every second). Could you use a larger batch size? If you see your GPU usage cycle between 100% and 0% it could be evidence of a dataloading bottleneck.

### "Babysitting" the learning process

Given that training these models can be an investment in time and resources, it's wise to continuously monitor training in order to catch and address anamolies if/when they occur. Here are some things to look out for:

#### What should the losses look like?

GAN losses come in all shapes and sizes and depend on numerous factors including architecture, dataset, and loss function. The adversarial learning process is highly dynamic and high frequency oscillations are quite common:

![Loss Montage](assets/loss_montage.jpg)

**Recommendation:** Make sure losses fall within a reasonable range and *catch failures early!*. If either loss (D or G) skyrockets to huge values, plunge to 0 or get stuck on a single value, there is likely an issue somewhere.If you are training a common architecture, consult the literature/other implementations to ground your expectations. One of the hardest things about re-implementing a paper can be checking if the logs line up early in training, especially if training takes multiple weeks.

**Is my model learning?**
- Monitor IS and FID metrics and other image quality metrics (if applicable) - are they following the expected trajectories? 
- How do the samples look? Are they improving over time? Do you see evidence of mode collapse?

*Mode Collapse*: When the generator produces an extremely limited set output patterns ("modes") despite maintaining diversity of input noise. Here are samples generated by a model undergoing mode collapse:

![Mode Collapse](assets/collapse.jpg)

**How do I know when to stop?**
- Most importantly, do the samples meet your expectations?
- Sharp increase in metrics followed by collapse?
- No longer improving.
- Explore your model!

## Results

Now that we have finished training, let's find out how we did. We will analyze our model in several ways:
1. Examine how D and G‚Äôs losses changed during training.
2. Visualize G‚Äôs output on the fixed_noise batch for every epoch and create a video.
3. Explore what the Generator has learned in its latent space.

In [None]:
# If you did not train, but want to continue with 
# a pretrained model + logs.
if not TRAIN:
    url = 'http://ganocracy.csail.mit.edu/models/DCGAN_CelebA_128_dim_z100_G_ch64_D_ch64_G_lr0.0002_D_lr0.0002_G_betas0.5_0.999_D_betas0.5_0.999-93ba4eb0.pth'
    checkpoint = torch.hub.load_state_dict_from_url(url, map_location='cpu')
    
    # PyTorch Tip: DataParallel prepends 'module.' to state dict names.
    # If you try to load a model trained with DataParallel into the original,
    # you may have to strip away the prefix like so:
    G.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['G'].items()})
    D.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['D'].items()})
    iters = checkpoint['iters']
    G_losses, D_losses = checkpoint['G_losses'], checkpoint['D_losses']
    IS_scores, FIDs = checkpoint['IS_scores'], checkpoint['FIDs']

In [None]:
def plot_loss_logs(G_loss, D_loss, figsize=(15, 5), smoothing=0.001):
    """Utility for plotting losses with smoothing."""
    G_loss = vutils.smooth_data(G_loss, amount=smoothing)
    D_loss = vutils.smooth_data(D_loss, amount=smoothing)
    plt.figure(figsize=figsize)
    plt.plot(D_loss, label='D_loss')
    plt.plot(G_loss, label='G_loss')
    plt.legend(loc='lower right', fontsize='medium')
    plt.xlabel('Iteration', fontsize='x-large')
    plt.ylabel('Losses', fontsize='x-large')
    plt.title('Training History', fontsize='xx-large')
    plt.show()

plot_loss_logs(G_losses, D_losses, figsize=(15, 5), smoothing=0.01)

In [None]:
def plot_metrics(num_iters, IS_scores, FIDs):
    fig, axs = plt.subplots(2, sharex=True, figsize=(15, 5))
    xscale = test_every if test_every is not None else (num_iters // len(IS_scores))
    itrs = np.arange(0, len(IS_scores)) * xscale
    axs[0].plot(itrs, IS_scores)
    axs[1].plot(itrs, FIDs)
    
    for label, ax in zip(['Inception Score', 'FID'], axs.flat):
        ax.set(ylabel=label)
    
    plt.xlabel('Iteration', fontsize='x-large')
    axs[0].set_title('Training History', fontsize='xx-large')
    fig.tight_layout()
    plt.show()
    
plot_metrics(iters,IS_scores, FIDs)

**Visualization of G‚Äôs progression**

Remember how we saved the generator‚Äôs output on the fixed_noise batch
after every `sample_every` iterations of training. Now, we can visualize the training
progression of G with a video. Press the play button to start the video.


In [None]:
def make_training_video(samples_dir, resolution, num_rows=8, fps=10):
    files = sorted([os.path.join(samples_dir, f) for f in os.listdir(samples_dir) if f.endswith('.jpg')])
    frames = [np.array(Image.open(f).convert('RGB'))[:num_rows * resolution] for f in files]
    clip = mpy.ImageSequenceClip(frames, fps=10)
    video_outfile = os.path.join(samples_dir, 'progress.mp4')
    clip.write_videofile(video_outfile)
    return video_outfile

if not TRAIN:
    video_outfile = 'assets/training_progression.mp4'
else:
    video_outfile = make_training_video(os.path.join(samples_dir, model_name), resolution)    

In [None]:
# Show the video
Video(video_outfile)

### Latent space exploration

The original DCGAN paper showed that the latent space learned by the generator maintains smooth transitions: as you walk through z space, the resulting output transistions naturally. Let's see if this holds for our generator.

In [None]:
# Intra-class (z only) Latent space interpolation
num_samples = 8
num_midpoints = 8
minibatch_size = 8

# First, choose two random coordinates in z space.
z0 = torch.randn(num_samples, dim_z).to(device)
z1 = torch.randn(num_samples, dim_z).to(device)

# Interpolate between z0 and z1.
zs = vutils.interp(z0, z1, num_midpoints, device=device)
zs = zs.view(-1, zs.size(-1))

# Generate samples.
with torch.no_grad():
    samples = G(zs)

# Show
vutils.visualize_samples(samples, nrow=num_midpoints + 2)

## Cool implementations of gans below
### Open source implementations
For some awesome open-source PyTorch implementations of recent GANs, check out:
- [BigGAN-PyTorch](https://github.com/ajbrock/BigGAN-PyTorch): Train BigGANs from *Large Scale GAN Training for High Fidelity Natural Image Synthesis* on 4-8 GPUs
- [pytorch_GAN_zoo](https://github.com/facebookresearch/pytorch_GAN_zoo): Implementations of DCGAN and Progressive Growing of GAN by Facebook Research
- [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix): Image-to-image translation in PyTorch (e.g., horse2zebra, edges2cats, and more)
- [PyTorch-GAN](https://github.com/eriklindernoren/PyTorch-GAN): Library of reference implementations (good for educational purposes)

Some more information about GANs:
- [Machine Learning](https://medium.com/machine-learning-for-humans/why-machine-learning-matters-6164faf1df12)
- [Types of Machine Learning](https://towardsdatascience.com/types-of-machine-learning-algorithms-you-should-know-953a08248861)
- [Multi Layer Perceptrons](http://deeplearning.net/tutorial/mlp.html)
- [Reinforcement Learning](https://deepsense.ai/what-is-reinforcement-learning-the-complete-guide/)
- [Convolutional Neural Networks](https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53)

And for the Tensorflow and Keras folks:
- [StyleGAN](https://github.com/NVlabs/stylegan): Official TensorFlow Implementation
- [Progressive Growing of GANs](https://github.com/tkarras/progressive_growing_of_gans): Official TensorFlow Implementation.
- [Keras-GAN](https://github.com/eriklindernoren/Keras-GAN): Library of reference implementations (good for educational purposes).