# Simple GAN

* Ivan Goodfellow's paper: https://arxiv.org/pdf/1406.2661.pdf 
* Source code: https://github.com/PacktPublishing/Learning-Generative-Adversarial-Networks/ 
* Implementation for simple GAN via tensorflow: 
  1. book - https://github.com/PacktPublishing/Learning-Generative-Adversarial-Networks/blob/master/Chapter02/Code/simple-gan.ipynb 
  2. blog - https://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/ 
* GAN Implementation by numpy: https://towardsdatascience.com/only-numpy-implementing-gan-general-adversarial-networks-and-adam-optimizer-using-numpy-with-2a7e4e032021

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer
from tensorflow.examples.tutorials.mnist import input_data

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline

# Generator

In [None]:
# origional ebook's xavier_init_manual
def xavier_init_manual(size):
    input_dim = size[0]
    xavier_variance = 1. / tf.sqrt(input_dim/2.)
    return tf.random_normal(shape=size, stddev=xavier_variance)

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [None]:
# generate random noise for generator
Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')

# Generator parameters setting
xavier_init = xavier_initializer()
G_W1 = tf.Variable(xavier_init([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(xavier_init([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]

In [None]:
# generator network
def generator(z):
    """
    Using prior z of G(z) to learn mapping between prior space and p(data) in [G_W1, G_b1, G_W2, G_b2].
    Arguments:
        z - random noise
    Returns:
        G_prob - probability of G(z)
    """
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.sigmoid(G_log_prob)
    return G_prob

# Discriminator

In [None]:
# prepare input setting of MNIST image via discrimator
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')

# Discrimator parameters setting
D_W1 = tf.Variable(xavier_init([784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros([128]), name='D_b1')
D_W2 = tf.Variable(xavier_init([128, 1]), name="D_W2")
D_b2 = tf.Variable(tf.zeros([1]), name='D_b2')
theta_D = [D_W1, D_W2, D_b1, D_b2]

In [None]:
# discriminator network
def discriminator(x):
    """
    Using D(x) to to judge if it's true data or false generated data.
    Arguments:
        x - input
    Returns:
        D_prob - probability of D(x)
        D_logit - logit of D(x)
    """
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.sigmoid(D_logit)
    return D_prob, D_logit

# Loss

In [None]:
D_real, D_logit_real = discriminator(X)
G_sample = generator(Z)
D_fake, G_logit_fake = discriminator(G_sample)

# GAN's loss function
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

# Only update D(X)'s parameters, var_list=theta_D
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
# Only update G(X)'s parameters, var_list=theta_G
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

In [None]:
def sample_Z(m, n):
    """
    sample distribution of Z, uniform prior for G(Z)
    """
    return np.random.uniform(-1., 1., size=[m, n])

In [None]:
mnist = input_data.read_data_sets('MNIST/', one_hot=True)

if not os.path.exists('output/'):
    os.makedirs('output/')

batch_size = 128
Z_dim = 100
i = 0
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for itr in range(1000000):
        if itr % 1000 == 0:
            samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})
            fig = plot(samples)
            plt.savefig('output/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

        X_mb, _ = mnist.train.next_batch(batch_size)
        _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(batch_size, Z_dim)})
        _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(batch_size, Z_dim)})

        if itr % 1000 == 0:
            print('Iter: {}'.format(itr))
            print('D_loss: {:.4}'.format(D_loss_curr))
            print('G_loss: {:.4}'.format(G_loss_curr))
            print()