In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import model

In [2]:
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)
# define inputs

batch_size = 128
cond_dim = 128

with tf.variable_scope("inputs"):
    # discriminator
    img = tf.placeholder(tf.float32, shape= [None, mnist.train.images.shape[1]])
    # genarater
    rand_num = tf.placeholder(tf.float32, shape= [None, cond_dim])
    # condition
    cond = tf.placeholder(tf.float32, shape= [None, mnist.train.labels.shape[1]])

Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz


In [3]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

In [4]:
def gen_rand_num(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

In [5]:
G_infer = model.generator(rand_num, cond)
D_real, D_logit_real = model.discriminator(img, cond, reuse = None)
D_fake, D_logit_fake = model.discriminator(G_infer, cond, reuse =True)

In [6]:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith("generator")]
d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

In [7]:
G_loss, D_loss = model.loss(D_logit_real, D_logit_fake)
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=d_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=g_vars)

In [8]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')
    
i = 0

for it in range(50000):
    if it % 1000 == 0:
        samples = sess.run(G_infer, feed_dict={rand_num: gen_rand_num(16, cond_dim)})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    img_batch, lab_batch = mnist.train.next_batch(batch_size)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={img: img_batch,
                                                             rand_num: gen_rand_num(batch_size, cond_dim),
                                                             cond: lab_batch})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={rand_num: gen_rand_num(batch_size, cond_dim),
                                                             cond: lab_batch})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Iter: 0
D loss: 1.301
G_loss: 1.969

Iter: 1000
D loss: 0.006551
G_loss: 6.72

Iter: 2000
D loss: 0.03962
G_loss: 5.092

Iter: 3000
D loss: 0.07021
G_loss: 6.32

Iter: 4000
D loss: 0.1594
G_loss: 5.683

Iter: 5000
D loss: 0.1754
G_loss: 4.731

Iter: 6000
D loss: 0.2947
G_loss: 5.038

Iter: 7000
D loss: 0.3887
G_loss: 3.698

Iter: 8000
D loss: 0.3818
G_loss: 3.771

Iter: 9000
D loss: 0.3744
G_loss: 4.23

Iter: 10000
D loss: 0.4463
G_loss: 3.479

Iter: 11000
D loss: 0.7283
G_loss: 3.402

Iter: 12000
D loss: 0.751
G_loss: 2.339

Iter: 13000
D loss: 0.609
G_loss: 2.89

Iter: 14000
D loss: 0.6679
G_loss: 2.333

Iter: 15000
D loss: 0.6427
G_loss: 2.669

Iter: 16000
D loss: 0.7028
G_loss: 2.514

Iter: 17000
D loss: 0.6326
G_loss: 2.634

Iter: 18000
D loss: 0.8207
G_loss: 2.197

Iter: 19000
D loss: 0.9323
G_loss: 2.093

Iter: 20000
D loss: 0.7565
G_loss: 2.505

Iter: 21000
D loss: 0.8801
G_loss: 2.044

Iter: 22000
D loss: 0.8105
G_loss: 1.856

Iter: 23000
D loss: 0.6272
G_loss: 2.255

Iter: 24