In [1]:
from tensorflow.python import keras as kr
kr.__version__

'2.1.5-tf'

# Introduction to generative adversarial networks

This notebook contains the second code sample found in Chapter 8, Section 5 of [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python?a_aid=keras&a_bid=76564dff). Note that the original text features far more content, in particular further explanations and figures: in this notebook, you will only find source code and related comments.

---
[...]

## A schematic GAN implementation


In what follows, we explain how to implement a GAN in Keras, in its barest form -- since GANs are quite advanced, diving deeply into the 
technical details would be out of scope for us. Our specific implementation will be a deep convolutional GAN, or DCGAN: a GAN where the 
generator and discriminator are deep convnets. In particular, it leverages a `Conv2DTranspose` layer for image upsampling in the generator.

We will train our GAN on images from CIFAR10, a dataset of 50,000 32x32 RGB images belong to 10 classes (5,000 images per class). To make 
things even easier, we will only use images belonging to the class "frog".

Schematically, our GAN looks like this:

* A `generator` network maps vectors of shape `(latent_dim,)` to images of shape `(32, 32, 3)`.
* A `discriminator` network maps images of shape (32, 32, 3) to a binary score estimating the probability that the image is real.
* A `gan` network chains the generator and the discriminator together: `gan(x) = discriminator(generator(x))`. Thus this `gan` network maps 
latent space vectors to the discriminator's assessment of the realism of these latent vectors as decoded by the generator.
* We train the discriminator using examples of real and fake images along with "real"/"fake" labels, as we would train any regular image 
classification model.
* To train the generator, we use the gradients of the generator's weights with regard to the loss of the `gan` model. This means that, at 
every step, we move the weights of the generator in a direction that will make the discriminator more likely to classify as "real" the 
images decoded by the generator. I.e. we train the generator to fool the discriminator.

## A bag of tricks


Training GANs and tuning GAN implementations is notoriously difficult. There are a number of known "tricks" that one should keep in mind. 
Like most things in deep learning, it is more alchemy than science: these tricks are really just heuristics, not theory-backed guidelines. 
They are backed by some level of intuitive understanding of the phenomenon at hand, and they are known to work well empirically, albeit not 
necessarily in every context.

Here are a few of the tricks that we leverage in our own implementation of a GAN generator and discriminator below. It is not an exhaustive 
list of GAN-related tricks; you will find many more across the GAN literature.

* We use `tanh` as the last activation in the generator, instead of `sigmoid`, which would be more commonly found in other types of models.
* We sample points from the latent space using a _normal distribution_ (Gaussian distribution), not a uniform distribution.
* Stochasticity is good to induce robustness. Since GAN training results in a dynamic equilibrium, GANs are likely to get "stuck" in all sorts of ways. 
Introducing randomness during training helps prevent this. We introduce randomness in two ways: 1) we use dropout in the discriminator, 2) 
we add some random noise to the labels for the discriminator.
* Sparse gradients can hinder GAN training. In deep learning, sparsity is often a desirable property, but not in GANs. There are two things 
that can induce gradient sparsity: 1) max pooling operations, 2) ReLU activations. Instead of max pooling, we recommend using strided 
convolutions for downsampling, and we recommend using a `LeakyReLU` layer instead of a ReLU activation. It is similar to ReLU but it 
relaxes sparsity constraints by allowing small negative activation values.
* In generated images, it is common to see "checkerboard artifacts" caused by unequal coverage of the pixel space in the generator. To fix 
this, we use a kernel size that is divisible by the stride size, whenever we use a strided `Conv2DTranpose` or `Conv2D` in both the 
generator and discriminator.

## The generator


First, we develop a `generator` model, which turns a vector (from the latent space -- during training it will sampled at random) into a 
candidate image. One of the many issues that commonly arise with GANs is that the generator gets stuck with generated images that look like 
noise. A possible solution is to use dropout on both the discriminator and generator.

In [2]:
from tensorflow.python import keras
from tensorflow.python.keras import layers
import numpy as np

latent_dim = 32
height = 32
width = 32
channels = 3

generator_input = keras.Input(shape=(latent_dim,))

# First, transform the input into a 16x16 128-channels feature map
x = layers.Dense(128 * 16 * 16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 16, 128))(x)

# Then, add a convolution layer
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)

# Upsample to 32x32
x = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = layers.LeakyReLU()(x)

# Few more conv layers
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)

# Produce a 32x32 1-channel feature map
x = layers.Conv2D(channels, 7, activation='tanh', padding='same')(x)
generator = keras.models.Model(generator_input, x)
generator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 32)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 32768)             1081344   
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32768)             0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 256)       819456    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 32, 32, 256)       1048832   
__________

## The discriminator


