### Necessary Imports

In [28]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Input, Flatten, Reshape
from tensorflow.keras.layers import Dense, Conv2DTranspose, BatchNormalization, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.layers import concatenate
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
import os
import math


### Loading Training Data

In [31]:
# load MNIST dataset
(x_train, _), (_, _) = mnist.load_data()
# reshape data for CNN as (28, 28, 1) and normalize
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_train.shape

(60000, 28, 28, 1)

### Model Parameters

In [32]:
latent_size = 100
batch_size = 64
train_steps = 40000
lr = 2e-4
decay = 6e-8
input_shape = (image_size, image_size, 1)


### Generator Function 

In [6]:
def build_generator(inputs, image_size):
    """Builds a generator model"""
    image_resize = image_size // 4
    kernel_size = 5
    layer_filters = [128, 64, 32, 1]
    
    x = Dense(image_resize * image_resize * layer_filters[0])(inputs)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)
    
    for filters in layer_filters:
        if filters > layer_filters[-2]:
            strides = 2
        else:
            strides = 1
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters = filters,
                            kernel_size=kernel_size,
                            strides=strides,
                            padding='same')(x)
    x = Activation('tanh')(x)
    generator = Model(inputs, x , name='generator')
    return generator
    

### Discriminator Function 

In [7]:
def build_discriminator(inputs):

    """Build a Discriminator Model
    Stack of LeakyReLU-Conv2D to discriminate real from fake.
    The network does not converge with BN so it is not used here
    unlike in [1] or original paper.
    # Arguments
    inputs (Layer): Input layer of the discriminator (the image)
    # Returns
    Model: Discriminator Model
    """
    kernel_size = 5
    layer_filters = [32, 64, 128, 256]
    x = inputs
    for filters in layer_filters:
   
        if filters == layer_filters[-1]:
            strides = 1
        else:
            strides = 2
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same')(x)
    x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    discriminator = Model(inputs, x, name='discriminator')
    return discriminator

###  Building Discriminator

In [33]:
inputs = Input(shape=input_shape, name='discriminator_input')
discriminator = build_discriminator(inputs)
optimizer = RMSprop(lr=lr, decay=decay)
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
discriminator.summary()


Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa [(None, 28, 28, 1)]       0         
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 32)        832       
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 64)          51264     
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 4, 4, 128)       

### Building Generator

In [34]:
input_shape = (latent_size )
inputs = Input(shape=input_shape, name='z_input')
generator = build_generator(inputs, image_size)
generator.summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_input (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense_4 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
batch_normalization_8 (Batch (None, 7, 7, 128)         512       
_________________________________________________________________
activation_11 (Activation)   (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 14, 14, 128)       409728    
_________________________________________________________________
batch_normalization_9 (Batch (None, 14, 14, 128)       51

### Building Adverserial Network

In [37]:
optimizer = RMSprop(lr=lr * 0.5, decay=decay * 0.5)
discriminator.trainable = False
adversarial = Model(inputs,
discriminator(generator(inputs)),
name='model')
adversarial.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
adversarial.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_input (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
generator (Model)            (None, 28, 28, 1)         1301505   
_________________________________________________________________
discriminator (Model)        (None, 1)                 1080577   
Total params: 2,382,082
Trainable params: 1,300,801
Non-trainable params: 1,081,281
_________________________________________________________________


### Training Function 

In [None]:
def train(models, x_train, paarams):
    generator, discriminator, adversarial = models
    batch_size, latent_size, train_steps, model_name = params
    save_interval = 500
    noise_input = np.random.uniform(-1.0, 1.0, size=[16, latent_size])
    train_size = x_train.shape[0]
    for i in range(train_steps):
        rand_indexes = np.random.randint(0, train_size, size=batch_size)
        