# Adversarial Networks using TensorFlow and the MNIST Dataset

In [1]:
import numpy as np
import sys
import tensorflow as tf
import time

import load_mnist

from sklearn.ensemble import RandomForestClassifier

%matplotlib inline
from IPython.core.pylabtools import figsize
import matplotlib.pyplot as plt
from matplotlib import cm

## Fetch and Read In Data

In [2]:
train_im, test_im, train_labels, test_labels = load_mnist.Datasets()

## Some Utility Functions

In [3]:
def Weight(shape, init=0.01):
    initial = tf.random_uniform(shape, minval=-init, maxval=init)
    return tf.Variable(initial, name="weights")

def Bias(shape, init=0.01):
    initial = tf.constant(init, shape=shape)
    return tf.Variable(initial, name="bias")

def StridedConv(x, W):
    return tf.nn.conv2d(
        x, W,
        strides=[1, 2, 2, 1],
        padding="SAME", name="strided_conv")

def FractionallyStridedConv(x, W, output_channels=None):
    batchsize, height, width, kernels = x.get_shape().as_list()
    if output_channels is None:
        output_channels = kernels / 2
    return tf.nn.deconv2d(
        x, W,
        [-1, height * 2, width * 2, output_channels],
        [1, 2, 2, 1],
        name="fractionally_strided_conv")

def ELU(x):
    pos = tf.cast(tf.greater_equal(x, 0), tf.float32)
    return (pos * x) + ((1 - pos) * (tf.exp(x) - 1))

In [4]:
try:
    del(sess)
    print "deleted session"
except Exception as e:
    print "no existing session to delete"
sess = tf.InteractiveSession()

m = np.arange(16).reshape(1, 4, 4, 1)
w = np.array([1, 10, 0.5, 1]).reshape(2, 2, 1, 1)

t1 = tf.placeholder(tf.float32, [1, 4, 4, 1])
W1 = tf.placeholder(tf.float32, [2, 2, 1, 1])

t2 = tf.nn.deconv2d(t1, W1, [1, 8, 8, 1], [1, 2, 2, 1])

t2.eval(feed_dict={t1: m, W1: w}).reshape(8, 8)

no existing session to delete


array([[   0. ,    0. ,    1. ,   10. ,    2. ,   20. ,    3. ,   30. ],
       [   0. ,    0. ,    0.5,    1. ,    1. ,    2. ,    1.5,    3. ],
       [   4. ,   40. ,    5. ,   50. ,    6. ,   60. ,    7. ,   70. ],
       [   2. ,    4. ,    2.5,    5. ,    3. ,    6. ,    3.5,    7. ],
       [   8. ,   80. ,    9. ,   90. ,   10. ,  100. ,   11. ,  110. ],
       [   4. ,    8. ,    4.5,    9. ,    5. ,   10. ,    5.5,   11. ],
       [  12. ,  120. ,   13. ,  130. ,   14. ,  140. ,   15. ,  150. ],
       [   6. ,   12. ,    6.5,   13. ,    7. ,   14. ,    7.5,   15. ]], dtype=float32)

