In [28]:
import keras.backend as K
from collections import defaultdict
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Multiply, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D,Conv2DTranspose
from keras.layers.noise import GaussianNoise
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD
from keras.backend.common import _EPSILON
from keras.utils.generic_utils import Progbar
import numpy as np

In [29]:
def wasserstein_loss(y_true, y_pred):
    
    """Calculates the Wasserstein loss for a sample batch.
    
    The Wasserstein loss function is very simple to calculate. In a standard GAN, the discriminator
    has a sigmoid output, representing the probability that samples are real or generated. In Wasserstein
    GANs, however, the output is linear with no activation function! Instead of being constrained to [0, 1],
    the discriminator wants to make the distance between its output for real and generated samples as large as possible.
    The most natural way to achieve this is to label generated samples -1 and real samples 1, instead of the
    0 and 1 used in normal GANs, so that multiplying the outputs by the labels will give you the loss immediately.
    Note that the nature of this loss means that it can be (and frequently will be) less than 0."""
    
    return K.mean(y_true * y_pred)

In [30]:
def make_generator(latent_size):
    """Creates a generator model and outputs images of size 28x28x1."""
    model = Sequential()
    model.add(Dense(1024, input_dim=latent_size))
    model.add(LeakyReLU())
    model.add(Dense(128 * 7 * 7))
    model.add(LeakyReLU())
    model.add(Reshape((128, 7, 7), input_shape=(128 * 7 * 7,)))
    # upsample to (..., 14, 14)
    model.add(Conv2DTranspose(128, (5, 5), strides=2, padding='same'))
    model.add(LeakyReLU())
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(LeakyReLU())
    # upsample to (..., 28, 28)
    model.add(Conv2DTranspose(64, (5, 5), strides=2, padding='same'))
    model.add(LeakyReLU())
    # Because we normalized training inputs to lie in the range [-1, 1],
    # the tanh function should be used for the output of the generator to ensure its output
    # also lies in this range.
    model.add(Conv2D(1, (5, 5), padding='same', activation='tanh'))
    
    latent = Input(shape=(latent_size, ))

    # sample label
    image_class = Input(shape=(1,), dtype='int32')

    #  MNIST classes
    cls = Flatten()(Embedding(10, latent_size, embeddings_initializer ='glorot_uniform')(image_class))

    # hadamard product between z-space and a class conditional embedding
    h = Multiply()([latent, cls])

    fake_image = model(h)

    return Model(inputs=[latent, image_class], outputs=fake_image)

In [31]:
def make_discriminator():
    model = Sequential()
    model.add(Conv2D(64, (5, 5), padding='same', strides=(2, 2), input_shape=(1, 28, 28)))
    model.add(LeakyReLU())
    model.add(Conv2D(128, (5, 5), kernel_initializer='he_normal',padding='same', strides=(2, 2)))
    model.add(LeakyReLU())
    model.add(Conv2D(128, (5, 5), kernel_initializer='he_normal', padding='same', strides=[2, 2]))
    model.add(LeakyReLU())
    model.add(Flatten())
    
    image = Input(shape=(1, 28, 28))
    features = model(image)
    # first output (name=generation) is whether or not the discriminator
    # thinks the image that is being shown is fake, and the second output
    # (name=auxiliary) is the class that the discriminator thinks the image
    # belongs to.
    fake = Dense(1, activation='linear', name='generation')(features)
    aux = Dense(10, activation='softmax', name='auxiliary')(features)

    return Model(inputs=image, outputs=[fake, aux])


In [32]:
# batch and latent size taken from the paper
nb_epochs = 50
batch_size = 100
latent_size = 100

# Adam parameters suggested in https://arxiv.org/abs/1511.06434
adam_lr = 0.0002
adam_beta_1 = 0.5


In [33]:
K.set_image_dim_ordering('th')
# build the discriminator
discriminator = make_discriminator()
discriminator.compile(optimizer= Adam(lr=adam_lr, beta_1=adam_beta_1),
loss=[wasserstein_loss, 'sparse_categorical_crossentropy'])
    
    
# build the generator
generator = make_generator(latent_size)
generator.compile(optimizer=Adam(lr=adam_lr, beta_1=adam_beta_1),
                      loss='binary_crossentropy')

