# Generate and train a GAN model on MNIST dataset

In this notebook, we create and train a GAN model on the MNIST dataset. For computationnal reasons, the notebook will be executed on a Google Collab environment. It can also be run locally, but the training of the model may take hours (approx. 30 min/epoch on our machines v. 5mn for 10 epochs on Google's GPUs).


In [0]:
# connect pipelines with Google drive - do not run this cell if running locally
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive 
from google.colab import auth 
from oauth2client.client import GoogleCredentials

In [0]:
# in order to be able to save the model on the google drive
# do not run this cell if the notebook is run localy.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()   
drive = GoogleDrive(gauth)  

In [3]:
# library imports
# we do not use tf 2.0 but earlier versions. 
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy import vstack
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Dropout
from matplotlib import pyplot

Using TensorFlow backend.


In [0]:
# definition of the functions for the model
# png images of the architecture of the models (generator and discriminator)
# are available in the "images" folder attached to the notebook.

# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):
    """
    this function defines the discriminator. 
    returns a model.
    """
    model = Sequential()
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

# define the standalone generator model
def define_generator(latent_dim):
      """
    this function defines the generator.
    arguments : 
    ``latent_dim`` : the dimension of the latent space 

    returns a model.
    """
    model = Sequential()
    # foundation for 7x7 image
    n_nodes = 128 * 7 * 7
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    # upsample to 14x14
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # upsample to 28x28
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
    """
    defines the gan. 
    arguments: 
    - ``g_model`` : a generator
    - ``d_model``: a discriminator

    returns a model
    """
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # connect them
    model = Sequential()
    # add generator
    model.add(g_model)
    # add the discriminator
    model.add(d_model)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

# load and prepare mnist training images
def load_real_samples():
    """
    a utility that preprocesses the mnist images

    arguments : None

    returns scales and reshaped np.arrays
    """
    # load mnist dataset
    (trainX, _), (_, _) = load_data()
    # expand to 3d, e.g. add channels dimension
    X = expand_dims(trainX, axis=-1)
    # convert from unsigned ints to floats
    X = X.astype('float32')
    # scale from [0,255] to [0,1]
    X = X / 255.0
    return X

# select real samples
def generate_real_samples(dataset, n_samples):
    """
    a function that generates n_samples from a dataset

    arguments: 
    - ``dataset`` : a dataset
    - ``n_samples``: an integer, the number of samples to pick
    """
    # choose random instances
    ix = randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    X = dataset[ix]
    # generate 'real' class labels (1)
    y = ones((n_samples, 1))
    return X, y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
    """
    a function that generates points in the latent space
    points are drawn according to a uniform distribution

    arguments
    - ``latent_dim`` the dimension of the latent space
    -``n_samples`` the number of samples to generate

    returns an array of vectors in the latente space
    """
    # generate points in the latent space
    x_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):
    """
    a function that generates fake data based on samples sampled 
    in the latent space

    arguments :
    - ``g_model``: a generative model
    - ``latent_dim``: the dimension of the latent space
    - ``n_samples`` : the number of samples to generate

    returns 
    X : an array of fake images
    y : the corresponding label (0, i.e. fake)
  
    """
    # generate points in latent space
    x_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    X = g_model.predict(x_input)
    # create 'fake' class labels (0)
    y = zeros((n_samples, 1))
    return X, y

