In [None]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os

In [None]:
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)

batch_size = 64
Z_dim = 128

# define inputs
with tf.variable_scope("inputs"):
    X = tf.placeholder(tf.float32, shape= [None, mnist.train.images.shape[1]])
    Z = tf.placeholder(tf.float32, shape= [None, Z_dim])
    Y = tf.placeholder(tf.float32, shape= [None, mnist.train.labels.shape[1]])

In [None]:
def generator(z, y):
    with tf.variable_scope("generator"):
        inputs =
        G_h1 = tf.layers.dense(inputs, 128,
                               activation=tf.nn.relu,
                               kernel_initializer=tf.glorot_normal_initializer())
        G_prob = tf.layers.dense(G_h1, 784, activation = tf.nn.sigmoid, kernel_initializer=tf.glorot_normal_initializer())
    
        return G_prob

def discriminator(x, y, reuse):
    with tf.variable_scope("discriminator"):       
        inputs =
        D_h1 = tf.layers.dense(inputs, 128,
                               reuse = reuse,
                               activation=tf.nn.relu,
                               kernel_initializer=tf.glorot_normal_initializer(), 
                               name = "L1")
        D_logit = tf.layers.dense(D_h1, 1,
                                  reuse = reuse,
                                  kernel_initializer=tf.glorot_normal_initializer(),
                                  name = "L2")
        D_prob = tf.nn.sigmoid(D_logit)
        return D_prob, D_logit

def loss(D_logit_real, D_logit_fake):
    # D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
    # G_loss = -tf.reduce_mean(tf.log(D_fake))
    
    # Alternative losses:
    # -------------------
    D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
    D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
    D_loss = D_loss_real + D_loss_fake
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
    
    return G_loss, D_loss

def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

In [None]:
G_sample = generator(Z, Y)
D_real, D_logit_real = discriminator(X, Y, reuse=None)
D_fake, D_logit_fake = discriminator(G_sample, Y, reuse=True)

In [None]:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith("generator")]
d_vars = [var for var in t_vars if var.name.startswith("discriminator")]

In [None]:
G_loss, D_loss = loss(D_logit_real, D_logit_fake)
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=d_vars)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=g_vars)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')

for it in range(50000):
    X_mb, Y_mb = mnist.train.next_batch(batch_size) 
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb,
                                                             Z: sample_Z(batch_size, Z_dim),
                                                             Y: Y_mb})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(batch_size, Z_dim),
                                                             Y: Y_mb})

    if it % 1000 == 0:
        print('Iter: {}, D loss: {:.4}, G_loss: {:.4}'.format(it,D_loss_curr,G_loss_curr))
        
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim),
                                               Y: Y_mb[0:16,]})

        fig = plot(samples)
        plt.savefig('out/ex05_CGAN/{}.png'.format(str(it).zfill(5)), bbox_inches='tight')
        plt.close(fig)