In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import model

In [2]:
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)
# define inputs

batch_size = 128
cond_dim = 128

with tf.variable_scope("inputs"):
    # discriminator
    img = tf.placeholder(tf.float32, shape= [None, mnist.train.images.shape[1]])
    # genarater
    rand_num = tf.placeholder(tf.float32, shape= [None, cond_dim])
    # condition
    cond = tf.placeholder(tf.float32, shape= [None, mnist.train.labels.shape[1]])

Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz


In [3]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

In [4]:
def gen_rand_num(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

In [5]:
G_infer = model.generator(rand_num, cond)
D_real, D_logit_real = model.discriminator(img, cond, reuse = None)
D_fake, D_logit_fake = model.discriminator(G_infer, cond, reuse =True)

In [6]:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith("generator")]
d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

In [7]:
G_loss, D_loss = model.loss(D_logit_real, D_logit_fake)
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=d_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=g_vars)

In [8]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')
    
i = 0

for it in range(50000):
    
    img_batch, lab_batch = mnist.train.next_batch(batch_size) 
    
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={img: img_batch,
                                                             rand_num: gen_rand_num(batch_size, cond_dim),
                                                             cond: lab_batch})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={rand_num: gen_rand_num(batch_size, cond_dim),
                                                             cond: lab_batch})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()
        
        samples = sess.run(G_infer, feed_dict={rand_num: gen_rand_num(16, cond_dim), cond: lab_batch[0:16,]})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

Iter: 0
D loss: 1.427
G_loss: 2.564

Iter: 1000
D loss: 0.01495
G_loss: 5.669

Iter: 2000
D loss: 0.02025
G_loss: 5.948

Iter: 3000
D loss: 0.1063
G_loss: 6.247

Iter: 4000
D loss: 0.1075
G_loss: 6.23

Iter: 5000
D loss: 0.3573
G_loss: 5.664

Iter: 6000
D loss: 0.4948
G_loss: 4.398

Iter: 7000
D loss: 0.3705
G_loss: 4.478

Iter: 8000
D loss: 0.4734
G_loss: 3.574

Iter: 9000
D loss: 0.4031
G_loss: 3.88

Iter: 10000
D loss: 0.7183
G_loss: 2.541

Iter: 11000
D loss: 0.7249
G_loss: 2.364

Iter: 12000
D loss: 0.7387
G_loss: 3.363

Iter: 13000
D loss: 0.6071
G_loss: 2.861

Iter: 14000
D loss: 0.7015
G_loss: 2.266

Iter: 15000
D loss: 0.7693
G_loss: 2.268

Iter: 16000
D loss: 0.6508
G_loss: 2.26

Iter: 17000
D loss: 0.6397
G_loss: 2.229

Iter: 18000
D loss: 1.006
G_loss: 2.391

Iter: 19000
D loss: 0.9215
G_loss: 1.887

Iter: 20000
D loss: 0.7909
G_loss: 1.987

Iter: 21000
D loss: 0.7099
G_loss: 2.017

Iter: 22000
D loss: 0.7891
G_loss: 1.961

Iter: 23000
D loss: 0.7039
G_loss: 2.218

Iter: 24