Then, we develop a `discriminator` model, that takes as input a candidate image (real or synthetic) and classifies it into one of two 
classes, either "generated image" or "real image that comes from the training set".

In [3]:
discriminator_input = layers.Input(shape=(height, width, channels))
x = layers.Conv2D(128, 3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)

# One dropout layer - important trick!
x = layers.Dropout(0.4)(x)

# Classification layer
x = layers.Dense(1, activation='sigmoid')(x)

discriminator = keras.models.Model(discriminator_input, x)
discriminator.summary()

# To stabilize training, we use learning rate decay
# and gradient clipping (by value) in the optimizer.
discriminator_optimizer = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 32, 32, 3)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 30, 30, 128)       3584      
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 30, 30, 128)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 14, 14, 128)       262272    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 6, 6, 128)         262272    
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 6, 6, 128)         0         
__________

## The adversarial network

Finally, we setup the GAN, which chains the generator and the discriminator. This is the model that, when trained, will move the generator 
in a direction that improves its ability to fool the discriminator. This model turns latent space points into a classification decision, 
"fake" or "real", and it is meant to be trained with labels that are always "these are real images". So training `gan` will updates the 
weights of `generator` in a way that makes `discriminator` more likely to predict "real" when looking at fake images. Very importantly, we 
set the discriminator to be frozen during training (non-trainable): its weights will not be updated when training `gan`. If the 
discriminator weights could be updated during this process, then we would be training the discriminator to always predict "real", which is 
not what we want!

In [4]:
# Set discriminator weights to non-trainable
# (will only apply to the `gan` model)
discriminator.trainable = False

gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)

gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

## How to train your DCGAN

Now we can start training. To recapitulate, this is schematically what the training loop looks like:

```
for each epoch:
    * Draw random points in the latent space (random noise).
    * Generate images with `generator` using this random noise.
    * Mix the generated images with real ones.
    * Train `discriminator` using these mixed images, with corresponding targets, either "real" (for the real images) or "fake" (for the generated images).
    * Draw new random points in the latent space.
    * Train `gan` using these random vectors, with targets that all say "these are real images". This will update the weights of the generator (only, since discriminator is frozen inside `gan`) to move them towards getting the discriminator to predict "these are real images" for generated images, i.e. this trains the generator to fool the discriminator.
```

Let's implement it:

In [5]:
import os
from tensorflow.python.keras.preprocessing import image

# Load CIFAR10 data
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()

# Select frog images (class 6)
x_train = x_train[y_train.flatten() == 6]

# Normalize data
x_train = x_train.reshape(
    (x_train.shape[0],) + (height, width, channels)).astype('float32') / 255.

iterations = 10000
batch_size = 20
save_dir = '/home/claro/tf-fchollet/gan_images/'

# Start training loop
start = 0
for step in range(iterations):
    # Sample random points in the latent space
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

    # Decode them to fake images
    generated_images = generator.predict(random_latent_vectors)

    # Combine them with real images
    stop = start + batch_size
    real_images = x_train[start: stop]
    combined_images = np.concatenate([generated_images, real_images])

    # Assemble labels discriminating real from fake images
    labels = np.concatenate([np.ones((batch_size, 1)),
                             np.zeros((batch_size, 1))])
    # Add random noise to the labels - important trick!
    labels += 0.05 * np.random.random(labels.shape)

    # Train the discriminator
    d_loss = discriminator.train_on_batch(combined_images, labels)

    # sample random points in the latent space
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

    # Assemble labels that say "all real images"
    misleading_targets = np.zeros((batch_size, 1))

    # Train the generator (via the gan model,
    # where the discriminator weights are frozen)
    a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
    
    start += batch_size
    if start > len(x_train) - batch_size:
      start = 0

    # Occasionally save / plot
    if step % 100 == 0:
        # Save model weights
        gan.save_weights('gan.h5')

        # Print metrics
        print('discriminator loss at step %s: %s' % (step, d_loss))
        print('adversarial loss at step %s: %s' % (step, a_loss))

        # Save one generated image
        img = image.array_to_img(generated_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))

        # Save one real image, for comparison
        img = image.array_to_img(real_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))

discriminator loss at step 0: 0.697465
adversarial loss at step 0: 0.704368


discriminator loss at step 100: 0.674762
adversarial loss at step 100: 1.06405




discriminator loss at step 200: 0.70435
adversarial loss at step 200: 0.738509




discriminator loss at step 300: 0.704424
adversarial loss at step 300: 0.757434




discriminator loss at step 400: 0.678159
adversarial loss at step 400: 0.758902




discriminator loss at step 500: 0.759559
adversarial loss at step 500: 0.787402




discriminator loss at step 600: 0.688237
adversarial loss at step 600: 0.731752




discriminator loss at step 700: 0.701401
adversarial loss at step 700: 0.752684




