In [27]:
import tensorflow as tf
import os
import struct
import numpy as np


logdir = "../../logs"

In [13]:
def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)


def load_mnist(path="../../data/MNIST/"):
    X_train = read_idx(os.path.join(path, "X_train"))
    X_train = X_train.reshape((X_train.shape[0], -1))
                       
    Y_train = read_idx(os.path.join(path, "Y_train"))
                       
    X_test = read_idx(os.path.join(path, "X_test"))
    X_test = X_test.reshape((X_test.shape[0], -1))
                      
    Y_test = read_idx(os.path.join(path, "Y_test"))
    return X_train, Y_train, X_test, Y_test

In [18]:
X_train, Y_train, X_test, Y_test = load_mnist()

In [31]:
batch_size = 20

training_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).repeat().batch(batch_size)
testing_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test))

In [17]:
def generator(z,reuse=None):
    with tf.variable_scope('generator',reuse=reuse):
        hidden1=tf.layers.dense(inputs=z,units=128,activation=tf.nn.leaky_relu)
        hidden2=tf.layers.dense(inputs=hidden1,units=128,activation=tf.nn.leaky_relu)
        output=tf.layers.dense(inputs=hidden2,units=784,activation=tf.nn.tanh)
        
        return output
    
def discriminator(X,reuse=None):
    with tf.variable_scope('discriminator',reuse=reuse):
        hidden1=tf.layers.dense(inputs=X,units=128,activation=tf.nn.leaky_relu)
        hidden2=tf.layers.dense(inputs=hidden1,units=128,activation=tf.nn.leaky_relu)
        logits=tf.layers.dense(hidden2,units=1)
        output=tf.sigmoid(logits)
        
        return output,logits

In [22]:
Z_size = 100


tf.reset_default_graph()

real_images=tf.placeholder(tf.float32,shape=[None,X_train.shape[1]])
z=tf.placeholder(tf.float32,shape=[None,Z_size])

G=generator(z)
D_output_real,D_logits_real=discriminator(real_images)
D_output_fake,D_logits_fake=discriminator(G,reuse=True)

In [24]:
def loss_func(logits_in,labels_in):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_in,labels=labels_in))

D_real_loss=loss_func(D_logits_real,tf.ones_like(D_logits_real)*0.9) #Smoothing for generalization
D_fake_loss=loss_func(D_logits_fake,tf.zeros_like(D_logits_real))
D_loss=D_real_loss+D_fake_loss

G_loss= loss_func(D_logits_fake,tf.ones_like(D_logits_fake))

In [25]:
lr=0.001

#Do this when multiple networks interact with each other
tvars=tf.trainable_variables()  #returns all variables created(the two variable scopes) and makes trainable true
d_vars=[var for var in tvars if 'discriminator' in var.name]
g_vars=[var for var in tvars if 'generator' in var.name]

D_trainer=tf.train.AdamOptimizer(lr).minimize(D_loss,var_list=d_vars)
G_trainer=tf.train.AdamOptimizer(lr).minimize(G_loss,var_list=g_vars)

batch_size=100
epochs=100
init=tf.global_variables_initializer()

In [None]:
epochs = 20

with tf.Session() as sess:
    
    # Create a filewriter to write the model's graph to TensorBoard
    writer = tf.summary.FileWriter(logdir + '/linear_regression', sess.graph)
    
    sess.run(init)
    for epoch in range(epochs):
        num_batches=X_train.shape[0]//batch_size
        for i in range(num_batches):
            batch=mnist.train.next_batch(batch_size)
            batch_images=batch[0].reshape((batch_size,784))
            batch_images=batch_images*2-1
            batch_z=np.random.uniform(-1,1,size=(batch_size,100))
            _=sess.run(D_trainer,feed_dict={real_images:batch_images,z:batch_z})
            _=sess.run(G_trainer,feed_dict={z:batch_z})
            
        print("on epoch{}".format(epoch))
        
        sample_z=np.random.uniform(-1,1,size=(1,100))
        gen_sample=sess.run(generator(z,reuse=True),feed_dict={z:sample_z})
        
        samples.append(gen_sample)

plt.imshow(samples[0].reshape(28,28))
plt.imshow(samples[99].reshape(28,28))