In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import model

In [2]:
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)
# define inputs

batch_size = 64
cond_dim = 128

with tf.variable_scope("inputs"):
    # discriminator
    img = tf.placeholder(tf.float32, shape= [None, 28, 28,1])
    # genarater
    rand_num = tf.placeholder(tf.float32, shape= [None, cond_dim])
    # condition
    # cond = tf.placeholder(tf.float32, shape= [None, mnist.train.labels.shape[1]])

Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz


In [3]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(32, 32), cmap='Greys_r')
    return fig

In [4]:
def gen_rand_num(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

In [5]:
G_infer = model.generator(rand_num)
D, D_logit = model.discriminator(img)
D_, D_logit_ = model.discriminator(G_infer, reuse =True)

generator  h0_BA  Tensor("generator/Relu:0", shape=(?, 2, 2, 256), dtype=float32)
generator  h1_BA  Tensor("generator/Relu_1:0", shape=(?, 4, 4, 128), dtype=float32)
generator  h2_BA  Tensor("generator/Relu_2:0", shape=(?, 8, 8, 64), dtype=float32)
generator  h3_BA  Tensor("generator/Relu_3:0", shape=(?, 16, 16, 32), dtype=float32)
generator  h4_A  Tensor("generator/Tanh:0", shape=(?, 32, 32, 1), dtype=float32)
discriminator  h0_A  Tensor("discriminator/Relu:0", shape=(?, 16, 16, 32), dtype=float32)
discriminator  h1_BA  Tensor("discriminator/Relu_1:0", shape=(?, 8, 8, 64), dtype=float32)
discriminator  h2_BA  Tensor("discriminator/Relu_2:0", shape=(?, 4, 4, 128), dtype=float32)
discriminator  h3_BA  Tensor("discriminator/bn3/batchnorm/add_1:0", shape=(?, 2, 2, 256), dtype=float32)
discriminator  h4  Tensor("discriminator/h4/BiasAdd:0", shape=(64, 1), dtype=float32)
discriminator  h0_A  Tensor("discriminator_1/Relu:0", shape=(?, 16, 16, 32), dtype=float32)
discriminator  h1_BA  Tensor(

In [6]:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith("generator")]
d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

In [None]:
G_loss, D_loss = model.loss(D_logit, D, D_logit_, D_)
D_solver = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(D_loss, var_list=d_vars)
G_solver = tf.train.AdamOptimizer(0.0002, beta1=0.5).minimize(G_loss, var_list=g_vars)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')
    
i = 0

for it in range(50000):
    
    img_batch, lab_batch = mnist.train.next_batch(batch_size)
    
    img_batch = np.reshape(img_batch, [batch_size,28 ,28, 1])
    
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={img: img_batch,
                                                             rand_num: gen_rand_num(batch_size, cond_dim)})
                                                             #cond: lab_batch})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={rand_num: gen_rand_num(batch_size, cond_dim)})
                                                             #cond: lab_batch})
                             

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()
        
        samples = sess.run(G_infer, feed_dict={rand_num: gen_rand_num(16, cond_dim)})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

Iter: 0
D loss: 1.399
G_loss: 3.212

Iter: 1000
D loss: 0.2292
G_loss: 4.179

Iter: 2000
D loss: 0.03599
G_loss: 4.888

Iter: 3000
D loss: 0.02266
G_loss: 5.113

Iter: 4000
D loss: 6.723
G_loss: 12.28

Iter: 5000
D loss: 0.003439
G_loss: 9.344

Iter: 6000
D loss: 2.845
G_loss: 16.49

Iter: 7000
D loss: 0.007591
G_loss: 4.885

Iter: 8000
D loss: 0.01026
G_loss: 7.339

Iter: 9000
D loss: 0.0227
G_loss: 5.334

Iter: 10000
D loss: 0.01317
G_loss: 5.082

Iter: 11000
D loss: 0.03526
G_loss: 6.885

Iter: 12000
D loss: 0.005731
G_loss: 5.517

Iter: 13000
D loss: 0.001897
G_loss: 6.014

Iter: 14000
D loss: 0.00675
G_loss: 5.577

Iter: 15000
D loss: 0.02729
G_loss: 4.251

Iter: 16000
D loss: 0.006403
G_loss: 5.7

Iter: 17000
D loss: 0.01469
G_loss: 5.816

Iter: 18000
D loss: 0.004165
G_loss: 5.942

Iter: 19000
D loss: 0.03064
G_loss: 6.298

Iter: 20000
D loss: 0.0231
G_loss: 6.042

Iter: 21000
D loss: 0.001982
G_loss: 9.123

Iter: 22000
D loss: 0.0006041
G_loss: 7.677

Iter: 23000
D loss: 0.0036

In [None]:
gen_rand_num(batch_size, cond_dim).shape