#### Copyright 2019 Google LLC.

In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generative Adversarial Networks

This colab is a highly modified version of [this blog post](https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3)., by Renu Khandelwal. For a beautiful illustration of how GANs work, see also this [interactive demonstration](https://poloclub.github.io/ganlab/).

## Overview

### Learning Objectives

* Create a Generative Adversarial Network

### Prerequisites

* Embeddings
* Recurrent Neural Networks

### Estimated Duration

60 minutes

### Grading Criteria

Each exercise is worth 3 points. The rubric for calculating those points is:

| Points | Description |
|--------|-------------|
| 0      | No attempt at exercise |
| 1      | Attempted exercise, but code does not run |
| 2      | Attempted exercise, code runs, but produces incorrect answer |
| 3      | Exercise completed successfully |

There are 4 exercises in this Colab so there are 12 points available. The grading scale will be 6 points.

## Understanding GANs

A Generative Adversarial Network (GAN) is an unsupervised deep learning algorithm where a "Generator" Network is pitted against an adversarial network, called the "Discriminator.""

The role of the Generator is to create new objects similar to those in a set of training data. For this colab the training data will come from the MNIST images of handwritten numerals. Our Generator that we create will produce images similar to those. 

The role of the Discriminator is to distinguish between images that are not in our training set (including random noise), and those that are. 

The strategy of a GAN is to put these together, and alternately train them. We ask the Generator to produce an image, and train the Discriminator to recognize that this is artificially constructed (i.e. it did not come from the training set). We then ask the Generator to refine its output so that the Discriminator is fooled. Then we go back to training the Discriminator, and repeat.

**Before starting this Colab, enable the GPU runtime.**

We begin by importing all required libraries.


In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.optimizers import adam

We now loading the data from the MNIST dataset. 

In [0]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5)/127.5

# convert shape of x_train from (60000, 28, 28) to (60000, 784) 
# 784 columns per row
X_train = x_train.reshape(60000, 784)

X_train.shape

We create Generator as a sequence of dense layers.

In [0]:
generator=Sequential([
    Dense(units=256,input_dim=100,activation='relu'),
    Dense(units=512, activation='relu'),
    Dense(units=1024, activation='relu'),
    Dense(units=784, activation='tanh'),
])

generator.summary()
    
generator.compile(loss='binary_crossentropy', optimizer=adam(lr=0.0002, beta_1=0.5))


The Discriminator is also a sequence of dense layers. The input will have shape (784,) (originally coming from a 28x28 image), and the output will be a number between 0 and 1, representing the confidence of the Discriminator that its input is "genuine" (i.e. from the training set).

In [0]:
##model building
discriminator = Sequential([
    Dense(1024, activation='relu',input_dim=784),
    Dropout(0.3),
    Dense(512, activation='relu'),
    Dropout(0.3),
    Dense(256, activation='relu'),
    Dropout(0.2),
    Dense(1, activation='sigmoid')
])


discriminator.compile(loss='binary_crossentropy', optimizer=adam(lr=0.0002, beta_1=0.5))

discriminator.summary()

We now create the GAN where we combine the Generator and Discriminator. When we train the Generator we will freeze the Discriminator.

We will input the noised image of shape 100 units to the Generator. The output generated from the Generator will be fed to the Discriminator.

In [0]:
#When we train the GAN that we are about to create, we will freeze the parameters of the Discriminator. 
#It is not necessary to do this now, but it will help us determine how many parameters there are at training time.
discriminator.trainable=False 

gan_input = Input(shape=(100,))
x = generator(gan_input)
gan_output= discriminator(x)
gan= Model(inputs=gan_input, outputs=gan_output)

gan.summary()

gan.compile(loss='binary_crossentropy', optimizer='adam')

Before we start training the model, we will write a function plot_generated_images to plot the generated images. This way we can see how the images are generated. We save the generated images to file that we can view later

In [0]:
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
    noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(100,28,28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], cmap='gray', interpolation='nearest')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image %d.png' %epoch)

We finally start to train GAN. 

In [0]:
epochs=200
batch_size=128
        
for e in range(1,epochs+1 ):
  print("Epoch %d" %e)
  for _ in tqdm(range(100)): #tqdm just displays a progress bar 

    #generate  random noise as an input  to  initialize the  generator
    noise= np.random.normal(0,1, [batch_size, 100])

    # Generate fake MNIST images from noised input
    generated_images = generator.predict(noise)

    # Get a random set of  real images
    image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]

    #Construct different batches of  real and fake data 
    X= np.concatenate([image_batch, generated_images])

    # Labels for generated and real data
    y_dis=np.zeros(2*batch_size)
    y_dis[:batch_size]=0.9

    #Pre train discriminator on  fake and real data  before starting the gan. 
    discriminator.trainable=True
    discriminator.train_on_batch(X, y_dis)

    #Tricking the noised input of the Generator as real data
    y_gen = np.ones(batch_size)

    # During the training of gan, 
    # the weights of discriminator should be fixed. 
    #We can enforce that by setting the trainable flag
    discriminator.trainable=False

    #training  the GAN by alternating the training of the Discriminator 
    #and training the chained GAN model with Discriminator’s weights freezed.
    gan.train_on_batch(noise, y_gen)

  if e == 1 or e % 20 == 0:
           
    plot_generated_images(e, generator)



As an illustration of the power of our generator, we will choose two random inputs and produce a sequence of characters that interpolate between them.

In [0]:
noise= np.random.normal(loc=0, scale=1, size=[2,100])

noise_interp=np.zeros((60,100))
for i in range(60):
  t=i/59
  noise_interp[i]=t*noise[0]+(1-t)*noise[1]
  
images = generator.predict(noise_interp)
images = images.reshape(60,28,28)
  

In [0]:
plt.imshow(images[0],cmap='gray')

In [0]:
plt.imshow(images[59],cmap='gray')

In [0]:
dim=(10,6)
plt.figure(figsize=(10,10))
for i in range(60):
  plt.subplot(dim[0], dim[1], i+1)
  plt.imshow(images[i], cmap='gray')
  plt.axis('off')


# Exercises

## Exercise 1

Import the Sign Language MNIST dataset from Kaggle, and train the above GAN on it. When it is done, use OpenCV to output the frames of the final image interpolation to a video file.

### Student Solution

In [0]:
## Your code here.

### Answer Key

**Solution**

In [0]:
# TODO(joshmcadams)

**Validation**

In [0]:
# TODO(joshmcadams)