<b>Author:</b> Jhosimar George Arias Figueroa

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

### Load MNIST dataset

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

In [None]:
num_samples = mnist.train.num_examples
num_features = mnist.train.images.shape[1]
print("Number of Samples: {}. Feature Dimension: {}".format(num_samples, num_features))

In [None]:
def weight_variable(shape):
    return tf.Variable(tf.truncated_normal(shape, stddev=0.1))

def bias_variable(shape):
    return tf.Variable(tf.constant(0.1, shape=shape))

In [None]:
class GAN:
    def __init__(self, **options):
        self.sample_size = options.get("sample_size", 20)        # noise sample for generator    
        self.batch_size = options.get("batch_size", 200)         # training batch size
        self.epochs = options.get("epochs", 100)                 # number of training epochs
        self.learning_rate = options.get("learning_rate", 0.01)  # learning rate
        self.sess = tf.Session()                                 # tensorflow session
        self.display_step = options.get("display_step", 1)       # display loss
        self.discriminator_steps = options.get("k", 1)           # k iterations for discriminator
    
    # Generator
    # input: sample noise (num_samples x sample_size)
    # output: generated image
    def generator(self, z):
        with tf.variable_scope("generator"):
            #first layer
            G_W1 = weight_variable([self.sample_size, 250])
            G_b1 = bias_variable([250])
            G_h1 = tf.nn.tanh( tf.matmul(z, G_W1) + G_b1 )
        
            #second layer
            G_W2 = weight_variable([250, 500])
            G_b2 = bias_variable([500])
            G_h2 = tf.nn.tanh( tf.matmul(G_h1, G_W2) + G_b2)
        
            #output layer
            G_W3 = weight_variable([500, self.num_features])
            G_b3 = bias_variable([self.num_features])
            G_output = tf.nn.sigmoid(tf.matmul(G_h2, G_W3) + G_b3) 
        
            return G_output
    
    # Discriminator
    # input: data (num_samples x num_features)
    # output: probabilities, logits
    def discriminator(self, X, reuse = False):
        with tf.variable_scope("discriminator", reuse= reuse):
            #first layer
            D_W1 = weight_variable([self.num_features, 500])
            D_b1 = bias_variable([500])
            D_h1 = tf.nn.tanh( tf.matmul(X, D_W1) + D_b1)
    
            #second layer
            D_W2 = weight_variable([500, 250])
            D_b2 = bias_variable([250])
            D_h2 = tf.nn.tanh( tf.matmul(D_h1, D_W2) + D_b2)
    
            #third layer
            D_W3 = weight_variable([250, 1])
            D_b3 = bias_variable([1])
            D_logits = tf.matmul(D_h2, D_W3) + D_b3
            D_output = tf.nn.sigmoid(D_logits)
    
            return D_output, D_logits
    
    @staticmethod    
    def sample_noise(shape):
        '''Uniform prior for G(Z)'''
        return np.random.uniform(-1., 1., size=shape)
    
    # Training discriminator k times and generator once 
    def alternating_optimization(self, _X, _z):
        #discriminator steps
        D_loss_avg = 0
        for i in range(self.discriminator_steps):
            _, D_loss = self.sess.run( [self.discriminator_optimizer, self.discriminator_loss] , 
                            feed_dict = {self.X:_X, self.z:_z} )
            D_loss_avg += D_loss
        D_loss_avg /= self.discriminator_steps
    
        #generator step
        _, G_loss = self.sess.run( [self.generator_optimizer, self.generator_loss], 
                        feed_dict= {self.z:_z})
        
        return D_loss_avg, G_loss

    def train(self, data):
        self.num_features = data.shape[1]
        self.num_samples = data.shape[0]
        
        # Input variables
        self.X = tf.placeholder(tf.float32, [None, self.num_features])
        self.z = tf.placeholder(tf.float32, [None, self.sample_size])
        
        # Loss functions
        self.G_sample = self.generator(self.z)
        
        # Discriminator loss
        D_real, D_real_logits = self.discriminator(self.X)
        D_fake, D_fake_logits = self.discriminator(self.G_sample, reuse = True)  
        self.discriminator_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1.0 - D_fake))
        #self.discriminator_loss = tf.reduce_mean(
        #    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real_logits))
        #    + tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits,labels=tf.zeros_like(D_fake_logits)))

        # Generator loss
        self.generator_loss = -tf.reduce_mean(tf.log(D_fake))
        #self.generator_loss = tf.reduce_mean(tf.log(1.0 - D_fake))
        #self.generator_loss = tf.reduce_mean(
        #    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.one_like(D_fake_logits)))
                                                 
        # Shared variables
        discriminator_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        generator_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
    
        # Define optimizers for both generator and discriminator
        self.discriminator_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        self.generator_optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        self.discriminator_optimizer = self.discriminator_optimizer.minimize(self.discriminator_loss, 
                                                                             var_list=discriminator_vars)
        self.generator_optimizer = self.generator_optimizer.minimize(self.generator_loss,
                                                                    var_list=generator_vars)
       
        # Variable initialization
        self.sess.run(tf.initialize_all_variables())
        
        num_batches = int(self.num_samples/self.batch_size)

        i = 0
        # Number of iterations
        for epoch in range(self.epochs):
            avg_loss = 0
            avg_generator = 0
            avg_discriminator = 0
            
            # Iterate on each batch
            for i in range(num_batches):
                start = i * self.batch_size
                end = i * self.batch_size + self.batch_size
                # Get current batch
                batch_X = data[start:end][:]
                z = self.sample_noise([self.batch_size, self.sample_size])
    
                # Train GAN alternating discriminator and generator
                D_loss, G_loss = self.alternating_optimization(batch_X, z)

                avg_generator += G_loss
                avg_discriminator += D_loss
                avg_loss += D_loss + G_loss
            
            avg_discriminator /= num_batches
            avg_generator /= num_batches
            avg_loss /= num_batches
            
            if( epoch % self.display_step == 0 ):
                print("Epoch {}: -- Discriminator={}, Generator={}, Loss={}".format(epoch + 1, avg_discriminator, 
                                                                                avg_generator, avg_loss))

    # Generate batch of samples
    # input: noise (n x sample_size)
    # output: generated images (n x num_features)
    def generate_data(self, num_samples, noise = None):
        if( noise == None):
            noise = self.sample_noise([num_samples, self.sample_size])
        generated = self.sess.run(self.G_sample, feed_dict={self.z: noise})
        return generated
        

In [None]:
# Load train data
train_data = mnist.train.images

# Generative Adversarial Network Instantiation
GAN_model = GAN(sample_size = 100, batch_size = 100, epochs = 10000, learning_rate = 0.001, display_step = 5000 )

# Training Conditional VAE
GAN_model.train(train_data)

### Data Generation

In [None]:
generated = GAN_model.generate_data(100)
plt.figure(figsize=[10,10])
for i in range(0,100):
    plt.subplot(10,10,i+1)
    plt.imshow(np.reshape(generated[i], (28, 28)), interpolation='none',cmap=plt.get_cmap('gray'))
    plt.axis('off')