### WGANomaly --  An Improved Wasserstein GAN with an Encoder reversing the Generator for anomaly detection

The Improved Wasserstein GAN algorithm can be found in the paper: https://arxiv.org/abs/1704.00028 

A documented version of Improved WGAN implementation in Keras can be found in Keras community distribution https://github.com/keras-team/keras-contrib/blob/master/examples/improved_wgan.py
A similar code is used and an additional model inverses the generator.

Here we apply the anomaly detection on Mnist dataset, we consider in this notebook the 0 digit as being abnormal and the others are being normal.

First, we train the Improved Wasserstein GAN on digits form 1 to 9. Then an encoder $E$ stacked with the generator $G$ is trained. The reconstitution $||x - G(E(x)||_2$  of the test digit serves as an anomaly detection score


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Dense, Reshape, Flatten, Activation, Dropout, Conv2DTranspose, BatchNormalization
from keras.layers.merge import _Merge
from keras.layers.convolutional import Convolution2D, Conv2DTranspose
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from keras.datasets import mnist
from keras import backend as K
from functools import partial
import keras.backend as K

In [None]:
BATCH_SIZE_WGAN = 64 #WGAN batch_size for generator and discriminator
BATCH_SIZE_ENC_GEN = 64 #WGANoly batch_size for encoding model
TRAINING_RATIO = 5 # number of times the discriminator is trained for one step of generator training
GRADIENT_PENALTY_WEIGHT = 10 #gradient penalty used in the improved WGAN version
LATENT_SPACE_DIM = 128 #Dimension of input of the generator
WGAN_EPOCHS = 30 
WGANomaly_EPOCHS = 15

### Dataset loading

In [None]:
#Mnist dataset:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28,28,1)
X_test = X_test.reshape(X_test.shape[0], 28,28,1)
X_train = (X_train.astype(np.float32)) / 256

#Mnist without 0 digit for WGAN training:
X_train_0 = X_train[y_train!=0]

In [None]:
#Labels for critic training:
true_label = np.ones((BATCH_SIZE_WGAN, 1), dtype=np.float32)
fake_label = -np.ones((BATCH_SIZE_WGAN, 1), dtype=np.float32)
dummy_y = np.zeros((BATCH_SIZE_WGAN, 1), dtype=np.float32) # used for penalization

### Networks used and models

In [None]:
#Generator of the GAN
def make_gen():
    input_gen = Input(shape=(LATENT_SPACE_DIM,))
    model = Dense(1024)(input_gen)
    model = LeakyReLU()(model)
    model = Dense(128*7*7)(model)
    model = BatchNormalization()(model)
    model = LeakyReLU()(model)
    model = Reshape((7, 7, 128))(model)
    model = Conv2DTranspose(128, (5, 5), strides=2, padding='same')(model)
    model = BatchNormalization()(model)
    model = LeakyReLU()(model)
    model = Convolution2D( 64, (5, 5), padding='same')(model)
    model = BatchNormalization()(model)
    model = LeakyReLU()(model)
    model = Conv2DTranspose( 64, (5, 5), strides=2, padding='same')(model)
    model = BatchNormalization()(model)
    model = LeakyReLU()(model)
    model = Convolution2D(1, (5, 5), padding='same', activation='sigmoid')(model)
    return Model(input_gen,model)
gen=make_gen()
gen.summary()

#Discriminator de the GAN
def make_critic():
    input_critic = Input(shape=(28,28,1))
    model= Convolution2D(64, (5, 5), padding='same')(input_critic)
    model = LeakyReLU()(model)
    model = Convolution2D(128, (5, 5), kernel_initializer='he_normal',strides=[2, 2])(model)
    model = LeakyReLU()(model)
    model = Convolution2D(128, (5, 5), kernel_initializer='he_normal', padding='same',strides=[2, 2])(model)
    model = LeakyReLU()(model)
    model = Flatten()(model)
    model = Dense(1024, kernel_initializer='he_normal')(model)
    model = LeakyReLU()(model)
    model = Dense(1, kernel_initializer='he_normal')(model)
    return Model(input_critic,model)
critic = make_critic()
critic.summary()

#Encoder for WGANomaly
def make_encoder():
    input_encoder = Input(shape=(28,28,1))
    model= Convolution2D(64, (5, 5), padding='same')(input_encoder)
    model = LeakyReLU()(model)
    model = Convolution2D(128, (5, 5), kernel_initializer='he_normal',strides=[2, 2])(model)
    model = LeakyReLU()(model)
    model = Convolution2D(128, (5, 5), kernel_initializer='he_normal', padding='same',strides=[2, 2])(model)
    model = LeakyReLU()(model)
    model = Flatten()(model)
    model = Dense(1024, kernel_initializer='he_normal')(model)
    model = LeakyReLU()(model)
    model = Dense(LATENT_SPACE_DIM, kernel_initializer='he_normal')(model)
    model = BatchNormalization()(model)
    return Model(input_encoder,model)
encoder = make_encoder()
encoder.summary()

In [None]:
#Keras improved wasserstein loss from keras-contrib
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr,
                            axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)

