# Wasserstein GAN - Improved

ref.:
GULRAJANI, Ishaan et al.  
Improved training of wasserstein gans.  
In: Advances in neural information processing systems. 2017. p. 5767-5777.

![new objective](wgan_gp_objective.png)

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from keras.datasets import mnist
from keras.layers.merge import _Merge
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
from functools import partial

import keras.backend as K

Using TensorFlow backend.


In [3]:
# auxiliary class

In [4]:
class RandomWeightedAverage(_Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((32, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

## Generator

In [5]:
def build_generator(latent_dim, channels):

    model = Sequential()

    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=4, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=4, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(Conv2D(channels, kernel_size=4, padding="same"))
    model.add(Activation("tanh"))

    model.summary()

    noise = Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

## Critic

In [6]:
def build_critic(img_shape):

    model = Sequential()

    model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(1))

    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

## helper

In [7]:
def gradient_penalty_loss(y_true, y_pred, averaged_samples):
    """
    Computes gradient penalty based on prediction and weighted real / fake samples
    """
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

In [8]:
def wasserstein_loss(y_true, y_pred):
    """ same loss as in WGAN """
    return K.mean(y_true * y_pred)

In [9]:
def sample_images(G, latent_dim, epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = G.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

## train

In [10]:
def train(G, critic_model, generator_model, 
          n_critic, latent_dim,
          epochs, batch_size, sample_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake =  np.ones((batch_size, 1))
    dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
    for epoch in range(epochs):

        for _ in range(n_critic):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
            # Sample generator input
            noise = np.random.normal(0, 1, (batch_size, latent_dim))
            # Train the critic
            d_loss = critic_model.train_on_batch([imgs, noise],
                                                 [valid, fake, dummy])

        # ---------------------
        #  Train Generator
        # ---------------------

        g_loss = generator_model.train_on_batch(noise, valid)

        # Plot the progress
        print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0 or epoch == epochs - 1:
            sample_images(G, latent_dim, epoch)

# main()

In [11]:
if not os.path.exists('images'):
    os.makedirs('images')

In [12]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

In [13]:
# Following parameter and optimizer set as recommended in paper
n_critic = 5

In [14]:
# create optimizer
optimizer = RMSprop(lr=0.00005)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [15]:
# Build the generator model
G = build_generator(latent_dim, channels)

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 6272)              633472    
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 128)       262272    
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 128)      

In [16]:
# Build the critic model
C = build_critic(img_shape)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 14, 14, 16)        160       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 16)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 7, 7, 32)          4640      
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 8, 8, 32)          0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 8, 8, 32)          128       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 32)         

In [17]:
#-------------------------------
# Construct Computational Graph
#       for the Critic
#-------------------------------

In [18]:
# Freeze generator's layers while training critic
G.trainable = False

In [19]:
# Image input (real sample)
real_img = Input(shape=img_shape)

In [20]:
# Noise input
z_disc = Input(shape=(latent_dim,))

# Generate image based of noise (fake sample)
fake_img = G(z_disc)

In [21]:
# Discriminator determines validity of the real and fake images
fake = C(fake_img)
valid = C(real_img)

In [22]:
# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage()([real_img, fake_img])

In [23]:
# Determine validity of weighted sample
validity_interpolated = C(interpolated_img)

In [24]:
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(gradient_penalty_loss,
                          averaged_samples=interpolated_img)

partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

critic_model = Model(inputs=[real_img, z_disc],
                     outputs=[valid, fake, validity_interpolated])

critic_model.compile(loss=[wasserstein_loss,
                           wasserstein_loss,
                           partial_gp_loss],
                     optimizer=optimizer,
                     loss_weights=[1, 1, 10])

In [25]:
#-------------------------------
# Construct Computational Graph
#         for Generator
#-------------------------------

In [26]:
# For the generator we freeze the critic's layers
C.trainable = False
G.trainable = True  # this is the default

In [27]:
# Sampled noise for input to generator
z_gen = Input(shape=(latent_dim,))