## Generative Adversarial Networks
See Goodfellow et al (http://arxiv.org/pdf/1406.2661v1.pdf) for a description of GANs.

In essence, we pit two network against each other in a game. A discriminative network attempts to determine whether an input image belongs to the training distribution, or is a forgery. A generative network attempts to produce forgeries which fool the discriminative network.

See Radford, Metz, & Chintala (http://arxiv.org/pdf/1511.06434v1.pdf) for useful constraints on GAN architecture. The authors make the following recommendations:

- Use strided convolutions rather than max pooling.
- Use batchnorm everywhere other than generator output and discriminator input.
- No dense hidden layers.
- ReLU units in the generator (except output, which uses Tanh) and Leaky ReLU in the discriminator.

I'll be using exponential liniar units (ELU) instead of ReLU / Leaky ReLU, and see if I can get away without batchnorm, as the ELU addresses some of the same issues: namely, encouraging inputs to each unit to approach 0, and allowing gradients to penetrate further into deep networks. Otherwise, I'll attempt to cleave to their suggestions.

The general architecture that Radford et al recommend for the generator is based on DCGAN. A vector of uniformly distributed noise is projected onto a small 2d image with a large number of channels. Repeated "fractionally strided convolutions" (sometimes erroneously called "deconvolutions") successfively scale down the number of channels and scale up the size of the image.

In [5]:
try:
    del(sess)
    print "deleted session"
except Exception as e:
    print "no existing session to delete"
sess = tf.InteractiveSession()

deleted session


#### Generative Network
The generative network maps a vector of uniform random noise inputs to a 28 x 28 1-channel image via a stack of three fractionally strided convolutional layers.

In [6]:
G_nonlin = ELU

# Input a stack of 100-dimensional noise vectors.
# This spans the underlying object space of 10 digits.
G_x = tf.placeholder(tf.float32, [None, 100])

# Project noise into a 4x4 image with 256 channels.
# We won't update the weight that controls the projection during training.
with tf.name_scope("G_projection") as scope:
    G_W_proj = Weight([100, 4 * 4 * 256])
    G_b_proj = Bias([4 * 4 * 256])
    G_h_proj = tf.reshape(tf.matmul(G_x, G_W_proj) + G_b_proj, [-1, 4, 4, 256])

# Apply fractionally strided convolutions to decrease number of channels and increase image size.
# 4x4 image w/ 256 channels => 8x8 image w/ 128 channels
with tf.name_scope("G_fsc1") as scope:
    G_W_fsc1 = Weight([2, 2, 128, 256])
    G_b_fsc1 = Bias([128])
    G_h_fsc1 = G_nonlin(FractionallyStridedConv(G_h_proj, G_W_fsc1) + G_b_fsc1)

# 8x8 image w/ 128 channels to 16x16 image with 64 channels
with tf.name_scope("G_fsc2") as scope:
    G_W_fsc2 = Weight([2, 2, 64, 128])
    G_b_fsc2 = Bias([64])
    G_h_fsc2 = G_nonlin(FractionallyStridedConv(G_h_fsc1, G_W_fsc2) + G_b_fsc2)

# 16x16 image w/ 64 channels to 32x32 image with 1 channel
with tf.name_scope("G_fsc3") as scope:
    G_W_fsc3 = Weight([2, 2, 1, 64])
    G_b_fsc3 = Bias([1])
    G_h_fsc3 = tf.nn.tanh(FractionallyStridedConv(G_h_fsc2, G_W_fsc3, output_channels=1) + G_b_fsc3)

# Pretty output in the same scale as images
with tf.name_scope("G_out") as scope:
    G_y = tf.slice(G_h_fsc3, [0, 2, 2, 0], [-1, 28, 28, -1])
    G_image = (G_y + 1) * 127.5

#### Discriminative Network

The discriminative network maps a 28 x 28 1-channel image to single float between 0 and 1, representing the probability that the image came from the training data distribution rather than the generative network.

In [7]:
D_nonlin = ELU

# Input is a stack of 28 x 28 black and white images with activations from 0 to 255.
D_x = G_y

# Target vector is a 1 (real) or 0 (forgery) for each input.
D_y_ = tf.placeholder(tf.float32, [None])

# Stack strided convolutional layers.
# Input should be 28x28, output 14x14
with tf.name_scope("D_sc1") as scope:
    D_W_conv1 = Weight([3, 3, 1, 16])
    D_b_conv1 = Bias([16])
    D_h_conv1 = D_nonlin(StridedConv(D_x, D_W_conv1) + D_b_conv1)

# Input should be 14x14, output 7x7
with tf.name_scope("D_sc2") as scope:
    D_W_conv2 = Weight([3, 3, 16, 32])
    D_b_conv2 = Bias([32])
    D_h_conv2 = D_nonlin(StridedConv(D_h_conv1, D_W_conv2) + D_b_conv2)

# Output a single float between 0 and 1
with tf.name_scope("D_output") as scope:
    D_h_flat = tf.reshape(D_h_conv2, [-1, 7 * 7 * 32])
    D_W_out = Weight([7 * 7 * 32, 1])
    D_b_out = Bias([1])
    D_y = tf.nn.sigmoid(tf.matmul(D_h_flat, D_W_out) + D_b_out)

#### Adversarial Training
0. Start with a batch of noise vectors.
1. Feed the noise vectors to the generator to create a batch of forgeries.
2. Mix in with a batch of real training images.
3. Train the discriminator on the mixed bag.
4. Train the generator on the noise vectors from step 0, using the output of the discriminator as the error signal.

In [None]:
# Generator
# When we evaluate the generator training step, we feed noise into G_x.
# Never evaluate the generator train step when feeding real images into D_x.

G_err = -tf.reduce_mean(tf.log(D_y))
G_train_step = tf.train.AdamOptimizer(0.001).minimize(
    G_err,
    var_list=[
        G_W_fsc1, G_b_fsc1,
        G_W_fsc2, G_b_fsc2,
        G_W_fsc3, G_b_fsc3])

# Discriminator
# When we evauate the descriminator train step, we should alternate between minibatches
# wherein we provide real images fed into D_x and noise fed into G_x.

D_xent = -tf.reduce_mean(D_y_ * tf.log(D_y)) - tf.reduce_mean((1 - D_y_) * tf.log(1 - D_y))
D_accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(tf.greater_equal(D_y, 0.5), tf.float32), D_y_), tf.float32))
D_train_step = tf.train.AdamOptimizer(0.00001).minimize(
    D_xent,
    var_list=[
        D_W_conv1, D_b_conv1,
        D_W_conv2, D_b_conv2,
        D_W_out, D_b_out])