# create and save a plot of generated images (reversed grayscale)
def save_plot(examples, epoch, n=10):
    """
    a utility tha saves a plot of generated images

    arguments:
    -``examples`` : images to be displayed
    -``epoch`` : the epoch number
    
    returns None
    """
    # plot images
    for i in range(n * n):
        # define subplot
        pyplot.subplot(n, n, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
    # save plot to file
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    pyplot.savefig(filename)
    pyplot.close()

# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    """
    a function that summarizes the performance of the model while 
    training.
    Fitted to be run on a google collab environment.
    in this environment, images will not be saved (works only in local)

    arguments 
    - ``epoch`` the epoch count
    - ``d_model`` the discriminator
    - ``g_model`` the generator
    - ``dataset`` the dataset on which trainig is made
    - ``latent_dim`` the dimension of the latent space

    returns None
    """
    # prepare real samples
    X_real, y_real = generate_real_samples(dataset, n_samples)
    # evaluate discriminator on real examples
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake examples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    # save plot
    save_plot(x_fake, epoch)
    # save the generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(filename)
    # upload the model to the google drive 
    # comment the three lines below if running the notebook locally
    model_file = drive.CreateFile({'title' : filename})
    model_file.SetContentFile(filename)                       
    model_file.Upload()
    # download to google drive                       
    # comment the two lines below if running the notebook locally
    drive.CreateFile({'id': model_file.get('id')})
    print('model %s downloaded to drive' %(filename))

# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=256):
    """
    the main function. Trains the generator and the discriminator

    arguments: 
    - ``epoch`` the epoch count
    - ``d_model`` the discriminator
    - ``g_model`` the generator
    - ``dataset`` the dataset on which trainig is made
    - ``latent_dim`` the dimension of the latent space

    Returns None
    """
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            X_real, y_real = generate_real_samples(dataset, half_batch)
            # generate 'fake' examples
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            # create training set for the discriminator
            X, y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
            # update discriminator model weights
            d_loss, _ = d_model.train_on_batch(X, y)
            # prepare points in latent space as input for the generator
            X_gan = generate_latent_points(latent_dim, n_batch)
            # create inverted labels for the fake samples
            y_gan = ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            # summarize loss on this batch
            print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, bat_per_epo, d_loss, g_loss))
        # evaluate the model performance, sometimes
        if (i+1) % 10 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

In [7]:
# train the model (from scratch)
# if run on a cpu, can take up hours !

# size of the latent space
latent_dim = 100
# create the discriminator
d_model = define_discriminator()
# create the generator
g_model = define_generator(latent_dim)
# create the gan
gan_model = define_gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100)

  'Discrepancy between trainable weights and collected trainable'


>1, 1/234, d=0.685, g=0.772
>1, 2/234, d=0.668, g=0.793


  'Discrepancy between trainable weights and collected trainable'


>1, 3/234, d=0.665, g=0.814
>1, 4/234, d=0.656, g=0.834
>1, 5/234, d=0.652, g=0.862
>1, 6/234, d=0.643, g=0.873
>1, 7/234, d=0.641, g=0.882
>1, 8/234, d=0.633, g=0.885
>1, 9/234, d=0.640, g=0.881
>1, 10/234, d=0.642, g=0.856
>1, 11/234, d=0.653, g=0.821
>1, 12/234, d=0.655, g=0.787
>1, 13/234, d=0.664, g=0.763
>1, 14/234, d=0.668, g=0.740
>1, 15/234, d=0.662, g=0.726
>1, 16/234, d=0.660, g=0.715
>1, 17/234, d=0.655, g=0.710
>1, 18/234, d=0.653, g=0.706
>1, 19/234, d=0.642, g=0.704
>1, 20/234, d=0.639, g=0.703
>1, 21/234, d=0.626, g=0.702
>1, 22/234, d=0.622, g=0.701
>1, 23/234, d=0.611, g=0.702
>1, 24/234, d=0.604, g=0.702
>1, 25/234, d=0.595, g=0.702
>1, 26/234, d=0.586, g=0.703
>1, 27/234, d=0.577, g=0.704
>1, 28/234, d=0.570, g=0.704
>1, 29/234, d=0.561, g=0.705
>1, 30/234, d=0.550, g=0.706
>1, 31/234, d=0.533, g=0.707
>1, 32/234, d=0.529, g=0.708
>1, 33/234, d=0.515, g=0.709
>1, 34/234, d=0.506, g=0.711
>1, 35/234, d=0.493, g=0.712
>1, 36/234, d=0.485, g=0.713
>1, 37/234, d=0.478, 

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/googleapiclient/discovery_cache/__init__.py", line 36, in autodetect
    from google.appengine.api import memcache
ModuleNotFoundError: No module named 'google.appengine'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/googleapiclient/discovery_cache/file_cache.py", line 33, in <module>
    from oauth2client.contrib.locked_file import LockedFile
ModuleNotFoundError: No module named 'oauth2client.contrib.locked_file'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/googleapiclient/discovery_cache/file_cache.py", line 37, in <module>
    from oauth2client.locked_file import LockedFile
ModuleNotFoundError: No module named 'oauth2client.locked_file'

During handling of the above exception, another exceptio

[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
>79, 156/234, d=0.688, g=0.715
>79, 157/234, d=0.692, g=0.721
>79, 158/234, d=0.691, g=0.712
>79, 159/234, d=0.689, g=0.685
>79, 160/234, d=0.694, g=0.677
>79, 161/234, d=0.694, g=0.705
>79, 162/234, d=0.694, g=0.742
>79, 163/234, d=0.697, g=0.704
>79, 164/234, d=0.690, g=0.682
>79, 165/234, d=0.690, g=0.681
>79, 166/234, d=0.694, g=0.704
>79, 167/234, d=0.697, g=0.710
>79, 168/234, d=0.690, g=0.712
>79, 169/234, d=0.690, g=0.679
>79, 170/234, d=0.688, g=0.698
>79, 171/234, d=0.695, g=0.704
>79, 172/234, d=0.690, g=0.717
>79, 173/234, d=0.683, g=0.724
>79, 174/234, d=0.696, g=0.699
>79, 175/234, d=0.693, g=0.689
>79, 176/234, d=0.693, g=0.688
>79, 177/234, d=0.693, g=0.722
>79, 178/234, d=0.696, g=0.721
>79, 179/234, d=0.688, g=0.695
>79, 180/234, d=0.693, g=0.683
>79, 181/234, d=0.695, g=0.700
>79, 182/234, d=0.693, g=0.719
>79, 183/234, d=0.689, g=0.730
>79, 184/234, d=0.693, g=0.726
>79, 185/

The model is now ready to be used for reconstruction. The file ```generator_model_100.h5``` will be imported in the subsequent notebooks for reconstruction. 

Note that the latter loop created a model every 10 epochs. 