In [28]:
# Generate images based of noise
img = G(z_gen)

In [29]:
# Discriminator determines validity
valid = C(img)

In [30]:
# Defines generator model
generator_model = Model(z_gen, valid)
generator_model.compile(loss=wasserstein_loss, optimizer=optimizer)

In [31]:
# train

In [None]:
# epochs=30000
epochs=3000
train(G, critic_model, generator_model, 
      n_critic, latent_dim,
      epochs=epochs, batch_size=32, sample_interval=100)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


  'Discrepancy between trainable weights and collected trainable'





  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 25.070177] [G loss: 0.297224]


  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 15.671294] [G loss: 0.142512]
2 [D loss: 15.337543] [G loss: 0.230281]
3 [D loss: 12.432339] [G loss: 0.056703]
4 [D loss: 10.851179] [G loss: 0.147348]
5 [D loss: 8.952878] [G loss: 0.442589]
6 [D loss: 7.261176] [G loss: -0.013684]
7 [D loss: 6.972205] [G loss: 0.325016]
8 [D loss: 5.035443] [G loss: 0.185003]
9 [D loss: 5.259032] [G loss: 0.004149]
10 [D loss: 4.227239] [G loss: 0.058557]
11 [D loss: 3.524243] [G loss: 0.075562]
12 [D loss: 3.562151] [G loss: 0.075373]
13 [D loss: 1.219465] [G loss: 0.086477]
14 [D loss: 1.879110] [G loss: -0.013041]
15 [D loss: 1.456353] [G loss: 0.004118]
16 [D loss: 1.266511] [G loss: -0.044474]
17 [D loss: 0.833396] [G loss: -0.189196]
18 [D loss: 0.918401] [G loss: -0.054217]
19 [D loss: 0.098476] [G loss: -0.063081]
20 [D loss: 0.282075] [G loss: -0.028315]
21 [D loss: 0.161475] [G loss: -0.138403]
22 [D loss: -0.262619] [G loss: -0.491620]
23 [D loss: -0.266399] [G loss: -0.451789]
24 [D loss: -0.206631] [G loss: -0.407562]
25 [D l

191 [D loss: -1.064137] [G loss: -2.537212]
192 [D loss: -0.683361] [G loss: -2.275415]
193 [D loss: -0.389910] [G loss: -2.385437]
194 [D loss: -0.748865] [G loss: -2.279803]
195 [D loss: -0.814931] [G loss: -2.298585]
196 [D loss: -0.349325] [G loss: -2.511345]
197 [D loss: -0.440804] [G loss: -2.617544]
198 [D loss: -0.169316] [G loss: -2.125013]
199 [D loss: -0.246127] [G loss: -2.563831]
200 [D loss: -0.301772] [G loss: -2.467508]
201 [D loss: -0.017004] [G loss: -2.208903]
202 [D loss: -0.580021] [G loss: -2.741154]
203 [D loss: -0.481001] [G loss: -2.209312]
204 [D loss: -0.145891] [G loss: -2.632309]
205 [D loss: -0.099586] [G loss: -2.275704]
206 [D loss: -0.447367] [G loss: -2.638948]
207 [D loss: -0.253114] [G loss: -2.339246]
208 [D loss: 0.010422] [G loss: -2.480601]
209 [D loss: -0.445769] [G loss: -2.459414]
210 [D loss: -0.323106] [G loss: -2.534847]
211 [D loss: -0.299835] [G loss: -2.672637]
212 [D loss: -0.196829] [G loss: -2.628883]
213 [D loss: -0.516484] [G loss: 

378 [D loss: -0.116739] [G loss: -2.648970]
379 [D loss: 0.010750] [G loss: -2.589654]
380 [D loss: -0.441586] [G loss: -2.460656]
381 [D loss: -0.206813] [G loss: -2.441152]
382 [D loss: 0.281966] [G loss: -2.474371]
383 [D loss: -0.133488] [G loss: -2.311149]
384 [D loss: -0.292996] [G loss: -2.403698]
385 [D loss: -0.067222] [G loss: -2.365760]
386 [D loss: -0.745061] [G loss: -2.979283]
387 [D loss: -0.386559] [G loss: -2.123550]
388 [D loss: -0.137915] [G loss: -2.242775]
389 [D loss: 0.084619] [G loss: -2.664877]
390 [D loss: -0.497101] [G loss: -2.509724]
391 [D loss: -0.264946] [G loss: -2.902325]
392 [D loss: -0.116550] [G loss: -2.418616]
393 [D loss: -0.295383] [G loss: -2.787032]
394 [D loss: -0.089682] [G loss: -2.645774]
395 [D loss: -0.830465] [G loss: -2.395820]
396 [D loss: -0.566966] [G loss: -2.665959]
397 [D loss: -0.239933] [G loss: -2.603440]
398 [D loss: -0.456670] [G loss: -2.405239]
399 [D loss: -0.190577] [G loss: -2.686404]
400 [D loss: -0.003322] [G loss: -2

565 [D loss: -0.297545] [G loss: -2.039172]
566 [D loss: -0.113412] [G loss: -2.333783]
567 [D loss: -0.248245] [G loss: -1.928136]
568 [D loss: -0.206390] [G loss: -1.970325]
569 [D loss: -0.012758] [G loss: -1.821019]
570 [D loss: -0.846002] [G loss: -2.105019]
571 [D loss: -0.063511] [G loss: -1.975502]
572 [D loss: -0.610505] [G loss: -1.660715]
573 [D loss: -0.236640] [G loss: -1.977567]
574 [D loss: -0.347666] [G loss: -1.988389]
575 [D loss: 0.019874] [G loss: -1.950062]
576 [D loss: -0.383859] [G loss: -1.820282]
577 [D loss: -0.378650] [G loss: -1.856554]
578 [D loss: 0.213223] [G loss: -1.944825]
579 [D loss: -0.539018] [G loss: -2.008958]
580 [D loss: -0.522736] [G loss: -2.102289]
581 [D loss: -0.426613] [G loss: -2.002568]
582 [D loss: -0.589562] [G loss: -2.386572]
583 [D loss: 0.024214] [G loss: -2.328669]
584 [D loss: -0.349392] [G loss: -2.131880]
585 [D loss: 0.075837] [G loss: -2.163935]
586 [D loss: -0.381199] [G loss: -1.793773]
587 [D loss: -0.307189] [G loss: -1.

752 [D loss: -1.178323] [G loss: -1.816783]
753 [D loss: -0.934460] [G loss: -1.955165]
754 [D loss: -0.054417] [G loss: -1.774947]
755 [D loss: -1.108276] [G loss: -1.511630]
756 [D loss: -0.133919] [G loss: -1.576594]
757 [D loss: -0.886952] [G loss: -2.077778]
758 [D loss: -0.239481] [G loss: -1.233009]
759 [D loss: -0.465669] [G loss: -1.725977]
760 [D loss: -0.373216] [G loss: -1.939457]
761 [D loss: -0.955613] [G loss: -1.587133]
762 [D loss: -0.172634] [G loss: -1.825816]
763 [D loss: -0.171664] [G loss: -2.080106]
764 [D loss: -0.312583] [G loss: -1.789814]
765 [D loss: -0.087180] [G loss: -1.606876]
766 [D loss: -0.622036] [G loss: -1.540080]
767 [D loss: -0.182765] [G loss: -2.016765]
768 [D loss: 0.049768] [G loss: -1.949809]
769 [D loss: -0.102701] [G loss: -1.811347]
770 [D loss: -0.862979] [G loss: -1.765056]
771 [D loss: -0.153048] [G loss: -2.138021]
772 [D loss: -0.980550] [G loss: -1.665675]
773 [D loss: -0.211273] [G loss: -1.793531]
774 [D loss: -0.725310] [G loss: 