### GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training
Paper: https://arxiv.org/abs/1805.06725

In [None]:
from importlib import reload
from hashlib import sha256
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import models.ganomaly as mg

In [None]:
input_shape = (32,32,3)

In [None]:
reload(mg)
en = mg.Encoder(input_shape)
en.build((None,*input_shape))
en.summary()

In [None]:
reload(mg)
de = mg.Decoder(input_shape)
de.build((None,4,4,100))
de.summary()

In [None]:
reload(mg)
dis = mg.NetD(input_shape)
dis.build((None,*input_shape))
dis.summary()

In [None]:
reload(mg)
gen = mg.NetG(input_shape)
#gen.build((None,*input_shape))
#gen.summary()

In [None]:
def test_model(model):
    nimg = 10
    rimg = np.ones(
        np.prod(input_shape) * nimg,
        dtype='float32'
    ).reshape((nimg,*input_shape))

    gimg, lati, lato = model(rimg)
    print(sha256(gimg.numpy()).hexdigest())
    print(sha256(lati.numpy()).hexdigest())
    print(sha256(lato.numpy()).hexdigest())

In [None]:
test_model(gen)

In [None]:
gen.save("my_model/test")
restored = tf.keras.models.load_model("my_model/test")
#restored.summary()
#restored.encoder_i.summary()
#restored.decoder.summary()
#restored.encoder_o.summary()
test_model(restored)

In [None]:
gen.save_weights("my_model/weights")
test_model(gen)

In [None]:
gen = mg.NetG(input_shape)
test_model(gen)

In [None]:
gen.load_weights("my_model/weights")
test_model(gen)

In [None]:
from utils.datasets import *
(train_images, train_labels), (test_images, test_labels) = get_dataset_cifar10()

In [None]:
reload(mg)
model = mg.GANomaly(
    input_shape=train_images[0].shape
)
#model.build((None,*train_images[0].shape))
model.compile()

In [None]:
tf.config.run_functions_eagerly(False)
results = model.fit(
    x=train_images[:640],
    batch_size=64,
    epochs=10
)

In [None]:
print(results.history.keys())

plt.plot(results.history['err_g'])
plt.title('generator loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['generator'], loc='upper right')
plt.show()

plt.plot(results.history['err_g_adv'])
plt.plot(results.history['err_g_con'])
plt.plot(results.history['err_g_enc'])
plt.title('generator specific losses')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['adversarial', 'reconstruction', 'encoder'], loc='upper right')
plt.show()

plt.plot(results.history['err_d'])
plt.plot(results.history['err_d_real'])
plt.plot(results.history['err_d_fake'])
plt.title('discriminator losses')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['sum', 'real', 'fake'], loc='upper right')
plt.show()