### Convolutional Autoencoder

In [None]:
import logging, sys
logging.basicConfig(
    stream=sys.stdout, # Jupyter Notebook doesn't print the default stderr
    level=logging.INFO
)

In [None]:
from importlib import reload
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import models.cae as mcae
import datasets.common as cds
import datasets.mvtec_ad as mvds
import utils.callbacks as cbu
import utils.datasets as dsu
import utils.plot as pu

AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
for mod in (mcae, cds, dsu, pu):
    reload(mod)

## MVTec AD Dataset

In [None]:
categories = [
    'bottle', 'cable', 'capsule', 'carpet',
    'grid', 'hazelnut', 'leather', 'metal_nut',
    'pill', 'screw', 'tile', 'toothbrush',
    'transistor', 'wood', 'zipper'
]

category = 0
channels = 1
resolution = 64
buffer_size = 1000

resize_image = lambda image, label: (tf.image.resize(image, (resolution, resolution)), label)

test_ds = mvds.get_labeled_dataset(
    category=categories[category],
    split = 'test',
    image_channels=channels,
    binary_labels=True
)
test_ds = test_ds.map(resize_image, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.cache('/tmp/tfdata_test_ds.cache')

train_ds = mvds.get_labeled_dataset(
    category=categories[category],
    split = 'train',
    image_channels=channels,
    binary_labels=True
)
train_ds = train_ds.map(resize_image, num_parallel_calls=AUTOTUNE)

train_ds = train_ds.cache('/tmp/tfdata_train_ds.cache')
train_ds = train_ds.repeat(10)
train_ds = train_ds.shuffle(buffer_size)
train_ds = train_ds.prefetch(buffer_size)

def augment_image(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    # Add 6 pixels of padding
    image = tf.image.resize_with_crop_or_pad(image, resolution + 6, resolution + 6) 
    # Random crop back to the original size
    image = tf.image.random_crop(image, size=[resolution, resolution, channels])
    #image = tf.image.random_brightness(image, max_delta=0.5)
    #image = tf.clip_by_value(image, 0, 1)
    return image, label

train_ds = train_ds.map(augment_image, num_parallel_calls=AUTOTUNE)

train_count = tf.data.experimental.cardinality(train_ds).numpy()
test_count = tf.data.experimental.cardinality(test_ds).numpy()
print("train_count: {}, test_count: {}".format(train_count, test_count))

In [None]:
reload(mcae)

latent_size = 300

model = mcae.CAE(
    input_shape=(resolution, resolution, channels),
    latent_size=latent_size
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=[
        tf.keras.losses.MeanAbsoluteError(),
        tf.keras.losses.BinaryCrossentropy()
    ]
)
#model.build((None, *train_images[0].shape))
model.build((None, resolution, resolution, channels))
model.summary()

In [None]:
model.load_weights('/tmp/cae')

In [None]:
reload(cbu)
adcb = cbu.ADModelEvaluator(test_count, early_stopping=20)

In [None]:
tf.config.run_functions_eagerly(False)
batch_size = 128
results = model.fit(
    x=train_ds.batch(batch_size),
    epochs=1000,
    validation_data=test_ds.batch(batch_size),
    callbacks=[adcb],
    verbose=0
)

In [None]:
model.set_weights(adcb.best_weights)

In [None]:
print(results.history.keys())

if 'loss' in results.history.keys():
    plt.plot(results.history['loss'])
if 'mean_absolute_error' in results.history.keys():
    plt.plot(results.history['mean_absolute_error'])
if 'binary_crossentropy' in results.history.keys():
    plt.plot(results.history['binary_crossentropy'])
plt.title('losses')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss', 'mean_absolute_error', 'binary_crossentropy'], loc='upper right')
plt.show()

In [None]:
predictions = model.predict(
    x=test_ds.batch(batch_size)
)
print(predictions.shape)

In [None]:
n_rows = len(predictions)
n_cols = 2
greyscale = True

_, axarr = plt.subplots(n_rows, n_cols, figsize=(15, 15*n_rows/n_cols))
for idx, ((img, lbl), pre) in enumerate(zip(test_ds, predictions)):
    axarr[idx,0].set_title("{}: {}".format(
        "good" if lbl == 0 else "broken",
        tf.reduce_mean(tf.keras.losses.MSE(img, pre))
    ))
    axarr[idx,0].imshow(img, cmap=plt.cm.binary if greyscale else None)
    axarr[idx,1].imshow(pre, cmap=plt.cm.binary if greyscale else None)
plt.show()

In [None]:
model.save_weights('/tmp/cae')
!(ls -lah /tmp/cae*)