discriminator loss at step 800: 0.690654
adversarial loss at step 800: 0.737324




discriminator loss at step 900: 0.704329
adversarial loss at step 900: 0.760033




discriminator loss at step 1000: 0.687321
adversarial loss at step 1000: 0.742191




discriminator loss at step 1100: 0.709247
adversarial loss at step 1100: 0.747085




discriminator loss at step 1200: 0.905955
adversarial loss at step 1200: 0.915063




discriminator loss at step 1300: 0.706997
adversarial loss at step 1300: 0.746751




discriminator loss at step 1400: 0.702261
adversarial loss at step 1400: 0.748807




discriminator loss at step 1500: 0.711617
adversarial loss at step 1500: 0.732273




discriminator loss at step 1600: 0.695692
adversarial loss at step 1600: 0.765894




discriminator loss at step 1700: 0.705757
adversarial loss at step 1700: 0.725899


discriminator loss at step 1800: 0.687978
adversarial loss at step 1800: 0.762718




discriminator loss at step 1900: 0.690869
adversarial loss at step 1900: 0.753191




discriminator loss at step 2000: 0.698868
adversarial loss at step 2000: 0.712694




discriminator loss at step 2100: 0.690721
adversarial loss at step 2100: 0.78837




discriminator loss at step 2200: 0.686589
adversarial loss at step 2200: 0.7416




discriminator loss at step 2300: 0.680353
adversarial loss at step 2300: 0.758624




discriminator loss at step 2400: 0.685275
adversarial loss at step 2400: 0.718926




discriminator loss at step 2500: 0.692279
adversarial loss at step 2500: 0.768586




discriminator loss at step 2600: 0.680852
adversarial loss at step 2600: 0.7249




discriminator loss at step 2700: 0.705707
adversarial loss at step 2700: 0.764143




discriminator loss at step 2800: 0.993573
adversarial loss at step 2800: 0.760023




discriminator loss at step 2900: 0.696314
adversarial loss at step 2900: 0.755919




discriminator loss at step 3000: 0.759032
adversarial loss at step 3000: 0.749412




discriminator loss at step 3100: 0.696107
adversarial loss at step 3100: 0.749971




discriminator loss at step 3200: 0.679407
adversarial loss at step 3200: 0.778419




discriminator loss at step 3300: 0.697299
adversarial loss at step 3300: 0.839425




discriminator loss at step 3400: 0.696949
adversarial loss at step 3400: 0.765225


discriminator loss at step 3500: 0.700387
adversarial loss at step 3500: 0.714678




discriminator loss at step 3600: 0.702057
adversarial loss at step 3600: 0.812257




discriminator loss at step 3700: 0.749928
adversarial loss at step 3700: 0.808666




discriminator loss at step 3800: 0.724461
adversarial loss at step 3800: 0.852933




discriminator loss at step 3900: 0.676857
adversarial loss at step 3900: 0.996093




discriminator loss at step 4000: 0.701145
adversarial loss at step 4000: 0.737347




discriminator loss at step 4100: 0.684454
adversarial loss at step 4100: 0.710306




discriminator loss at step 4200: 0.701707
adversarial loss at step 4200: 0.794332




discriminator loss at step 4300: 0.695996
adversarial loss at step 4300: 0.758902




discriminator loss at step 4400: 0.684178
adversarial loss at step 4400: 0.764428




discriminator loss at step 4500: 0.694421
adversarial loss at step 4500: 0.708887




discriminator loss at step 4600: 0.693588
adversarial loss at step 4600: 0.731158




discriminator loss at step 4700: 0.681196
adversarial loss at step 4700: 0.748954




discriminator loss at step 4800: 0.688656
adversarial loss at step 4800: 0.706575




discriminator loss at step 4900: 0.694722
adversarial loss at step 4900: 0.716507




discriminator loss at step 5000: 0.672391
adversarial loss at step 5000: 0.671072




discriminator loss at step 5100: 0.67501
adversarial loss at step 5100: 0.762064


discriminator loss at step 5200: 0.696213
adversarial loss at step 5200: 0.783617




discriminator loss at step 5300: 0.698123
adversarial loss at step 5300: 0.751154




discriminator loss at step 5400: 0.694693
adversarial loss at step 5400: 0.75365




discriminator loss at step 5500: 0.694983
adversarial loss at step 5500: 0.717285




discriminator loss at step 5600: 0.688334
adversarial loss at step 5600: 0.762959




discriminator loss at step 5700: 0.709189
adversarial loss at step 5700: 0.770763




discriminator loss at step 5800: 0.683092
adversarial loss at step 5800: 0.794479




discriminator loss at step 5900: 0.693747
adversarial loss at step 5900: 1.02038




discriminator loss at step 6000: 0.697186
adversarial loss at step 6000: 0.741706




