# Pytorch GANs

An example script for creating and running a GAN is in the accompanying file `mnist_gan.py`.

https://github.com/cmudeeplearning11785/deep-learning-tutorials/blob/master/recitation-10/mnist_gan.py

- Inferno is used for training and logging
- The main inferno training loop pumps real images and trains the discriminator
- A callback periodically trains the generator

If you want the full experience, please try running tensorboard while the script runs. 

- Live images will be drawn to the webpage as your network trains.
- A video will be rendered when training completes (if you install ffmpeg).


![tensorboard](https://github.com/cmudeeplearning11785/deep-learning-tutorials/raw/master/recitation-10/images/tensorboard1.png)

![tensorboard](https://github.com/cmudeeplearning11785/deep-learning-tutorials/raw/master/recitation-10/images/tensorboard2.png)

In [3]:
# Import the actual code from the linked file
import mnist_gan
import mnist_wgangp
import mnist_cwgangp
import cifar10_wgangp
from IPython.core.display import HTML

# Successful GAN

First we train the GAN with settings that converge (found through trial-and-error). The generator and discriminator both have a learning rate of 3e-4 and the generator is trained 1 time every 5 times the discriminator is trained.

In [None]:
mnist_gan.main([])

https://www.youtube.com/embed/IUi0REAWj2c?rel=0

# Failed GAN

Here we see what happens if the generator is trained too much or too little compared to the discriminator.

In [None]:
mnist_gan.main(['--generator-frequency=1', '--save-directory=output/mnist_gan/frequency-1', '--epochs=50'])

https://www.youtube.com/embed/J8m1NXLwSKw

# Wasserstein GAN with Gradient Penalty

Here we see an improvement on traditional GAN that we discussed in lecture.

In [None]:
mnist_wgangp.main([])

https://www.youtube.com/watch?v=unXILX2wp1A

# WGAN-GP on CIFAR10

For a slightly more complicated dataset, this example uses CIFAR10.

In [None]:
cifar10_wgangp.main([])

https://youtu.be/dAe-UcOfywE

# Conditional WGAN-GP on MNIST

Here use use a conditional GAN to learn each digit.

In [None]:
mnist_cwgangp.main([])

https://youtu.be/_wuRRwujeHc