latent = Input(shape=(latent_size, ))
image_class = Input(shape=(1,), dtype='int32')
    
# get a fake image
fake = generator([latent, image_class])
# we only want to be able to train generation for the combined model
discriminator.trainable = False
fake, aux = discriminator(fake)
combined = Model(inputs=[latent, image_class], outputs=[fake, aux])

combined.compile(optimizer='RMSprop', loss=[wasserstein_loss, 'sparse_categorical_crossentropy'])


In [34]:
# get our mnist data, and force it to be of shape (..., 1, 28, 28) with # range [-1, 1]
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=1)
    
nb_train, nb_test = X_train.shape[0], X_test.shape[0]


In [35]:
from PIL import Image
def generate_images(epoch):
        # generate some digits to display
        noise = np.random.normal(-1, 1, (100, latent_size))
        sampled_labels = np.array([[i] * 10 for i in range(10)]).reshape(-1, 1)

        # get a batch to display
        generated_images = generator.predict([noise, sampled_labels], verbose=0)

        # arrange them into a grid
        img = (np.concatenate([r.reshape(-1, 28)
                               for r in np.split(generated_images, 10)
                               ], axis=-1) * 127.5 + 127.5).astype(np.uint8)

        Image.fromarray(img).save('plot_epoch_{0:03d}_generated.png'.format(epoch))


In [36]:
discriminator_train_loss = []
generator_train_loss = [] 


for epoch in range(nb_epochs):
    
        print('Epoch {} of {}'.format(epoch + 1, nb_epochs))
        #No. of batches
        nb_batches = int(X_train.shape[0] / batch_size)
        #Keras Progress bar
        progress_bar = Progbar(target=nb_batches)

        gen_loss = []
        dis_loss = []

        for index in range(nb_batches):
            if len(gen_loss) + len(dis_loss) > 1:
                progress_bar.update(index, values=[('dis_loss',np.mean(np.array(dis_loss),axis=0)[0]), ('gen_loss', np.mean(np.array(gen_loss),axis=0)[0])])
            else:
                progress_bar.update(index)
                
            # generate a new batch of noise
            noise = np.random.normal(0, 1, (batch_size, latent_size))

            # get a batch of real images
            image_batch = X_train[index * batch_size:(index + 1) * batch_size]
            label_batch = y_train[index * batch_size:(index + 1) * batch_size]

            # sample some labels from p_c
            sampled_labels = np.random.randint(0, 10, batch_size)

            # generate a batch of fake images, using the generated labels as a
            # conditioner.
            generated_images = generator.predict([noise, sampled_labels.reshape((-1, 1))], verbose=0)

            X = np.concatenate((image_batch, generated_images))
            y = np.array([-1] * batch_size + [1] * batch_size)
            aux_y = np.concatenate((label_batch, sampled_labels), axis=0)

            # Training the discriminator network
            dis_loss.append(discriminator.train_on_batch(X, [y, aux_y]))

            # make new noise. we generate 2 * batch size here such that we have
            # the generator optimize over an identical number of images as the
            # discriminator
            noise = np.random.normal(0, 1, (2 * batch_size, latent_size))
            sampled_labels = np.random.randint(0, 10, 2 * batch_size)

            # we want to train the genrator to trick the discriminator
            # For the generator, we want all the {fake, not-fake} labels to say
            # not-fake
            trick = -np.ones(2 * batch_size)

            gen_loss.append(combined.train_on_batch(
                [noise, sampled_labels.reshape((-1, 1))], [trick, sampled_labels]))

        print('\nTesting for epoch {}:'.format(epoch + 1))
        
        discriminator_train_loss.append(np.mean(np.array(dis_loss), axis=0))
        generator_train_loss.append(np.mean(np.array(gen_loss), axis=0))

        
        # save weights every epoch
        generator.save_weights(
            'params_generator_epoch_{0:03d}.hdf5'.format(epoch), True)
        discriminator.save_weights(
            'params_discriminator_epoch_{0:03d}.hdf5'.format(epoch), True)
        
        #function to generate image at every epoch!
        generate_images(epoch)
        
    

Epoch 1 of 50
  3/600 [..............................] - ETA: 6401s - dis_loss: 2.2597 - gen_loss: 2.1866

KeyboardInterrupt: 