In [18]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)


X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]


Z = tf.placeholder(tf.float32, shape=[None, 16])
c = tf.placeholder(tf.float32, shape=[None, 10])

G_W1 = tf.Variable(xavier_init([26, 256]))
G_b1 = tf.Variable(tf.zeros(shape=[256]))

G_W2 = tf.Variable(xavier_init([256, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]


Q_W1 = tf.Variable(xavier_init([784, 128]))
Q_b1 = tf.Variable(tf.zeros(shape=[128]))

Q_W2 = tf.Variable(xavier_init([128, 10]))
Q_b2 = tf.Variable(tf.zeros(shape=[10]))

theta_Q = [Q_W1, Q_W2, Q_b1, Q_b2]


def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])


def sample_c(m):
    return np.random.multinomial(1, 10*[0.1], size=m)


def generator(z, c):
    inputs = tf.concat(axis=1, values=[z, c])
    G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob


def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob


def Q(x):
    Q_h1 = tf.nn.relu(tf.matmul(x, Q_W1) + Q_b1)
    Q_prob = tf.nn.softmax(tf.matmul(Q_h1, Q_W2) + Q_b2)

    return Q_prob


def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


G_sample = generator(Z, c)
D_real = discriminator(X)
D_fake = discriminator(G_sample)
Q_c_given_x = Q(G_sample)

D_loss = -tf.reduce_mean(tf.log(D_real + 1e-8) + tf.log(1 - D_fake + 1e-8))
G_loss = -tf.reduce_mean(tf.log(D_fake + 1e-8))

cross_ent = tf.reduce_mean(-tf.reduce_sum(tf.log(Q_c_given_x + 1e-8) * c, 1))
ent = tf.reduce_mean(-tf.reduce_sum(tf.log(c + 1e-8) * c, 1))
Q_loss = cross_ent + ent

D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
Q_solver = tf.train.AdamOptimizer().minimize(Q_loss, var_list=theta_G + theta_Q)

mb_size = 32
Z_dim = 16

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')

i = 0

for it in range(1000000):
    if it % 1000 == 0:
        Z_noise = sample_Z(16, Z_dim)

        idx = np.random.randint(0, 10)
        c_noise = np.zeros([16, 10])
        c_noise[range(16), idx] = 1

        samples = sess.run(G_sample,
                           feed_dict={Z: Z_noise, c: c_noise})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size)
    Z_noise = sample_Z(mb_size, Z_dim)
    c_noise = sample_c(mb_size)

    _, D_loss_curr = sess.run([D_solver, D_loss],
                              feed_dict={X: X_mb, Z: Z_noise, c: c_noise})

    _, G_loss_curr = sess.run([G_solver, G_loss],
                              feed_dict={Z: Z_noise, c: c_noise})

    sess.run([Q_solver], feed_dict={Z: Z_noise, c: c_noise})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz
Iter: 0
D loss: 1.794
G_loss: 1.644

Iter: 1000
D loss: 0.0596
G_loss: 6.506

Iter: 2000
D loss: 0.04315
G_loss: 5.196

Iter: 3000
D loss: 0.02801
G_loss: 5.146

Iter: 4000
D loss: 0.06162
G_loss: 6.035

Iter: 5000
D loss: 0.1742
G_loss: 6.137

Iter: 6000
D loss: 0.2251
G_loss: 3.609

Iter: 7000
D loss: 0.5696
G_loss: 3.728

Iter: 8000
D loss: 0.3282
G_loss: 3.383

Iter: 9000
D loss: 0.5757
G_loss: 3.857

Iter: 10000
D loss: 0.4715
G_loss: 3.47

Iter: 11000
D loss: 0.6542
G_loss: 3.517

Iter: 12000
D loss: 0.9598
G_loss: 2.434

Iter: 13000
D loss: 0.7852
G_loss: 1.627

Iter: 14000
D loss: 0.8352
G_loss: 2.336

Iter: 15000
D loss: 0.4842
G_loss: 2.7

Iter: 16000
D loss: 0.7724
G_loss: 2.336

Iter: 17000
D loss: 0.7604
G_loss: 1.698

Iter: 18000
D loss: 0.9171
G_loss: 1.9



Iter: 191000
D loss: 0.9797
G_loss: 1.411

Iter: 192000
D loss: 0.7651
G_loss: 1.574

Iter: 193000
D loss: 1.047
G_loss: 1.52

Iter: 194000
D loss: 1.205
G_loss: 1.532

Iter: 195000
D loss: 0.8976
G_loss: 1.374

Iter: 196000
D loss: 0.6857
G_loss: 1.532

Iter: 197000
D loss: 0.9933
G_loss: 1.374

Iter: 198000
D loss: 1.022
G_loss: 1.284

Iter: 199000
D loss: 0.867
G_loss: 1.647

Iter: 200000
D loss: 1.015
G_loss: 1.215

Iter: 201000
D loss: 0.8116
G_loss: 2.28

Iter: 202000
D loss: 0.8718
G_loss: 1.558

Iter: 203000
D loss: 0.8696
G_loss: 1.377

Iter: 204000
D loss: 0.9371
G_loss: 1.385

Iter: 205000
D loss: 0.7908
G_loss: 1.418

Iter: 206000
D loss: 0.9821
G_loss: 1.329

Iter: 207000
D loss: 0.8882
G_loss: 1.718

Iter: 208000
D loss: 0.8698
G_loss: 1.478

Iter: 209000
D loss: 0.9314
G_loss: 1.139

Iter: 210000
D loss: 1.02
G_loss: 1.311

Iter: 211000
D loss: 0.8162
G_loss: 1.566

Iter: 212000
D loss: 0.7806
G_loss: 1.328

Iter: 213000
D loss: 1.041
G_loss: 1.339

Iter: 214000
D loss: 

Iter: 384000
D loss: 1.171
G_loss: 1.479

Iter: 385000
D loss: 0.8222
G_loss: 1.565

Iter: 386000
D loss: 0.8687
G_loss: 1.589

Iter: 387000
D loss: 1.08
G_loss: 1.415

Iter: 388000
D loss: 1.076
G_loss: 1.539

Iter: 389000
D loss: 0.9053
G_loss: 1.432

Iter: 390000
D loss: 0.7742
G_loss: 1.685

Iter: 391000
D loss: 1.001
G_loss: 1.356

Iter: 392000
D loss: 0.9235
G_loss: 1.732

Iter: 393000
D loss: 1.141
G_loss: 1.507

Iter: 394000
D loss: 0.8695
G_loss: 1.472

Iter: 395000
D loss: 1.121
G_loss: 1.585

Iter: 396000
D loss: 0.8259
G_loss: 1.479

Iter: 397000
D loss: 0.9954
G_loss: 1.294

Iter: 398000
D loss: 0.9686
G_loss: 1.402

Iter: 399000
D loss: 0.9265
G_loss: 1.751

Iter: 400000
D loss: 1.058
G_loss: 1.291

Iter: 401000
D loss: 1.22
G_loss: 1.247

Iter: 402000
D loss: 1.017
G_loss: 1.656

Iter: 403000
D loss: 0.9881
G_loss: 1.265

Iter: 404000
D loss: 0.8803
G_loss: 1.632



KeyboardInterrupt: 