<a href="https://colab.research.google.com/github/jarrydmartinx/generative-models/blob/master/generative_adversarial_networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generative Adversarial Networks

## Install/Imports

In [1]:
# @title Imports

import matplotlib.pyplot as plt
import numpy as np
import sonnet as snt
import pandas as pd
import plotnine as gg
import tensorflow as tf
import tensorflow_datasets as tfds


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [0]:
# @title Get the MNIST data.
mnist = tfds.load('mnist')

## GAN training

In [0]:
# @title Hyperparameters.

BATCH_SIZE = 32
MIXING_RATIO = 0.5  # Fraction of MNIST vs fake (generated) images.

LATENT_SIZE = 12  # Size of the generator's latent space.
GENERATOR_HIDDENS = [50, 50, 50]
DISCRIMINATOR_HIDDENS = [50, 50]

NUM_TRAINING_STEPS = 2000
NUM_GENERATOR_STEPS_PER_DISCRIMINATOR_STEP = 10

DISC_LEARNING_RATE = 1e-3
GEN_LEARNING_RATE = 1e-3

In [0]:
# @title Make & train a GAN.

tf.reset_default_graph()

mnist = tfds.load('mnist')

dataset = mnist['train']

# Given the dataset, create an iterator that produces MNIST images.
real_images = (
    dataset
    .map(lambda x: tf.cast(x['image'], tf.float32) / tf.uint8.max)
    .batch(BATCH_SIZE)
    .repeat()
    .make_one_shot_iterator()
    .get_next()
)

# Get image shape (e.g. for MNIST: (28, 28, 1)), set static batch size.
image_shape = real_images.shape.as_list()[1:]  # [H, W, C]
real_images.set_shape([BATCH_SIZE] + image_shape)  # [B, H, W, C]

# Create the latent (noise) Tensor.
latent = tf.random_normal(shape=(BATCH_SIZE, LATENT_SIZE))  # [B, L]

# Create the generator network.
generator = snt.Sequential([
    snt.nets.MLP(GENERATOR_HIDDENS + [np.prod(image_shape)], 
                 name='generator'),
    lambda x: tf.reshape(x, [BATCH_SIZE] + image_shape),
#     tf.nn.sigmoid,
])

# This Tensor some fake images from the generator.
fake_images = generator(latent)  # [B, H, W, C]

# Mix real and fake images together randomly.
real = tf.random_uniform(shape=(BATCH_SIZE,)) < MIXING_RATIO
input_images = tf.where(real, real_images,  fake_images)

# Create the discriminator network.
discriminator = snt.Sequential([
    snt.BatchFlatten(),
    snt.nets.MLP(DISCRIMINATOR_HIDDENS + [2], name='discriminator'),
])

# The discriminator tries to classify inputs as either 'real' or 'fake'.
logits = discriminator(input_images)

# The discriminator loss is the binary cross-entropy.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits=logits, labels=tf.cast(real, tf.int32))
loss = tf.reduce_mean(loss)

# Create an op to track the discriminator's accuracy.
predicted = tf.argmax(logits, axis=-1, output_type=tf.int32)
acc = tf.cast(tf.equal(predicted, tf.cast(real, tf.int32)), tf.float32)
acc = tf.reduce_mean(acc)

# Create separate optimizers for the discriminator and generator.
disc_optimizer = tf.train.AdamOptimizer(DISC_LEARNING_RATE)
gen_optimizer = tf.train.AdamOptimizer(GEN_LEARNING_RATE)

# Get the discriminator and generator network variables.
disc_vars = discriminator.get_all_variables()
gen_vars = generator.get_all_variables()

# The discriminator seeks to *minimize* the discrimination loss.
disc_grads_and_vars = disc_optimizer.compute_gradients(loss, disc_vars)
disc_sgd_op = disc_optimizer.apply_gradients(disc_grads_and_vars)

# The generator seeks to *maximize* the discrimination loss.
gen_grads_and_vars = gen_optimizer.compute_gradients(-loss, gen_vars)
gen_sgd_op = gen_optimizer.apply_gradients(gen_grads_and_vars)

# Now we run our training loop.
with tf.train.MonitoredSession() as sess:
  results = []
  for step in range(NUM_TRAINING_STEPS):
    # Do a batch of SGD on the generator network.
    sess.run(gen_sgd_op)
    
    # Periodically do a batch of SGD on the discriminator network.
    if step % NUM_GENERATOR_STEPS_PER_DISCRIMINATOR_STEP == 0:
      sess.run(disc_sgd_op)
    
    # Log the loss and discriminator accuracy.
    loss_val, acc_val = sess.run([loss, acc])
    results.append({'step': step,
                    'loss': loss_val, 
                    'accuracy': acc_val})
    if step % max(1, NUM_TRAINING_STEPS // 10) == 0:
      print('Step {}/{}. Loss: {:.2f}. Acc: {:.2f}'.format(
          step, NUM_TRAINING_STEPS, loss_val, acc_val))
  
  # When training is finished, generate a batch of samples.
  samples = sess.run(fake_images)
  
  # Collect results into a Pandas dataframe for plotting.
  df = pd.DataFrame(results)





Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.


Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.




















Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor








Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where








INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


Step 0/2000. Loss: 0.65. Acc: 0.88
Step 200/2000. Loss: 1153.76. Acc: 0.44
Step 400/2000. Loss: 0.32. Acc: 0.62


In [0]:
# @title Plot

p = (gg.ggplot(df)
     + gg.aes(x='step', y='accuracy')   
     + gg.geom_line(color='Navy')
     + gg.ggtitle('Discriminator Accuracy')
    )
p

In [0]:
# @title Show some samples (needs tuning/improvement!)
plt.imshow(samples[4].squeeze())
plt.grid(False)