discriminator loss at step 6100: 0.698089
adversarial loss at step 6100: 0.758648




discriminator loss at step 6200: 0.730755
adversarial loss at step 6200: 0.83014




discriminator loss at step 6300: 0.688814
adversarial loss at step 6300: 0.752135




discriminator loss at step 6400: 0.719298
adversarial loss at step 6400: 0.849147




discriminator loss at step 6500: 0.6941
adversarial loss at step 6500: 0.728641




discriminator loss at step 6600: 0.749416
adversarial loss at step 6600: 0.783581




discriminator loss at step 6700: 0.749906
adversarial loss at step 6700: 0.963508




discriminator loss at step 6800: 0.663971
adversarial loss at step 6800: 0.80178


discriminator loss at step 6900: 0.696122
adversarial loss at step 6900: 0.868248




discriminator loss at step 7000: 0.682202
adversarial loss at step 7000: 0.834885




discriminator loss at step 7100: 0.716239
adversarial loss at step 7100: 0.918414




discriminator loss at step 7200: 0.757434
adversarial loss at step 7200: 0.873729




discriminator loss at step 7300: 0.718922
adversarial loss at step 7300: 0.746266




discriminator loss at step 7400: 0.709303
adversarial loss at step 7400: 0.95767




discriminator loss at step 7500: 0.756372
adversarial loss at step 7500: 0.958449




discriminator loss at step 7600: 0.668054
adversarial loss at step 7600: 0.760841




discriminator loss at step 7700: 0.726013
adversarial loss at step 7700: 0.772983




discriminator loss at step 7800: 0.797494
adversarial loss at step 7800: 1.17269




discriminator loss at step 7900: 0.665403
adversarial loss at step 7900: 0.703456




discriminator loss at step 8000: 0.665854
adversarial loss at step 8000: 0.88767




discriminator loss at step 8100: 0.661141
adversarial loss at step 8100: 0.700408




discriminator loss at step 8200: 0.730118
adversarial loss at step 8200: 0.726783




discriminator loss at step 8300: 0.699406
adversarial loss at step 8300: 0.908771




discriminator loss at step 8400: 0.718274
adversarial loss at step 8400: 0.772979




discriminator loss at step 8500: 0.680671
adversarial loss at step 8500: 0.72406


discriminator loss at step 8600: 0.706079
adversarial loss at step 8600: 0.828559




discriminator loss at step 8700: 0.707501
adversarial loss at step 8700: 0.666524




discriminator loss at step 8800: 0.650842
adversarial loss at step 8800: 0.758531




discriminator loss at step 8900: 0.727954
adversarial loss at step 8900: 0.916939




discriminator loss at step 9000: 0.633656
adversarial loss at step 9000: 0.90855




discriminator loss at step 9100: 0.74661
adversarial loss at step 9100: 1.00563




discriminator loss at step 9200: 0.726609
adversarial loss at step 9200: 0.803032




discriminator loss at step 9300: 0.656879
adversarial loss at step 9300: 0.78234




discriminator loss at step 9400: 0.706414
adversarial loss at step 9400: 0.733639




discriminator loss at step 9500: 0.729519
adversarial loss at step 9500: 0.695822




discriminator loss at step 9600: 0.719787
adversarial loss at step 9600: 0.667924




discriminator loss at step 9700: 0.70726
adversarial loss at step 9700: 0.734278




discriminator loss at step 9800: 0.711661
adversarial loss at step 9800: 0.841383




discriminator loss at step 9900: 0.774486
adversarial loss at step 9900: 0.899968






Let's display a few of our fake images:

In [6]:
import matplotlib.pyplot as plt

# Sample random points in the latent space
random_latent_vectors = np.random.normal(size=(10, latent_dim))

# Decode them to fake images
generated_images = generator.predict(random_latent_vectors)

for i in range(generated_images.shape[0]):
    img = image.array_to_img(generated_images[i] * 255., scale=False)
    plt.figure()
    plt.imshow(img)
    
plt.show()

<matplotlib.figure.Figure at 0x7f531c6a5e48>

<matplotlib.figure.Figure at 0x7f5354b27828>

<matplotlib.figure.Figure at 0x7f52d8fcc7b8>

<matplotlib.figure.Figure at 0x7f52d8f9d518>

<matplotlib.figure.Figure at 0x7f52d8f6aa58>

<matplotlib.figure.Figure at 0x7f52d8f76ac8>

<matplotlib.figure.Figure at 0x7f52d8ef1748>

<matplotlib.figure.Figure at 0x7f52d8e407b8>

<matplotlib.figure.Figure at 0x7f4fbe23d470>

<matplotlib.figure.Figure at 0x7f4fbc0e5ba8>

Froggy with some pixellated artifacts.