<small>
Copyright (c) 2017 Andrew Glassner

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
</small>



# Deep Learning From Basics to Practice
## by Andrew Glassner, https://dlbasics.com, http://glassner.com
------
## Chapter 26: Generative Adversarial Networks
### Notebook 3: GAN on MNIST

This notebook is provided as a “behind-the-scenes” look at code used to make some of the figures in this chapter. It is still in the hacked-together form used to develop the figures, and is only lightly commented.

In [1]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape, LeakyReLU
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Conv2DTranspose
from keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from tqdm import tqdm
import math
import matplotlib.pyplot as plt

from keras import backend as keras_backend
keras_backend.set_image_data_format('channels_last')

Using TensorFlow backend.


In [2]:
# Make a File_Helper for saving and loading files.

save_files = True

import os, sys, inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
sys.path.insert(0, os.path.dirname(current_dir)) # path to parent dir
from DLBasics_Utilities import File_Helper
file_helper = File_Helper(save_files)

In [85]:
# The old models match those in [Gildenblat16],
# but use upsampling and downsampling layers.
# The new ones are modified to use the suggestions
# in [Radford16].
disc_model_index = 1
gen_model_index = 1

mnist_data_dir = 'MNIST-GAN-data'

In [93]:
def generator_model_0():
    model = Sequential()
    model.add(Dense(1024, input_shape=[100]))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

def generator_model_1(): 
    model = Sequential()
    model.add(Dense(1024, input_shape=[100]))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(Conv2DTranspose(64, (5, 5), activation='tanh', strides=(2,2), padding='same'))
    model.add(Conv2DTranspose(1, (5, 5), activation='tanh', strides=(2,2), padding='same'))
    return model

def generator_model_2():
    model = Sequential()
    model.add(Dense(1024, input_shape=[100], activation='tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(Conv2DTranspose(64, (5, 5), padding='same', strides=(2,2), activation=None))    
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    # No batchnorm for last convolution in generator [Radford16]
    model.add(Conv2DTranspose(1, (5, 5), padding='same', strides=(2,2), activation='tanh'))    

    return model

In [94]:
def discriminator_model_0():
    model = Sequential()
    model.add(Conv2D(64, (5, 5),
                     padding='same', 
                     input_shape=(28, 28, 1)))  
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

def discriminator_model_1():
    model = Sequential()
    # No batchnorm for first convolution in discriminator [Radford16]
    model.add(Conv2D(64, (5, 5),
                     padding='same', strides=(2,2),
                     activation=None,
                     input_shape=(28, 28, 1)))     
    model.add(LeakyReLU()) # LeakyReLU has to be in its own layer
    model.add(Conv2D(128, (5, 5), 
                     padding='same', strides=(2,2),
                     activation=None))    
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Flatten())
    model.add(Dense(1024, activation='tanh'))
    model.add(Dense(1, activation='sigmoid'))
    return model

def discriminator_model_2():
    model = Sequential()
    # No batchnorm for first convolution in discriminator [Radford16]
    model.add(Conv2D(64, (5, 5),
                     padding='same', strides=(2,2), 
                     activation=None,
                     input_shape=(28, 28, 1))) 
    model.add(LeakyReLU()) # LeakyReLU has to be in its own layer
    model.add(Conv2D(128, (5, 5), padding='same', strides=(2,2),
                    activation=None))   
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Flatten())
    model.add(Dense(1024, activation='tanh'))
    model.add(Dense(1, activation='sigmoid'))
    return model

In [95]:
def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

In [96]:
def plot_loss(epoch_number, losses):
    #display.clear_output(wait=True)
    #display.display(plt.gcf())
    plt.figure(figsize=(10,8))
    plt.plot(losses["d"], label='discriminitive loss')
    plt.plot(losses["g"], label='generative loss')
    plt.legend()
    file_helper.save_figure('MNIST-losses-epoch-'+str(epoch_number))
    plt.show()
        
def plot_gen(epoch_number, generator, discriminator):
    cols = 10
    rows = 10
    grid_size = rows*cols
    
    overbuild_factor = 20
    number_of_images = grid_size * overbuild_factor
    noise = np.zeros((number_of_images, 100))
    for i in range(number_of_images):
        noise[i, :] = np.random.uniform(-1, 1, 100)
    generated_images = generator.predict(noise, verbose=1)
    print(">>>>> generated_images.shape=",generated_images.shape)
    predictions = discriminator.predict(generated_images, verbose=1)
    index = np.arange(0, number_of_images)
    index.resize((number_of_images, 1))
    predictions_with_index = list(np.append(predictions, index, axis=1))
    predictions_with_index.sort(key=lambda x: x[0], reverse=True)

    plt.figure(figsize=((cols, rows)))
    for i in range(grid_size):
        image_number = int(predictions_with_index[i][1])
        plt.subplot(rows, cols, i+1)
        plt.imshow(generated_images[image_number, 0, :, :], cmap="gray")
        ax = plt.gca()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    file_helper.save_figure("MNIST-generated-images-epoch-"+str(epoch_number))
    plt.show()