# Initialize

sess.run(tf.initialize_all_variables())

#### Main loop

In [None]:
num_epochs = 10
batch_size = 100
test_every_n_batches = 30

figsize(6, 1.5)

def PrepBatch(ims, start, stop):
    batch = ims[start:stop].reshape((stop - start,) + ims.shape[1:3] + (1,))
    return (batch / 127.5) - 1
    

batches_per_epoch = train_im.shape[0] / batch_size
mark = time.time()
train_gen = True
train_disc = True
for ep in xrange(num_epochs):
    for i in xrange(batches_per_epoch):
        sys.stdout.write(".")
        sys.stdout.flush()
        start_offset = i * batch_size
        stop_offset = start_offset + batch_size
        
        # Train discriminator
        if train_disc:
            # ...on real images
            D_train_step.run(feed_dict={
                D_x: PrepBatch(train_im, start_offset, stop_offset),
                D_y_: np.ones(batch_size)})

            # ...on forgeries
            D_train_step.run(feed_dict={
                G_x: np.random.random((batch_size, 100)),
                D_y_: np.zeros(batch_size)})

        # Train the generator
        if train_gen:
            G_train_step.run(feed_dict={
                G_x: np.random.random((batch_size, 32))})

        if (i + 1) % test_every_n_batches == 0:
            ac_real, xent_real = sess.run(
                [D_accuracy, D_xent], feed_dict={
                    D_x: PrepBatch(test_im, 0, 100),
                    D_y_: np.ones(100)})
            ac_forged, xent_forged, im_forged, ger = sess.run(
                [D_accuracy, D_xent, G_image, G_err],
                feed_dict={
                    G_x: np.random.random((100, 100)),
                    D_y_: np.zeros(100)})
            print ("\nEpoch {ep}, batch {ba} ({t:.1f} seconds since last report)"
                   "\nDISCRIMINATOR: real {acr:.2f}% / {xr:.5f}, "
                   "forged accuracy {acf:.1f}% / {xf:.5f}"
                   "\nGENERATOR: error = {ger:.2f}").format(
                ep=ep, ba=i, t=time.time() - mark,
                acr=ac_real * 100, xr=xent_real, acf=ac_forged * 100, xf=xent_forged, ger=ger)
            _, axes = plt.subplots(1, 3)
            for j in xrange(3):
                axes[j].imshow(im_forged[j], cmap=cm.Blues)
            plt.show()
            mark = time.time()

.