In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('D:/MNIST/',one_hot=True)
train_img = mnist.train.images[0:100]

train_img = train_img*2 - 1;

Extracting D:/MNIST/train-images-idx3-ubyte.gz
Extracting D:/MNIST/train-labels-idx1-ubyte.gz
Extracting D:/MNIST/t10k-images-idx3-ubyte.gz
Extracting D:/MNIST/t10k-labels-idx1-ubyte.gz


In [2]:
def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, [None, real_dim], name='inputs_real')
    inputs_real = tf.reshape(inputs_real,[-1,28,28,1])
    
    inputs_z = tf.placeholder(tf.float32, [None, z_dim], name='inputs_z')
    
    return inputs_real, inputs_z

In [3]:
def discriminator(x,reuse=False):
    with tf.variable_scope('discriminator',reuse=reuse):
        conv1 = tf.layers.conv2d(x,6,kernel_size=(1,1),strides=(1,1),activation=tf.nn.relu)
        pool1 = tf.layers.max_pooling2d(conv1,pool_size=(2,2),strides=(2,2))
        conv2 = tf.layers.conv2d(pool1,16,kernel_size=(5,5),strides=(1,1),activation=tf.nn.relu)
        pool2 = tf.layers.max_pooling2d(conv2,pool_size=(2,2),strides=(2,2))
        
        dense1 = tf.layers.flatten(pool2)
        dense2 = tf.layers.dense(dense1,120,activation=tf.nn.relu)
        dense3 = tf.layers.dense(dense2,84,activation=tf.nn.relu)
        logit = tf.layers.dense(dense3,1)
        
        out = tf.nn.sigmoid(logit)
        return out,logit

In [4]:
def generator(z,z_dim,reuse=False):
    with tf.variable_scope('generator',reuse=reuse):
        g1 = tf.layers.dense(z,3136,activation=tf.nn.relu)
        g1 = tf.reshape(g1,[-1,56,56,1])
        g1 = tf.contrib.layers.batch_norm(g1,epsilon=1e-5)
        
        g2 = tf.layers.conv2d(g1,z_dim/2,kernel_size=(3,3),strides=(2,2),padding='same',activation=tf.nn.relu)
        g2 = tf.contrib.layers.batch_norm(g2,epsilon=1e-5)
    
        g3 = tf.layers.conv2d(g2,z_dim/4,kernel_size=(3,3),strides=(2,2),padding='same',activation=tf.nn.relu)
        g3 = tf.contrib.layers.batch_norm(g3,epsilon=1e-5)
    
        g3 = tf.image.resize_images(g3,[56,56])
    
        # final output with one channel
        logit = tf.layers.conv2d(g3,1,kernel_size=(3,3),strides=(2,2),padding='same')
        out = tf.nn.tanh(logit)
        
        return out,logit

In [5]:
z_size = 100
lr = 0.002

In [6]:
input_real,input_z = model_inputs(784,z_size)
g_model,g_logit = generator(input_z,z_size)

d_model_real,d_logit_real = discriminator(input_real)
d_model_fake,d_logit_fake = discriminator(g_model,reuse=True)

In [7]:
print(input_real.shape)

(?, 28, 28, 1)


In [8]:
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logit_real,labels=tf.ones_like(d_logit_real)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logit_fake,labels=tf.zeros_like(d_logit_fake)))

d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logit_fake,labels=tf.ones_like(d_logit_fake)))

In [9]:
t_vars = tf.trainable_variables()

g_var = [var for var in t_vars if var.name.startswith('generator')]
d_var = [var for var in t_vars if var.name.startswith('discriminator')]

In [10]:
d_train = tf.train.AdamOptimizer(lr).minimize(d_loss,var_list=d_var)
g_train = tf.train.AdamOptimizer(lr).minimize(g_loss,var_list=g_var)

In [11]:
epochs = 500
batch_size = 50

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    g_cost,d_cost = [],[]
    
    for e in range(epochs):
        for i in range(10):
            batch_img = mnist.train.next_batch(batch_size)
            batch_img = batch_img[0].reshape(batch_size,784)
            batch_img = batch_img*2 - 1;
            batch_img = batch_img.reshape(batch_img.shape[0],28,28,1)
            
            batch_z = np.random.uniform(-1,1,(batch_size,z_size))
            
            sess.run(d_train,feed_dict={input_real:batch_img,input_z:batch_z})
            sess.run(g_train,feed_dict={input_z:batch_z})
            
            d_ = sess.run(d_loss,feed_dict={input_real:batch_img,input_z:batch_z})
            g_ = g_loss.eval(feed_dict={input_z:batch_z})
            
            if e % 10 == 0:
                d_cost.append(d_)
                g_cost.append(g_)
            
        if e % 25 == 0:
            print('epoch : ',str(e),' ,d_loss : ',str(d_),' ,g_loss : ',str(g_))
            
    
    for j in range(3):
        batch_z = np.random.uniform(-1,1,(batch_size,z_size))
        sample,_ = sess.run(generator(input_z,z_size,reuse=True),feed_dict={input_z:batch_z})
        sample = np.array(sample)
        plt.imshow(sample[1].reshape(28,28),cmap='Greys_r')
        plt.show()
    
    sess.close()

epoch :  0  ,d_loss :  0.9313  ,g_loss :  1.07941
epoch :  25  ,d_loss :  1.01288  ,g_loss :  1.77895
epoch :  50  ,d_loss :  0.996225  ,g_loss :  1.03075
epoch :  75  ,d_loss :  0.97052  ,g_loss :  1.93999
epoch :  100  ,d_loss :  0.537881  ,g_loss :  2.07897
epoch :  125  ,d_loss :  0.725354  ,g_loss :  1.91842
epoch :  150  ,d_loss :  0.63064  ,g_loss :  1.78402
epoch :  175  ,d_loss :  0.674166  ,g_loss :  2.1539
epoch :  200  ,d_loss :  0.804975  ,g_loss :  1.92757
epoch :  225  ,d_loss :  0.594217  ,g_loss :  1.74163
epoch :  250  ,d_loss :  0.531988  ,g_loss :  2.06198
epoch :  275  ,d_loss :  0.497663  ,g_loss :  2.04506
epoch :  300  ,d_loss :  0.625113  ,g_loss :  2.27976
epoch :  325  ,d_loss :  0.599595  ,g_loss :  1.87968
epoch :  350  ,d_loss :  0.530174  ,g_loss :  1.97243
epoch :  375  ,d_loss :  0.641245  ,g_loss :  2.14184
epoch :  400  ,d_loss :  0.744241  ,g_loss :  1.77239
epoch :  425  ,d_loss :  0.64474  ,g_loss :  2.25099
epoch :  450  ,d_loss :  0.563244  ,g_lo

In [12]:
plt.plot(d_cost,'r-')
plt.plot(g_cost,'g-')
plt.show()