![MIT Deep Learning](https://deeplearning.mit.edu/files/images/github/mit_deep_learning.png)

<table align="center">
  <td align="center"><a target="_blank" href="https://deeplearning.mit.edu">
        <img src="https://deeplearning.mit.edu/files/images/github/icon_mit.png" style="padding-bottom:5px;" />
      Visit MIT Deep Learning</a></td>
  <td align="center"><a target="_blank" href="http://colab.research.google.com/github/lexfridman/mit-deep-learning/blob/master/tutorial_gans/tutorial_gans.ipynb">
        <img src="https://deeplearning.mit.edu/files/images/github/icon_google_colab.png" style="padding-bottom:5px;" />Run in Google Colab</a></td>
  <td align="center"><a target="_blank" href="https://github.com/lexfridman/mit-deep-learning/blob/master/tutorial_gans/tutorial_gans.ipynb">
        <img src="https://deeplearning.mit.edu/files/images/github/icon_github.png" style="padding-bottom:5px;"  />View Source on GitHub</a></td>
  <td align="center"><a target="_blank" align="center" href="https://www.youtube.com/watch?v=O5xeyoRL95U&list=PLrAXtmErZgOeiKm4sgNOknGvNjby9efdf">
        <img src="https://deeplearning.mit.edu/files/images/github/icon_youtube.png" style="padding-bottom:5px;" />Watch YouTube Videos</a></td>
</table>

# Generative Adversarial Networks (GANs)

This tutorial accompanies lectures of the [MIT Deep Learning](https://deeplearning.mit.edu) series. Acknowledgement to amazing people involved is provided throughout the tutorial and at the end. Introductory lectures on GANs include the following (with more coming soon):

<table>
  <td align="center" style="text-align: center;">    
    <a target="_blank" href="https://www.youtube.com/watch?list=PLrAXtmErZgOeiKm4sgNOknGvNjby9efdf&v=O5xeyoRL95U">
        <img src="https://i.imgur.com/FfQVV8q.png" style="padding-bottom:5px;" />
        (Lecture) Deep Learning Basics: Intro and Overview
    </a>
  </td>
  <td align="center" style="text-align: center;">
      <a target="_blank" href="https://www.youtube.com/watch?list=PLrAXtmErZgOeiKm4sgNOknGvNjby9efdf&v=53YvP6gdD7U">
        <img src="https://i.imgur.com/vbNjF3N.png" style="padding-bottom:5px;" />
          (Lecture) Deep Learning State of the Art 2019
      </a>
  </td>
</table>

Generative Adversarial Networks (GANs) are a framework for training networks optimized for generating new realistic samples from a particular representation. In its simplest form, the training process involves two networks. One network, called the generator, generates new data instances, trying to fool the other network, the discriminator, that classifies images as real or fake. This original form is illustrated as follows (where #6 refers to one of 7 architectures described in the [Deep Learning Basics tutorial](https://github.com/lexfridman/mit-deep-learning/blob/master/tutorial_deep_learning_basics/deep_learning_basics.ipynb)):

<img src="https://i.imgur.com/LweaD1s.png" width="600px">

There are broadly 3 categories of GANs:

1. **Unsupervised GANs**: The generator network takes random noise as input and produces a photo-realistic image that appears very similar to images that appear in the training dataset. Examples include the [original version of GAN](https://arxiv.org/abs/1406.2661), [DC-GAN](https://arxiv.org/abs/1511.06434), [pg-GAN](https://arxiv.org/abs/1710.10196), etc.
3. **Style-Transfer GANs** - Translate images from one domain to another (e.g., from horse to zebra, from sketch to colored images). Examples include [CycleGAN](https://junyanz.github.io/CycleGAN/) and [pix2pix](https://phillipi.github.io/pix2pix/).
2. **Conditional GANs** - Jointly learn on features along with images to generate images conditioned on those features (e.g., generating an instance of a particular class). Examples includes [Conditional GAN](https://arxiv.org/abs/1411.1784), [AC-GAN](https://arxiv.org/abs/1610.09585), [Stack-GAN](https://github.com/hanzhanggit/StackGAN), and [BigGAN](https://arxiv.org/abs/1809.11096).

First, we illustrate BigGAN, a state-of-the-art conditional GAN from DeepMind. This illustration is based on the [BigGAN TF Hub Demo](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/biggan_generation_with_tf_hub.ipynb) and the BigGAN generators on [TF Hub](https://tfhub.dev/deepmind/biggan-256). See the [BigGAN paper on arXiv](https://arxiv.org/abs/1809.11096) [1] for more information about these models.

We'll be adding more parts to this tutorial as additional lectures come out.

## Part 1: BigGAN

We recommend that you run this this notebook in the cloud on Google Colab. If you have not done so yet, consider following the setup steps in the [Deep Learning Basics tutorial](https://github.com/lexfridman/mit-deep-learning) and reading the [Deep Learning Basics: Introduction and Overview with TensorFlow](https://medium.com/tensorflow/mit-deep-learning-basics-introduction-and-overview-with-tensorflow-355bcd26baf0) blog post.

In [6]:
# basics
import io
import os
import numpy as np

# deep learning
from scipy.stats import truncnorm
import tensorflow as tf
if tf.__version__.split(".")[0] == '2':
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
import tensorflow_hub as hub

# visualization
from IPython.core.display import HTML
#!pip install imageio
import imageio
import base64

# check that tensorflow GPU is enabled
tf.test.gpu_device_name() # returns empty string if using CPU

''

### Load BigGAN generator module from TF Hub

In [7]:
# comment out the TF Hub module path you would like to use
# module_path = 'https://tfhub.dev/deepmind/biggan-128/1'  # 128x128 BigGAN
# module_path = 'https://tfhub.dev/deepmind/biggan-256/1'  # 256x256 BigGAN
module_path = 'https://tfhub.dev/deepmind/biggan-512/1'  # 512x512 BigGAN

tf.reset_default_graph()
print('Loading BigGAN module from:', module_path)
module = hub.Module(module_path)
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
          for k, v in module.get_input_info_dict().items()}
output = module(inputs)

Loading BigGAN module from: https://tfhub.dev/deepmind/biggan-512/1
INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


### Functions for Sampling and Interpolating the Generator

In [8]:
input_z = inputs['z']
input_y = inputs['y']
input_trunc = inputs['truncation']

dim_z = input_z.shape.as_list()[1]
vocab_size = input_y.shape.as_list()[1]

# sample truncated normal distribution based on seed and truncation parameter
def truncated_z_sample(truncation=1., seed=None):
    state = None if seed is None else np.random.RandomState(seed)
    values = truncnorm.rvs(-2, 2, size=(1, dim_z), random_state=state)
    return truncation * values

# convert `index` value to a vector of all zeros except for a 1 at `index`
def one_hot(index, vocab_size=vocab_size):
    index = np.asarray(index)
    if len(index.shape) == 0: # when it's a scale convert to a vector of size 1
        index = np.asarray([index])
    assert len(index.shape) == 1
    num = index.shape[0]
    output = np.zeros((num, vocab_size), dtype=np.float32)
    output[np.arange(num), index] = 1
    return output

def one_hot_if_needed(label, vocab_size=vocab_size):
    label = np.asarray(label)
    if len(label.shape) <= 1:
        label = one_hot(label, vocab_size)
    assert len(label.shape) == 2
    return label

# using vectors of noise seeds and category labels, generate images
def sample(sess, noise, label, truncation=1., batch_size=8, vocab_size=vocab_size):
    noise = np.asarray(noise)
    label = np.asarray(label)
    num = noise.shape[0]
    if len(label.shape) == 0:
        label = np.asarray([label] * num)
    if label.shape[0] != num:
        raise ValueError('Got # noise samples ({}) != # label samples ({})'
                         .format(noise.shape[0], label.shape[0]))
    label = one_hot_if_needed(label, vocab_size)
    ims = []
    for batch_start in range(0, num, batch_size):
        s = slice(batch_start, min(num, batch_start + batch_size))
        feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation}
        ims.append(sess.run(output, feed_dict=feed_dict))
    ims = np.concatenate(ims, axis=0)
    assert ims.shape[0] == num
    ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
    ims = np.uint8(ims)
    return ims

def interpolate(a, b, num_interps):
    alphas = np.linspace(0, 1, num_interps)
    assert a.shape == b.shape, 'A and B must have the same shape to interpolate.'
    return np.array([(1-x)*a + x*b for x in alphas])

def interpolate_and_shape(a, b, steps):
    interps = interpolate(a, b, steps)
    return (interps.transpose(1, 0, *range(2, len(interps.shape))).reshape(steps, -1))

### Create a TensorFlow session and initialize variables

In [9]:
initializer = tf.global_variables_initializer()
sess = tf.Session()
sess.run(initializer)

### Create video of interpolated BigGAN generator samples

In [11]:
# category options: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
category = 947 # mushroom

# important parameter that controls how much variation there is
truncation = 0.2 # reasonable range: [0.02, 1]

seed_count = 10
clip_secs = 36

seed_step = int(100 / seed_count)
interp_frames = int(clip_secs * 30 / seed_count)  # interpolation frames

cat1 = category
cat2 = category
all_imgs = []

for i in range(seed_count):
    seed1 = i * seed_step # good range for seed is [0, 100]
    seed2 = ((i+1) % seed_count) * seed_step
    
    z1, z2 = [truncated_z_sample(truncation, seed) for seed in [seed1, seed2]]
    y1, y2 = [one_hot([category]) for category in [cat1, cat2]]

    z_interp = interpolate_and_shape(z1, z2, interp_frames)
    y_interp = interpolate_and_shape(y1, y2, interp_frames)

    imgs = sample(sess, z_interp, y_interp, truncation=truncation)
    
    all_imgs.extend(imgs[:-1])

# save the video for displaying in the next cell, this is way more space efficient than the gif animation
imageio.mimsave('gan.mp4', all_imgs, fps=30)

In [None]:
%%HTML
<video autoplay loop>
  <source src="gan.mp4" type="video/mp4">
</video>

The above code should generate a 512x512 video version of the following:

![BigGAN mushroom](https://i.imgur.com/TA9uh1a.gif)

# Acknowledgements

The content of this tutorial is based on and inspired by the work of [TensorFlow team](https://www.tensorflow.org) (see their [Colab notebooks](https://www.tensorflow.org/tutorials/)), [Google DeepMind](https://deepmind.com/), our [MIT Human-Centered AI team](https://hcai.mit.edu), and individual pieces referenced in the [MIT Deep Learning](https://deeplearning.mit.edu) course slides.

TF Colab and TF Hub content is copyrighted to The TensorFlow Authors (2018). Licensed under the Apache License, Version 2.0 (the "License"); http://www.apache.org/licenses/LICENSE-2.0