class RandomWeightedAverage(_Merge):
    def _merge_function(self, inputs):
        weights = K.random_uniform((BATCH_SIZE_WGAN, 1, 1, 1))
        return (weights * inputs[0]) + ((1 - weights) * inputs[1])

In [None]:

#Generator model:
gen.trainable = True
critic.trainable = False #We only want to train the generator
input_gen_model = Input(shape=(LATENT_SPACE_DIM,))
output_gen_model = critic(gen(input_gen_model))
gen_model = Model(inputs=[input_gen_model], outputs=[output_gen_model])
gen_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9), loss=wasserstein_loss)

#Critic model:
critic.trainable = True #We only want to train the critic
gen.trainable = False

input_true = Input(shape=X_train[0].shape)
input_latent = Input(shape=(LATENT_SPACE_DIM,))
input_false = gen(input_latent)
averaged_inputs = RandomWeightedAverage()([input_true, input_false]) #Used for penalization

output_false = critic(input_false)
output_true = critic(input_true)
output_average = critic(averaged_inputs)

partial_gp_loss = partial(gradient_penalty_loss,averaged_samples=averaged_inputs,gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
partial_gp_loss.__name__ = 'gradient_penalty' 

critic_model = Model(inputs=[input_true, input_latent],outputs=[output_true, output_false,output_average])
critic_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),loss=[wasserstein_loss,wasserstein_loss,partial_gp_loss])


In [None]:
critic_model.summary()

### Imroved WGAN training

In [None]:
def minibatch_train(Xtrain,bacth_size):
    index = np.random.randint(0,len(Xtrain),bacth_size)
    return Xtrain[index]

In [None]:
#Training of improved WGAN:
for epoch in range(WGAN_EPOCHS):
    print("Epoch: ", epoch)
    critic_loss = []
    generator_loss = []
    for i in range(int(X_train_0.shape[0] // (BATCH_SIZE_WGAN ))):
        for j in range(TRAINING_RATIO):
            train_images = minibatch_train(X_train_0,BATCH_SIZE_WGAN)
            noise_critic = np.array([np.random.normal(0,1, LATENT_SPACE_DIM) for i in range(BATCH_SIZE_WGAN)]).astype(np.float32)
            critic_loss.append(critic_model.train_on_batch([train_images, noise_critic],
                                                                       [true_label, fake_label, dummy_y]))
        noise_gen = np.array([np.random.normal(0,1, LATENT_SPACE_DIM) for i in range(BATCH_SIZE_WGAN)]).astype(np.float32)
        generator_loss.append(gen_model.train_on_batch(noise_gen, true_label))
        print("critic loss:",critic_loss[-1], end = '  ')
        print("generator_loss:",generator_loss[-1])

### WGANomaly model

In [None]:
#Encoder model for WGANomaly:
gen2= make_gen()
encoder.trainable = True
gen2.trainable = False

Input_encoder = Input(shape=X_train_0.shape[1:])
Output_encoder_gen = gen2(encoder(Input_encoder))

encoder_model = Model(Input_encoder,Output_encoder_gen)
encoder_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),loss='mean_squared_error')
gen2.set_weights(gen.get_weights())
gen2.summary()

### WGANomaly training

In [None]:
#Training of WGANomaly after the convergence of the generator:
for epochs in range(WGANomaly_EPOCHS):
    print("Epoch: ", epoch)
    WGANomal_loss= []
    for ii in range(len(X_train)//BATCH_SIZE_ENC_GEN):
        images_batch = minibatch_train(X_train_0,BATCH_SIZE_ENC_GEN)
        WGANomal_loss.append(encoder_model.train_on_batch(images_batch,images_batch))
        print("WGANomaly loss:", WGANomal_loss[-1])

### WGANomaly scores
The higher the score, the higher the probability that sample is abnormal

In [None]:
zeros_test = np.concatenate((X_train[y_train == 0],X_test[y_test == 0]),axis=0)
non_zeros_test = X_test[y_test!=0]

In [None]:
from sklearn.metrics import mean_squared_error
abnormal_samples_score = []
normal_sample_score = []
for i in range(len(zeros_test)):
    abnormal_samples_score.append(mean_squared_error(encoder_model.predict(zeros_test[i].reshape(1,28,28,1)).reshape(28,28),zeros_test[i].reshape(28,28)))
for j in range(len(non_zeros_test)):
    normal_samples_score.append(mean_squared_error(encoder_model.predict(non_zeros_test[j].reshape(1,28,28,1)).reshape(28,28),non_zeros_test[j].reshape(28,28)))    