## Library imports 

In [None]:
# Basic Code is taken from https://github.com/ckmarkoh/GAN-tensorflow

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from skimage.io import imsave
import os
import shutil

## Constants and flags 

In [None]:
img_height = 28
img_width = 28
img_size = img_height * img_width

to_train = True
to_restore = False
output_path = "output"

max_epoch = 500

h1_size = 150
h2_size = 300
z_size = 100
batch_size = 256
ngf = 128

## Function definitions 

### Convolution layer 
Defines a general 2D convolution layer with batch normalization

In [None]:
def general_conv2d(inputconv, name="conv2d",
                   o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, 
                   stddev=0.02, padding=None, 
                   do_norm=True, do_relu=True):
    '''Defines a general 2D convolution layer with batch normalization'''
    
    with tf.variable_scope(name):
        initializer = tf.truncated_normal_intializer(stddev=stddev)
        w = tf.get_variable('w',
                            [f_h, f_w, inputconv.get_shape(-1), o_d], 
                            initializer=initializer)
        conv = tf.nn.conv2d(inputconv,
                            filter=w,strides=[1, s_w, s_h, 1],
                            padding=padding)
        biases = tf.get_variable('b',
                                 [o_d],
                                 initializer=tf.constant_initializer(0.0))
        conv = tf.nn.bias_add(conv,biases)
        
        # Add batch_norm layer
        if do_norm:
            dims = conv.get_shape()
            scale = tf.get_variable('scale',
                                    [dims[1],dims[2],dims[3]],
                                    tf.constant_initializer(1))
            beta = tf.get_variable('beta',
                                   [dims[1],dims[2],dims[3]],
                                   tf.constant_initializer(0))
            conv_mean, conv_var = tf.nn.moments(conv,[0])
            conv = tf.nn.batch_normalization(conv, conv_mean, conv_var, beta, scale, 0.001)
        # Add ReLU activation
        if do_relu:
            conv = tf.nn.relu(conv,0)
    return conv

### Resnet block 

In [None]:
def build_resnet_block(inputres, dim, name="resnet"):
    out_res = inputres
    with tf.variable_scope(name):
        out_res = general_conv2d(inputres, dim, 3, 3, 1, 1, 0.02, "SAME", "c1")
        out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "SAME", "c2", do_relu=False)
    return tf.nn.relu(out_res + inputres)

def build_generator_resnet_6blocks(inputgen, name="generator"):
    with tf.variable_scope(name):
        f = 7
        ks = 3
        o_c1 = general_conv2d(inputgen, ngf, f, f, 1, 1, 0.02, "SAME", "c1")
        o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, None, "c2")
        o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, None, "c3")

### Show results 

In [None]:
def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):
    batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5
    img_h, img_w = batch_res.shape[1], batch_res.shape[2]
    grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
    grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
    img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
    for i, res in enumerate(batch_res):
        if i >= grid_size[0] * grid_size[1]:
            break
        img = (res) * 255
        img = img.astype(np.uint8)
        row = (i // grid_size[0]) * (img_h + grid_pad)
        col = (i % grid_size[1]) * (img_w + grid_pad)
        img_grid[row:row + img_h, col:col + img_w] = img
    imsave(fname, img_grid)

### Training and testing routines 

In [None]:
def train():
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

    x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data")
    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
    keep_prob = tf.placeholder(tf.float32, name="keep_prob")
    global_step = tf.Variable(0, name="global_step", trainable=False)

    x_generated, g_params = build_generator(z_prior)
    y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob)

    d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
    g_loss = - tf.log(y_generated)

    optimizer = tf.train.AdamOptimizer(0.0001)

    d_trainer = optimizer.minimize(d_loss, var_list=d_params)
    g_trainer = optimizer.minimize(g_loss, var_list=g_params)

    init = tf.initialize_all_variables()

    saver = tf.train.Saver()

    sess = tf.Session()

    sess.run(init)

    if to_restore:
        chkpt_fname = tf.train.latest_checkpoint(output_path)
        saver.restore(sess, chkpt_fname)
    else:
        if os.path.exists(output_path):
            shutil.rmtree(output_path)
        os.mkdir(output_path)


    z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)

    for i in range(sess.run(global_step), max_epoch):
        for j in range(60000 / batch_size):
            print("epoch:%s, iter:%s" % (i, j))
            x_value, _ = mnist.train.next_batch(batch_size)
            x_value = 2 * x_value.astype(np.float32) - 1
            z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
            sess.run(d_trainer,
                     feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
            if j % 1 == 0:
                sess.run(g_trainer,
                         feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
        show_result(x_gen_val, "output/sample{0}.jpg".format(i))
        z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
        show_result(x_gen_val, "output/random_sample{0}.jpg".format(i))
        sess.run(tf.assign(global_step, i + 1))
        saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)


def test():
    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
    x_generated, _ = build_generator(z_prior)
    chkpt_fname = tf.train.latest_checkpoint(output_path)

    init = tf.initialize_all_variables()
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(init)
    saver.restore(sess, chkpt_fname)
    z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
    x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
    show_result(x_gen_val, "output/test_result.jpg")

## Main/driver

In [None]:
if __name__ == '__main__':
    if to_train:
        train()
    else:
        test()