In [97]:
def get_X_train_from_MNIST():
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape(X_train.shape+ (-1,))
    print("******* CLIPPING TO 10000 samples *********") # save time
    X_train = X_train[:1000]
    return X_train

In [98]:
def train_one_epoch(X_train, generator, discriminator, discriminator_on_generator, batch_size, epoch_number, losses):
    noise = np.zeros((batch_size, 100))
    for index in tqdm(range(int(X_train.shape[0]/batch_size))):
        for i in range(batch_size):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        image_batch = X_train[index*batch_size:(index+1)*batch_size]
        generated_images = generator.predict(noise, verbose=0)
        generated_images = np.array(generated_images)
        X = np.concatenate((image_batch, generated_images))
        y = [1] * batch_size + [0] * batch_size
        d_loss = discriminator.train_on_batch(X, y)
        for i in range(batch_size):
            noise[i, :] = np.random.uniform(-1, 1, 100)
        discriminator.trainable = False
        g_loss = discriminator_on_generator.train_on_batch(
            noise, [1] * batch_size)
        discriminator.trainable = True    
    
    file_helper.check_for_directory(mnist_data_dir)
    generator.save_weights(mnist_data_dir+'/generator-epoch-'+str(epoch_number)+'.h5', True)
    discriminator.save_weights(mnist_data_dir+'/discriminator-epoch-'+str(epoch_number)+'.h5', True)
    generator.save_weights(mnist_data_dir+'/generator-most-recent.h5', True)
    discriminator.save_weights(mnist_data_dir+'/discriminator-most-recent.h5', True)
    losses["g"].append(g_loss)
    losses["d"].append(d_loss)

In [99]:
def get_models():
    
    print("Discriminator ",disc_model_index," generator ",gen_model_index)
    
        
    if disc_model_index==0:
        discriminator = discriminator_model_0()
    elif disc_model_index==1:
        discriminator = discriminator_model_1()
    elif disc_model_index==2:
        discriminator = discriminator_model_2()
    else:
        print(">>>>> DEFAULT DISCRIMINATOR VERSION 0")
        discriminator = discriminator_model_0() 
        
        
    if gen_model_index==0:
        generator = generator_model_0()
    elif gen_model_index==1:
        generator = generator_model_1()
    elif gen_model_index==2:
        generator = generator_model_2()
    else:
        print(">>>>> DEFAULT GENERATOR VERSION 0")
        generator = generator_model_0()
        
        
    discriminator_on_generator = \
        generator_containing_discriminator(generator, discriminator)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    discriminator_on_generator.compile(
        loss='binary_crossentropy', optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)
    return (generator, discriminator, discriminator_on_generator)

In [100]:
def train_model(number_of_epochs, X_train, batch_size, losses, generator, discriminator, discriminator_on_generator):
    for epoch_number in range(number_of_epochs):
        print("Starting Epoch ",epoch_number)
        train_one_epoch(X_train, generator, discriminator, discriminator_on_generator, batch_size, epoch_number, losses)
        plot_loss(epoch_number, losses)
        plot_gen(epoch_number, generator, discriminator)

In [101]:
X_train = get_X_train_from_MNIST()
(generator, discriminator, discriminator_on_generator) = get_models()
losses = {"d":[], "g":[]}
batch_size = 32
number_of_epochs = 100

train_model(number_of_epochs, X_train, batch_size, losses, generator, discriminator, discriminator_on_generator)

******* CLIPPING TO 10000 samples *********
Discriminator  0  generator  0


  0%|          | 0/31 [00:00<?, ?it/s]

Starting Epoch  0
noise.shape= (32, 100)
X_train.shape= (1000, 28, 28, 1)
image_batch.shape= (32, 28, 28, 1)
generated_images.shape= (32, 28, 28, 1)


  3%|▎         | 1/31 [00:03<01:46,  3.54s/it]

noise.shape= (32, 100)
X_train.shape= (1000, 28, 28, 1)
image_batch.shape= (32, 28, 28, 1)
generated_images.shape= (32, 28, 28, 1)


  6%|▋         | 2/31 [00:04<01:20,  2.77s/it]

noise.shape= (32, 100)
X_train.shape= (1000, 28, 28, 1)
image_batch.shape= (32, 28, 28, 1)
generated_images.shape= (32, 28, 28, 1)


 10%|▉         | 3/31 [00:05<01:02,  2.24s/it]

noise.shape= (32, 100)
X_train.shape= (1000, 28, 28, 1)
image_batch.shape= (32, 28, 28, 1)
generated_images.shape= (32, 28, 28, 1)


 13%|█▎        | 4/31 [00:06<00:50,  1.88s/it]

noise.shape= (32, 100)
X_train.shape= (1000, 28, 28, 1)
image_batch.shape= (32, 28, 28, 1)
generated_images.shape= (32, 28, 28, 1)





KeyboardInterrupt: 

In [None]:
X_train.shape