In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as grd
import os
from tensorflow.examples.tutorials.mnist import input_data

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def xavier_init(size):
    input_dim = size[0]
    xavier_variance = 1. / tf.sqrt(input_dim/2.)
    return tf.random_normal(shape=size, stddev=xavier_variance)

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = grd.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

In [3]:
Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')

G_W1 = tf.Variable(xavier_init([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(xavier_init([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]

X = tf.placeholder(tf.float32, shape=[None, 784], name='X')

D_W1 = tf.Variable(xavier_init([784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(xavier_init([128, 1]), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name='D_b2')

theta_D = [D_W1, D_W2, D_b1, D_b2]

def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    
    return G_prob

def descriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)
    
    return D_prob, D_logit

In [5]:
G_sample = generator(Z)

D_real, D_logit_real = descriminator(X)
D_fake, D_logit_fake = descriminator(G_sample)

D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)

G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

batch_size = 12
Z_dim = 100

sess = tf.Session()
sess.run(tf.global_variables_initializer())

mnist = input_data.read_data_sets('MNIST/', one_hot=True)

if not os.path.exists('output/'):
    os.makedirs('output/')
    
i = 0

for itr in range(1000000):
    if itr % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
        
        fig = plot(samples)
        plt.savefig('output/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i+=1
        plt.close(fig)
        
    X_mb, _ = mnist.train.next_batch(batch_size)
    
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z:sample_Z(batch_size, Z_dim)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z:sample_Z(batch_size, Z_dim)})
    
    if itr%1000==0:
        print('ITER: {}'.format(itr))
        print('D_loss: {:.4}'.format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Extracting MNIST/train-images-idx3-ubyte.gz
Extracting MNIST/train-labels-idx1-ubyte.gz
Extracting MNIST/t10k-images-idx3-ubyte.gz
Extracting MNIST/t10k-labels-idx1-ubyte.gz
ITER: 0
D_loss: 1.289
G_loss: 2.335

ITER: 1000
D_loss: 0.01849
G_loss: 6.35

ITER: 2000
D_loss: 0.003767
G_loss: 6.106

ITER: 3000
D_loss: 0.1
G_loss: 4.159

ITER: 4000
D_loss: 0.2406
G_loss: 4.782

ITER: 5000
D_loss: 0.2427
G_loss: 3.349

ITER: 6000
D_loss: 0.04012
G_loss: 3.779

ITER: 7000
D_loss: 0.3373
G_loss: 2.89

ITER: 8000
D_loss: 0.9332
G_loss: 3.058

ITER: 9000
D_loss: 0.9536
G_loss: 2.002

ITER: 10000
D_loss: 0.6368
G_loss: 2.095



KeyboardInterrupt: 