In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

  from ._conv import register_converters as _register_converters


Instructions for updating:
Use the retry module or similar alternatives.


In [48]:
tf.reset_default_graph()

In [49]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

def sample_z(n,m):
    return np.random.uniform(-1.,1., size=[n,m])

In [50]:
Z = tf.placeholder(tf.float32, shape=[None, 100])
X = tf.placeholder(tf.float32, shape=[None, 784])

In [51]:
def generator(z):
    with tf.variable_scope("generator",reuse=tf.AUTO_REUSE):
        x =  tf.layers.dense(z, 128, activation=tf.nn.relu)
        x =  tf.layers.dense(x,784)
        x = tf.nn.sigmoid(x)
    return x

In [52]:
def discriminator(x):
    with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
        enc = tf.layers.dense(x, 128, activation=tf.nn.relu)
        dec = tf.layers.dense(enc, 784)
        mse = tf.reduce_mean(tf.reduce_sum((x - dec)**2, 1))
    return mse

In [53]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [54]:
G_sample = generator(Z)
D_real = discriminator(X)
D_fake = discriminator(G_sample)

In [55]:
m=5
D_loss = D_real + tf.maximum(0., m-D_fake)
G_loss = D_fake

In [56]:
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list = [v for v in tf.trainable_variables() if v.name.startswith('disc')])
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list = [v for v in tf.trainable_variables() if v.name.startswith('gen')])

In [57]:
z_size = 100
mb_size = 16

mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True)

Extracting ../data/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../data/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../data/MNIST_data/t10k-labels-idx1-ubyte.gz


In [None]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())

In [None]:
if not os.path.exists('../out/'):
    os.makedirs('../out/')

i = 0

for it in range(1000000):
    if it % 100 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_z(16, z_size)})

        fig = plot(samples)
        plt.savefig('../out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_z(mb_size, z_size)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_z(mb_size, z_size)})


    if it % 100 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Iter: 0
D loss: 105.2
G_loss: 208.5

Iter: 100
D loss: 30.29
G_loss: 6.579

Iter: 200
D loss: 19.77
G_loss: 7.052

Iter: 300
D loss: 14.99
G_loss: 5.152

Iter: 400
D loss: 15.0
G_loss: 5.455

Iter: 500
D loss: 13.6
G_loss: 5.254

Iter: 600
D loss: 9.509
G_loss: 5.618

Iter: 700
D loss: 11.62
G_loss: 5.128

Iter: 800
D loss: 9.901
G_loss: 5.079

Iter: 900
D loss: 10.42
G_loss: 5.429

Iter: 1000
D loss: 9.347
G_loss: 5.057

Iter: 1100
D loss: 9.592
G_loss: 5.535

Iter: 1200
D loss: 8.553
G_loss: 5.776

Iter: 1300
D loss: 9.086
G_loss: 5.757

Iter: 1400
D loss: 9.623
G_loss: 6.629

Iter: 1500
D loss: 8.636
G_loss: 6.17

Iter: 1600
D loss: 8.166
G_loss: 5.188

Iter: 1700
D loss: 7.875
G_loss: 5.159

Iter: 1800
D loss: 8.469
G_loss: 6.269

Iter: 1900
D loss: 7.626
G_loss: 5.429

Iter: 2000
D loss: 8.065
G_loss: 5.691

Iter: 2100
D loss: 7.315
G_loss: 5.83

Iter: 2200
D loss: 7.367
G_loss: 6.426

Iter: 2300
D loss: 7.506
G_loss: 6.161

Iter: 2400
D loss: 7.148
G_loss: 5.166

Iter: 2500
D los