# Training and Testing a Generative Adversarial Network

## Imports and loading in dataset

In [None]:
from numpy import zeros, ones, expand_dims, asarray
from numpy.random import randn, randint
import tensorflow as tf
from keras.datasets import mnist
from keras.optimizers import Adam
from keras.models import Model, load_model
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import Conv2D, Conv2DTranspose, Concatenate
from keras.layers import LeakyReLU, Dropout, Embedding
from keras.layers import BatchNormalization, Activation
from keras import initializers, Sequential
from keras.initializers import RandomNormal
from keras.optimizers import Adam, RMSprop, SGD
from matplotlib import pyplot
import numpy as np
from math import sqrt
import os
from keras.callbacks import Callback
from keras.losses import BinaryCrossentropy


(X_train, _), (_, _) = mnist.load_data()
X_train = X_train.astype(np.float32) / 127.5 - 1
X_train = np.expand_dims(X_train, axis=3)


## Defining Discriminator and Generator

In [None]:
def define_discriminator(input_dim=(28,28,1)):
    model = Sequential([
        Input(shape=input_dim),
        Flatten(),
        Dense(units=784, activation='relu'),
        Dense(units=128, activation='relu'),
        Dense(units=64, activation='relu'),
        Dense(units=1, activation='sigmoid')
    ])
    model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
    return model

def define_generator(latent_dim = 64, input_dim=(28,28,1)):
    model = Sequential([
        Input(shape=(latent_dim,)),
        Dense(units=784, activation='relu'),
        Dense(units=256, activation='relu'),
        Dense(units=input_dim[0] * input_dim[1] * input_dim[2], activation='tanh'),
        Reshape(target_shape=input_dim)
    ])
    return model

## Defining GAN Training Architecture

In [None]:
class GANModel(Model):
    def __init__(self, generator, discriminator, generator_latent_dim, *args, **kwargs):
        # Pass through args and kwargs to base class 
        super().__init__(*args, **kwargs)
        
        # Create attributes for gen and disc
        self.generator = generator 
        self.discriminator = discriminator
        self.generator_latent_dim = generator_latent_dim
        
        
    def compile(self, g_opt, d_opt, g_loss, d_loss, n_critic, LAMBDA, *args, **kwargs): 
        # Compile with base class
        super().compile(*args, **kwargs)
        
        # Create attributes for losses and optimizers
        self.g_opt = g_opt
        self.d_opt = d_opt
        self.g_loss = g_loss
        self.d_loss = d_loss
        self.n_critic = n_critic
        self.LAMBDA = LAMBDA
        
    def get_generator(self):
        return self.generator
    
    # returns 2D array
    # n_samples number of rows with each row having latent_dim number of random noise.
    def generate_latent_points(self, n_samples):
        # Returns a np array of dimension (X,) meaning 1D array. 
        x_input = randn(self.generator_latent_dim * n_samples)
        
        # Returns a 2D np array. 
        # Divides 1D array such that for each n_samples, there are latent_dim random numbers
        z_input = x_input.reshape(n_samples, self.generator_latent_dim)
        return z_input
    
    def gradient_penalty(self, real_images, fake_images):
        batch_size = real_images.shape[0]
        epsilon = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0.0, maxval=1.0)
        interpolated_images = epsilon * tf.dtypes.cast(real_images, tf.float32) + ((1 - epsilon) * fake_images)
        
        with tf.GradientTape() as penalty_tape:
            penalty_tape.watch(interpolated_images)
            yhat_interpolated = self.discriminator(interpolated_images, training=True)
            
        p_grad = penalty_tape.gradient(yhat_interpolated, interpolated_images)
        grad_norms = tf.sqrt(tf.reduce_sum(tf.square(p_grad), axis=[1, 2, 3]))
        gradient_penalty = tf.reduce_mean(tf.square(grad_norms-1.0))
        return gradient_penalty
        
        
    def train_step(self, batch):
        real_images = batch
        fake_images = self.generator(self.generate_latent_points(batch.shape[0]), training=False)
        
        for _ in range(self.n_critic):
            # Train the discriminator
            with tf.GradientTape() as d_tape:
                yhat_real = self.discriminator(real_images, training=True) 
                yhat_fake = self.discriminator(fake_images, training=True)
                yhat_realfake = tf.concat([yhat_real, yhat_fake], axis=0)
                y_realfake = tf.concat([tf.zeros_like(yhat_real), tf.ones_like(yhat_fake)], axis=0)
                
                gradient_penalty = self.gradient_penalty(real_images, fake_images)
                # Calculate loss - BINARYCROSS 
                # total_d_loss = self.d_loss(y_realfake, yhat_realfake)
                total_d_loss = (tf.reduce_mean(yhat_fake) - tf.reduce_mean(yhat_real)) + (self.LAMBDA * gradient_penalty)
                
            # Apply backpropagation - nn learn 
            d_grad = d_tape.gradient(total_d_loss, self.discriminator.trainable_variables) 
            self.d_opt.apply_gradients(zip(d_grad, self.discriminator.trainable_variables))
        
        # Train the generator
        with tf.GradientTape() as g_tape:
            gen_images = self.generator(self.generate_latent_points(batch.shape[0]), training=True)
            predicted_labels = self.discriminator(gen_images, training=False)
                                        
            # Calculate loss - trick to training to fake out the discriminator
            # total_g_loss = self.g_loss(tf.zeros_like(predicted_labels), predicted_labels)
            total_g_loss = -tf.reduce_mean(predicted_labels)
            
        # Apply backprop
        g_grad = g_tape.gradient(total_g_loss, self.generator.trainable_variables)
        self.g_opt.apply_gradients(zip(g_grad, self.generator.trainable_variables))
        
        return {"d_loss":total_d_loss, "g_loss":total_g_loss}

