# Preface

In this notebook, we implement the basic generative adversarial network and the iterative training algorithm.

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
sns.set(font_scale=1.5, style='darkgrid')

# Data

We will use the fashion mnist dataset we have seen before. We will only need the training images. Our goal is to generate new images using GAN.

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train[:, :, :, None] / 255.0

# Building the GAN

In [None]:
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose, Reshape, LeakyReLU, Dropout, Flatten, Input
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

First we define a latent dimension, which is the dimension of the noise input into the generator, $z$.

In [None]:
latent_dim = 128
image_shape = (28, 28, 1)
optimizer = Adam(0.0002, 0.5)

## Build Generator

Next, we write a function to build a generator. We will use a simple convolutional neural network. The input would be a 128 dimensional random noise and outputs a (28, 28, 1) greyscale image.

In [None]:
def build_generator():
    generator = Sequential()
    generator.add(Dense(7*7*128, kernel_initializer=RandomNormal(0, 0.02), input_dim=latent_dim))
    generator.add(LeakyReLU(0.2))
    generator.add(Reshape((7, 7, 128)))
    
    # 7x7x128
    generator.add(Conv2DTranspose(128, (3, 3), strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02)))
    generator.add(LeakyReLU(0.2))
    
    # 14x14*128
    generator.add(Conv2DTranspose(128, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))
    generator.add(LeakyReLU(0.2))
    
    # 28x28x128
    generator.add(Conv2DTranspose(128, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))
    generator.add(LeakyReLU(0.2))
    
    # 28x28x1
    generator.add(Conv2D(1, (3, 3), padding='same', activation='tanh', kernel_initializer=RandomNormal(0, 0.02)))
    
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator

In [None]:
generator = build_generator()

To visualized generated images, we will write a function to plot them. This basically generates random Gaussian noises and then decodes them via $G$.

In [None]:
def plot_generated_images(generator, examples=25, dim=(5,5), figsize=(10,10)):
    noise = np.random.normal(size=[examples, latent_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples,28,28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys_r')
        plt.axis('off')
    plt.tight_layout()

In [None]:
plot_generated_images(generator)

As expected, an untrained generator is going to be rubbish results.

## Build Discriminator

Now let us build the discriminator, which is again an CNN which outputs a value [0, 1].

In [None]:
def build_discriminator():
    discriminator = Sequential()
    discriminator.add(
        Conv2D(64, (3, 3), padding='same', kernel_initializer=RandomNormal(0, 0.02), input_shape=(28, 28, 1)))
    discriminator.add(LeakyReLU(0.2))
    
    discriminator.add(Conv2D(128, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))
    discriminator.add(LeakyReLU(0.2))
    
    discriminator.add(Conv2D(128, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))
    discriminator.add(LeakyReLU(0.2))
    
    discriminator.add(Conv2D(256, (3, 3), strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02)))
    discriminator.add(LeakyReLU(0.2))
    
    discriminator.add(Flatten())
    discriminator.add(Dropout(0.4))
    discriminator.add(Dense(1, activation='sigmoid'))
    
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator

In [None]:
discriminator = build_discriminator()

## Combining into GAN

Now we will combine the discriminator and the generator into a combined GAN model for training. 

The training will proceed in two parts
  * Train discriminator to tell noise-generated samples apart from basic samples
  * Train generator to generate hard-to-tell-apart samples

We will use a trick. During training of generator, we will disable the training of discriminator, this can be done by setting
`discriminator.trainable = False` before building the `gan` model. Note that this does *not* affect the trainability of
the discriminator model since it has been compiled before setting this flag.

In [None]:
discriminator.trainable = False

Build combined model
$$
    D_\phi ( G_\theta (z) ) 
$$

In [None]:
gan_input = Input(shape=(latent_dim,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)
gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

# Train GAN

Now let us train the GAN model by alternating training of the discriminator and the generator. 
  * Training the discriminator is very simple, as this is just a basic binary classification problem
  * Training the generator: recall that we need to maximize the discriminator error. This can be easily done by *reversing* the labels - set label = 1 for generated samples, keeping the discriminator fixed.

In [None]:
from tqdm import tqdm

In [None]:
batch_size = 16
steps_per_epoch = 3750
epochs = 10

In [None]:
# for epoch in tqdm(range(epochs)):
#     for batch in tqdm(range(steps_per_epoch)):
#         noise = np.random.normal(0, 1, size=(batch_size, latent_dim))
#         fake_x = generator.predict(noise)  # generated images
#         real_x = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]  # real images
        
#         x = np.concatenate((real_x, fake_x))

#         disc_y = np.zeros(2*batch_size)
#         disc_y[:batch_size] = 0.9

#         d_loss = discriminator.train_on_batch(x, disc_y)

#         y_gen = np.ones(batch_size)
#         g_loss = gan.train_on_batch(noise, y_gen)

# generator.save_weights('gan_generator.h5')
# discriminator.save_weights('gan_discriminator.h5')

In [None]:
generator.load_weights('gan_generator.h5')
discriminator.load_weights('gan_discriminator.h5')

Let us see what images we can generate using the trained generator

In [None]:
plot_generated_images(generator)

# Exercise

1. Try the various performance-improving techniques introduced in class on this simple GAN
2. Try implementing the Wasserstein GAN