In [1]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm import tqdm

In [2]:
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [5]:
#Discriminator Net
def genrator(z):
    with tf.variable_scope('GEN',reuse=tf.AUTO_REUSE):
        G_h1=tf.contrib.layers.fully_connected(inputs=z, num_outputs=128, activation_fn=tf.nn.relu,weights_initializer=tf.contrib.layers.xavier_initializer())
        G_log_prob=tf.contrib.layers.fully_connected(inputs=G_h1,num_outputs=784,activation_fn=None,weights_initializer=tf.contrib.layers.xavier_initializer())
        G_prob=tf.nn.sigmoid(G_log_prob,name='G_prob')
    return G_prob

#Genrator Net
def discriminator(x):
    with tf.variable_scope('DES',reuse=tf.AUTO_REUSE):
        D_h1=tf.contrib.layers.fully_connected(inputs=x,num_outputs=128,activation_fn=tf.nn.relu,weights_initializer=tf.contrib.layers.xavier_initializer())
        D_logits=tf.contrib.layers.fully_connected(inputs=D_h1,num_outputs=1,activation_fn=None,weights_initializer=tf.contrib.layers.xavier_initializer())
        D_prob=tf.nn.sigmoid(D_logits)
    return D_prob,D_logits

#uniform prior for sample
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])
#
def plot(samples):
    fig = plt.figure(figsize=(5, 5))
    gs = gridspec.GridSpec(5, 5)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

In [13]:
chkpt_dir='chkpt000'
batch_size=128
z_dim=100
tf.reset_default_graph()
experiment_id=0
T_graph=tf.Graph()
batch_count = int(mnist.train.num_examples/batch_size)
with T_graph.as_default():
    #Z=tf.constant(np.random.random([10,100]),dtype=tf.float32,name='Z')
    #X=tf.constant(np.random.random([10,784]),dtype=tf.float32,name='X')
    X=tf.placeholder(tf.float32,shape=[None,784],name='X')
    Z=tf.placeholder(tf.float32,shape=[None,100],name='X')
    G_sample=genrator(Z)
    D_real,D_logit_real=discriminator(X)
    D_fake,D_logit_fake=discriminator(G_sample)
    D_loss=-tf.reduce_mean(tf.log(D_real)+tf.log(1.0-D_fake))
    G_loss=-tf.reduce_mean(tf.log(D_fake))
    gen_vars=[var for var in tf.trainable_variables() if var.name[:3]=='GEN']
    des_vars=[var for var in tf.trainable_variables() if var.name[:3]=='DES']
    D_optimize=tf.train.AdamOptimizer(learning_rate=0.0001,name='D_Adam').minimize(D_loss,var_list=des_vars)
    G_optimize=tf.train.AdamOptimizer(learning_rate=0.0001,name='G_Adam').minimize(G_loss,var_list=gen_vars)
    
    #summaries=set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    D_summaries=set([])
    D_summaries.add(tf.summary.scalar('D_loss',D_loss))
    D_summary_op=tf.summary.merge(list(D_summaries))
    
    G_summaries=set([])
    G_summaries.add(tf.summary.scalar('G_loss',G_loss))
    G_summary_op=tf.summary.merge(list(G_summaries))
    
with tf.Session(graph=T_graph) as sess:
    saver=tf.train.Saver()
    summary_writer=tf.summary.FileWriter(chkpt_dir,graph=T_graph)
    init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
    sess.run(init)
    tf_vars=tf.trainable_variables()
    if os.path.exists(chkpt_dir+'/checkpoint'):
        print('Found Chcekpoint')
        saver.restore(sess,chkpt_dir+'/model.chkpt')
        print('checkpoint_restored !')
    elif not os.path.exists(chkpt_dir):
        os.mkdir(chkpt_dir)
    print('Trainng Started...')
    #p_bar=tqdm(100000,desc='G_loss : N/A ; D_loss : N/A')
    step_D_loss,step_G_loss=[],[]
    for i in tqdm(range(1000000)):
        #p_bar.update(1)
        X_real, _ = mnist.train.next_batch(batch_size)
        _, D_loss_curr,D_summary = sess.run([D_optimize, D_loss,D_summary_op], feed_dict={X: X_real, Z: sample_Z(batch_size, z_dim)})
        _, G_loss_curr,G_summary = sess.run([G_optimize, G_loss,G_summary_op], feed_dict={Z: sample_Z(batch_size, z_dim)})
        step_D_loss.append(D_loss_curr)
        step_G_loss.append(G_loss_curr)
        summary_writer.add_summary(D_summary, i)
        summary_writer.add_summary(G_summary, i)
        if i% 1000==0:
            #p_bar.set_description('G_loss : {} ; D_loss : {}'.format((sum(step_D_loss)/len(step_D_loss)),(sum(step_G_loss)/len(step_G_loss))))
            step_D_loss,step_G_loss=[],[]
            samples=sess.run(G_sample,feed_dict={Z:sample_Z(25,z_dim)})
            fig=plot(samples)
            plt.savefig('out/{}.png'.format(str(experiment_id).zfill(3)+'_'+str(i/1000).zfill(5)))
            plt.close(fig)
            saver.save(sess,chkpt_dir+'/model.chkpt')
    print('G_loss : {} ; D_loss : {}'.format((sum(step_D_loss)/len(step_D_loss)),(sum(step_G_loss)/len(step_G_loss))))
    print('training Complete.')

In [11]:
# #Alternative losses:
# D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit_real, tf.ones_like(D_logit_real)))
# D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit_fake, tf.zeros_like(D_logit_fake)))
# D_loss = D_loss_real + D_loss_fake
# G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit_fake, tf.ones_like(D_logit_fake)))

In [14]:
tf_vars

[<tf.Variable 'GEN/fully_connected/weights:0' shape=(100, 128) dtype=float32_ref>,
 <tf.Variable 'GEN/fully_connected/biases:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'GEN/fully_connected_1/weights:0' shape=(128, 784) dtype=float32_ref>,
 <tf.Variable 'GEN/fully_connected_1/biases:0' shape=(784,) dtype=float32_ref>,
 <tf.Variable 'DES/fully_connected/weights:0' shape=(784, 128) dtype=float32_ref>,
 <tf.Variable 'DES/fully_connected/biases:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'DES/fully_connected_1/weights:0' shape=(128, 1) dtype=float32_ref>,
 <tf.Variable 'DES/fully_connected_1/biases:0' shape=(1,) dtype=float32_ref>,
 <tf.Variable 'wrap/DES/fully_connected/weights:0' shape=(784, 128) dtype=float32_ref>,
 <tf.Variable 'wrap/DES/fully_connected/biases:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'wrap/DES/fully_connected_1/weights:0' shape=(128, 1) dtype=float32_ref>,
 <tf.Variable 'wrap/DES/fully_connected_1/biases:0' shape=(1,) dtype=float32_ref>]