## Functions

In [None]:
# returns 2D array
# n_samples number of rows with each row having latent_dim number of random noise.
def generate_latent_points(latent_dim, n_samples):
    # Returns a np array of dimension (X,) meaning 1D array. 
    x_input = randn(latent_dim * n_samples)
    
    # Returns a 2D np array. 
    # Divides 1D array such that for each n_samples, there are latent_dim random numbers
    z_input = x_input.reshape(n_samples, latent_dim)
    return z_input

# Chooses n_samples number of samples from training set
# Gets labels alongside with same dimension.
def generate_real_samples(X_train, n_samples):
    #Returns a np array of size n_samples repr. indices of chosen elements for next batch
    ix = randint(0, X_train.shape[0], n_samples)
    X = X_train[ix]
    y = ones((n_samples, 1))
    
    # X is of dimension (n_samples, 28, 28, 1)
    # y is of dimension (n_samples, 1)
    return X, y

# generates n_samples from generator
# takes in 2D array of latent points aswell
def generate_fake_samples(generator, latent_dim, n_samples):
    z_input = generate_latent_points(latent_dim, n_samples)
    outputs = generator.predict(z_input)  
    y = zeros((n_samples, 1))
    return outputs, y

def get_GAN_training_network(generator, discriminator, latent_dim, generator_learning_rate = 0.002, discriminator_learning_rate = 0.002, n_critic=1, LAMBDA=1):
    gan_model = GANModel(generator=generator, discriminator=discriminator, generator_latent_dim=latent_dim)
    g_opt = Adam(learning_rate=generator_learning_rate, beta_1=0.5)
    d_opt = Adam(learning_rate=discriminator_learning_rate, beta_1=0.5)
    g_loss = BinaryCrossentropy()
    d_loss = BinaryCrossentropy()
    gan_model.compile(g_opt, d_opt, g_loss, d_loss, n_critic=n_critic, LAMBDA=LAMBDA)
    return gan_model

def get_generator_and_discriminator(latent_dim, input_dim):
    discriminator = define_discriminator(input_dim=input_dim)
    generator = define_generator(latent_dim=latent_dim, input_dim=input_dim)
    return discriminator, generator

## Training the GAN

### Get Discriminator and Generator

In [None]:
discriminator, generator = get_generator_and_discriminator(latent_dim=100, input_dim=(28,28,1))
gan_model = get_GAN_training_network(generator=generator, 
                                     discriminator=discriminator, 
                                     latent_dim=100, 
                                     generator_learning_rate=0.00005, 
                                     discriminator_learning_rate=0.00005,
                                     n_critic=5,
                                     LAMBDA=1
                                    )

### Train GAN

In [None]:
hist = gan_model.fit(X_train, epochs=100, batch_size=1250)

### Review Training History

In [None]:
pyplot.suptitle('Loss')
pyplot.plot(hist.history['d_loss'], label='d_loss')
pyplot.plot(hist.history['g_loss'], label='g_loss')
pyplot.legend()
pyplot.show()

### Get the generator from the GAN

In [None]:
generator = gan_model.get_generator()

## Print resulting images

In [None]:
X_real, y_real = generate_real_samples(X_train=X_train, n_samples=64)
X_fake, y_fake = generate_fake_samples(generator=generator, latent_dim=100, n_samples=64)
for i in range(64):
    pyplot.imshow(X_real[i])
    pyplot.show()
    pyplot.imshow(X_fake[i])
    pyplot.show()