### GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training
Paper: https://arxiv.org/abs/1805.06725

In [None]:
from importlib import reload
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import models.ganomaly as mg
import datasets.common as cds
import datasets.mvtec_ad as mvds
import utils.datasets as dsu
import utils.plot as pu

AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
for mod in (mg, cds, dsu, pu):
    reload(mod)

## Legacy datasets

In [None]:
(train_images, train_labels), (test_images, test_labels) = dsu.create_anomaly_dataset(
    cds.get_dataset('mnist'),
    abnormal_class=2
)
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)
#print(train_labels[:10])
#print(test_labels[5395:5405])

In [None]:
if train_images.shape[1] > 64:
    train_images = tf.image.resize(train_images, (64,64))
    test_images = tf.image.resize(test_images, (64,64))
elif train_images.shape[1] not in [2**x for x in range(10)]:
    power = 1
    while power < train_images.shape[1]:
        power *= 2
    new_size = (power, power)
    print("resizing to:", new_size)
    train_images = tf.image.resize(train_images, new_size)
    test_images = tf.image.resize(test_images, new_size)
train_labels = train_labels.reshape((-1,1))
test_labels = test_labels.reshape((-1,1))
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)

In [None]:
abnormal_start = dsu.find_abnormal_start_index(test_labels)

In [None]:
pu.plot_images(test_images[abnormal_start-5:abnormal_start+5])

## MVTec AD Dataset

In [None]:
categories = [
    'bottle', 'cable', 'capsule', 'carpet',
    'grid', 'hazelnut', 'leather', 'metal_nut',
    'pill', 'screw', 'tile', 'toothbrush',
    'transistor', 'wood', 'zipper'
]

category = 0
channels = 1
resolution = 64
buffer_size = 1000

resize_image = lambda image, label: (tf.image.resize(image, (resolution, resolution)), label)

test_ds = mvds.get_labeled_dataset(
    category=categories[category],
    split = 'test',
    image_channels=channels,
    binary_labels=True
)
test_ds = test_ds.map(resize_image, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.cache('/tmp/tfdata_test_ds.cache')

train_ds = mvds.get_labeled_dataset(
    category=categories[category],
    split = 'train',
    image_channels=channels,
    binary_labels=True
)
train_ds = train_ds.map(resize_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache('/tmp/tfdata_train_ds.cache')
train_ds = train_ds.repeat(10)
train_ds = train_ds.shuffle(buffer_size)
train_ds = train_ds.prefetch(buffer_size)

def augment_image(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    # Add 6 pixels of padding
    image = tf.image.resize_with_crop_or_pad(image, resolution + 6, resolution + 6) 
    # Random crop back to the original size
    image = tf.image.random_crop(image, size=[resolution, resolution, channels])
    #image = tf.image.random_brightness(image, max_delta=0.5)
    #image = tf.clip_by_value(image, 0, 1)
    return image, label

train_ds = train_ds.map(augment_image, num_parallel_calls=AUTOTUNE)

train_count = tf.data.experimental.cardinality(train_ds).numpy()
test_count = tf.data.experimental.cardinality(test_ds).numpy()
print("train_count: {}, test_count: {}".format(train_count, test_count))

In [None]:
reload(mg)
model = mg.GANomaly(
    #input_shape=train_images[0].shape,
    input_shape=(resolution, resolution, channels),
    latent_size=300
)
model.compile(metrics=[tf.keras.metrics.AUC()])
#model.build((None, *train_images[0].shape))
model.build((None, resolution, resolution, channels))

In [None]:
reload(mg)
#adcb = mg.ADModelEvaluator(test_images.shape[0])
adcb = mg.ADModelEvaluator(test_count)

In [None]:
tf.config.run_functions_eagerly(False)
batch_size = 128
results = model.fit(
    #x=train_images,
    #y=train_labels,
    x=train_ds.batch(batch_size),
    #batch_size=batch_size,
    epochs=100,
    #validation_data=(test_images, test_labels),
    #validation_batch_size=test_labels.shape[0]//10,
    validation_data=test_ds.batch(batch_size),
    callbacks=[adcb],
    #verbose=0
)

In [None]:
model.set_weights(adcb.best_weights)

In [None]:
print(results.history.keys())

plt.plot(adcb.test_results)
plt.title('test results')
plt.ylabel('AUC (ROC)')
plt.xlabel('epoch')
#plt.legend(['generator'], loc='upper right')
plt.show()

plt.plot(results.history['loss_gen'])
plt.title('generator loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['generator'], loc='upper right')
plt.show()

plt.plot(results.history['loss_gen_adv'])
plt.plot(results.history['loss_gen_rec'])
plt.plot(results.history['loss_gen_enc'])
plt.title('generator specific losses')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['adversarial', 'reconstruction', 'encoder'], loc='upper right')
plt.show()

plt.plot(results.history['loss_dis'])
plt.plot(results.history['loss_dis_real'])
plt.plot(results.history['loss_dis_fake'])
plt.title('discriminator losses')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['real/fake', 'real', 'fake'], loc='upper right')
plt.show()

In [None]:
# makes no sense as this is only the per image mse between latent_i and latent_o
tf.config.run_functions_eagerly(True)
eval_results = model.evaluate(
    x=test_images,
    y=test_labels,
    batch_size=test_images.shape[0]//10,
    verbose=0
)

In [None]:
predictions = model.predict(
    x=test_images,
    batch_size=test_images.shape[0]//10
)
print(predictions.shape)

In [None]:
min_val = np.min(predictions)
ptp_val = np.ptp(predictions)
print("ptp_val:", ptp_val, "min_val:", min_val)

predictions -= min_val
predictions /= ptp_val

print(predictions[abnormal_start-15:abnormal_start+15])

In [None]:
print(test_labels[abnormal_start-5:abnormal_start+5])

In [None]:
predictions_normal = model.predict(
    x=test_images[:abnormal_start],
    batch_size=test_images[:abnormal_start].shape[0]
)

In [None]:
print("max:", np.max(predictions_normal))
print("min:", np.min(predictions_normal))
print("mean:", np.mean(predictions_normal))
print("q(50):", np.percentile(predictions_normal, 50))
print("q(75):", np.percentile(predictions_normal, 75))
print("q(90):", np.percentile(predictions_normal, 90))
print("q(95):", np.percentile(predictions_normal, 95))
print("q(99):", np.percentile(predictions_normal, 99))

In [None]:
predictions_abnormal = model.predict(
    x=test_images[abnormal_start:],
    batch_size=test_images[abnormal_start:].shape[0]
)

In [None]:
print("max:", np.max(predictions_abnormal))
print("min:", np.min(predictions_abnormal))
print("mean:", np.mean(predictions_abnormal))
print("q(50):", np.percentile(predictions_abnormal, 50))
print("q(75):", np.percentile(predictions_abnormal, 75))
print("q(90):", np.percentile(predictions_abnormal, 90))
print("q(95):", np.percentile(predictions_abnormal, 95))
print("q(99):", np.percentile(predictions_abnormal, 99))

In [None]:
(reconstructed, latent_i, latent_o), (classifier, features) = model(test_images[abnormal_start-5:abnormal_start+5], training=False)

In [None]:
image_tuples = list(zip(
    test_images[abnormal_start-5:abnormal_start+5],
    reconstructed
))
p.plot_image_tuples(image_tuples)