#### Import Modules

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math

import tensorflow as tf

from keras import backend as K
K.set_image_dim_ordering('th')

#### Define Generator and Descriminator Models

In [None]:
def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((128, 7, 7), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Convolution2D(64, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Convolution2D(1, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    return model

def discriminator_model():
    model = Sequential()
    model.add(Convolution2D(64, 5, 5, border_mode='same', input_shape=(1, 28, 28)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(128, 5, 5))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

#### Helper Function for Visualization

In [None]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[2:]
    image = np.zeros((height*shape[0], width*shape[1]), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[0, :, :]
    return image

#### Training GAN

In [None]:
mini_batch = 100
max_epoch = 100

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
discriminator = discriminator_model()
generator = generator_model()
discriminator_on_generator = generator_containing_discriminator(generator, discriminator)
d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
generator.compile(loss='binary_crossentropy', optimizer="SGD")
discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=g_optim)
discriminator.trainable = True
discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)
noise = np.zeros((mini_batch, 100))

K.get_session().run(tf.global_variables_initializer())

for epoch in range(max_epoch):
    print("Epoch is", epoch)
    print("Number of batches", int(X_train.shape[0]/mini_batch))
    for index in range(int(X_train.shape[0]/mini_batch)):
        for i in range(mini_batch):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        image_batch = X_train[index*mini_batch:(index+1)*mini_batch]
        generated_images = generator.predict(noise, verbose=0)
        if index % 20 == 0:
            image = combine_images(generated_images)
            image = image*127.5+127.5
            Image.fromarray(image.astype(np.uint8)).save('./output/'+str(epoch)+"_"+str(index)+".png")
        X = np.concatenate((image_batch, generated_images))
        y = [1] * mini_batch + [0] * mini_batch
        d_loss = discriminator.train_on_batch(X, y)
        #print("batch %d d_loss : %f" % (index, d_loss))
        for i in range(mini_batch):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        discriminator.trainable = False
        g_loss = discriminator_on_generator.train_on_batch(noise, [1] * mini_batch)
        discriminator.trainable = True
        print("Epoch : %d, Iter %d, D_loss : %f, G_loss : %f" % (epoch, index, d_loss, g_loss))
        if index % 10 == 9:
            generator.save_weights('./models/generator', True)
            discriminator.save_weights('./models/discriminator', True)

#### Generate Images using Trained GAN

In [None]:
nice = True
num_generate = 100

generator = generator_model()
generator.compile(loss='binary_crossentropy', optimizer="SGD")
generator.load_weights('./models/generator')
if nice:
    discriminator = discriminator_model()
    discriminator.compile(loss='binary_crossentropy', optimizer="SGD")
    discriminator.load_weights('./models/discriminator')
    noise = np.zeros((num_generate*20, 100))
    for i in range(num_generate*20):
        noise[i, :] = np.random.uniform(-1, 1, 100)
    generated_images = generator.predict(noise, verbose=1)
    d_pret = discriminator.predict(generated_images, verbose=1)
    index = np.arange(0, num_generate*20)
    index.resize((num_generate*20, 1))
    pre_with_index = list(np.append(d_pret, index, axis=1))
    pre_with_index.sort(key=lambda x: x[0], reverse=True)
    nice_images = np.zeros((num_generate, 1) + (generated_images.shape[2:]), dtype=np.float32)
    for i in range(int(num_generate)):
        idx = int(pre_with_index[i][1])
        nice_images[i, 0, :, :] = generated_images[idx, 0, :, :]
    image = combine_images(nice_images)
else:
    noise = np.zeros((num_generate, 100))
    for i in range(num_generate):
        noise[i, :] = np.random.uniform(-1, 1, 100)
    generated_images = generator.predict(noise, verbose=1)
    image = combine_images(generated_images)
image = image*127.5+127.5
image = image.astype(np.uint8)
plt.imshow(image, cmap='gray')