In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

In [2]:
mnist = input_data.read_data_sets('../../dataset/MNIST', one_hot=True)
Z_dim = 128
mb_size = 64
# define inputs
X = tf.placeholder(tf.float32, shape=[None, 784])
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../dataset/MNIST\train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../dataset/MNIST\train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../../dataset/MNIST\t10k-images-idx3-ubyte.gz
Extracting ../../dataset/MNIST\t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [3]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [4]:
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def generator(z):
    with tf.variable_scope("generator"):
        G_h1 = tf.layers.dense(Z,
                               128,
                               activation=tf.nn.relu,
                               kernel_initializer=tf.glorot_normal_initializer())
        G_prob = tf.layers.dense(G_h1, 784, activation = tf.nn.sigmoid, kernel_initializer=tf.glorot_normal_initializer())
    
        return G_prob

def discriminator(x, reuse):
    with tf.variable_scope("discriminator"):
        D_h1 = tf.layers.dense(x,
                               128,
                               reuse = reuse,
                               activation=tf.nn.relu,
                               kernel_initializer=tf.glorot_normal_initializer(), 
                               name = "L1")
        D_logit = tf.layers.dense(D_h1, 1, reuse = reuse, kernel_initializer=tf.glorot_normal_initializer(), name = "L2")
        D_prob = tf.nn.sigmoid(D_logit)
        return D_prob, D_logit

In [5]:
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X, reuse=None)
D_fake, D_logit_fake = discriminator(G_sample, reuse=True)

In [6]:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith("generator")]
d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

In [7]:
# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

# Alternative losses:
# -------------------
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=d_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=g_vars)

In [8]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')

for it in range(50000):
    X_mb, _ = mnist.train.next_batch(mb_size)
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb,
                                                             Z: sample_Z(mb_size, Z_dim)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
       

    if it % 1000 == 0:
        print('Iter: {}, D loss: {:.4}, G_loss: {:.4}'.format(it,D_loss_curr,G_loss_curr))
        
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
        
        fig = plot(samples)
        plt.savefig('out/ex03_VanillaGAN/{}.png'.format(str(it).zfill(5)), bbox_inches='tight')
        plt.close(fig)
sess.close()

Iter: 0, D loss: 1.315, G_loss: 2.649
Iter: 1000, D loss: 0.01144, G_loss: 8.789
Iter: 2000, D loss: 0.03487, G_loss: 4.683
Iter: 3000, D loss: 0.1774, G_loss: 4.937
Iter: 4000, D loss: 0.1987, G_loss: 4.301
Iter: 5000, D loss: 0.2622, G_loss: 4.43
Iter: 6000, D loss: 0.3079, G_loss: 3.489
Iter: 7000, D loss: 0.3133, G_loss: 3.42
Iter: 8000, D loss: 0.4203, G_loss: 2.889
Iter: 9000, D loss: 0.6059, G_loss: 3.232
Iter: 10000, D loss: 0.4173, G_loss: 3.055
Iter: 11000, D loss: 0.663, G_loss: 2.597
Iter: 12000, D loss: 1.101, G_loss: 2.835
Iter: 13000, D loss: 0.6209, G_loss: 3.134
Iter: 14000, D loss: 0.6785, G_loss: 2.287
Iter: 15000, D loss: 0.8266, G_loss: 1.668
Iter: 16000, D loss: 0.6918, G_loss: 2.318
Iter: 17000, D loss: 0.8854, G_loss: 1.905
Iter: 18000, D loss: 0.5246, G_loss: 2.361
Iter: 19000, D loss: 0.7041, G_loss: 1.945
Iter: 20000, D loss: 0.45, G_loss: 2.313
Iter: 21000, D loss: 0.5956, G_loss: 2.095
Iter: 22000, D loss: 0.5057, G_loss: 2.458
Iter: 23000, D loss: 0.7316, 