In [None]:
%pip install ipywidgets
%pip install tensorflow_datasets

In [8]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds

def load_mnist_dataset():
    def preprocess(image, label):
        image = tf.pad(image, [[2, 2], [2, 2], [0, 0]])
        image = tf.cast(image, tf.float32) / 255.0
        return image, image
    (ds_train, ds_test), ds_info = tfds.load(
        'mnist',
        split=['train', 'test'],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )
    ds_train = ds_train.map(
        preprocess, num_parallel_calls=tf.data.AUTOTUNE).batch(256).prefetch(tf.data.AUTOTUNE)
    ds_test = ds_test.map(
        preprocess, num_parallel_calls=tf.data.AUTOTUNE).batch(256).prefetch(tf.data.AUTOTUNE)
    return ds_train, ds_test

ds_train, ds_test = load_mnist_dataset()


In [None]:
import tensorflow as tf
import numpy as np
from train_strategies import train_until_improvement_treshold

input_shape = ds_train.element_spec[0].shape[1:]

latent_dimensions = 8

input = tf.keras.layers.Input(shape=input_shape)

encoder = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=input_shape),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(
            units=np.multiply.reduce(input_shape) * 8, activation=tf.keras.activations.relu
        ),
        tf.keras.layers.Dense(units=latent_dimensions, activation=tf.keras.activations.relu),
    ]
)

decoder = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=encoder.output_shape[1:]),
        tf.keras.layers.Dense(
            units=np.multiply.reduce(input_shape) * 8, activation=tf.keras.activations.relu
        ),
        tf.keras.layers.Dense(
            units=np.multiply.reduce(input_shape), activation=tf.keras.activations.relu
        ),
        tf.keras.layers.Reshape(target_shape=input_shape),
    ]
)

output = decoder(encoder(input))

model = tf.keras.Model(inputs=input, outputs=output)

model.compile(
    loss=tf.keras.losses.Huber(),
    optimizer=tf.keras.optimizers.Adam(),
)

model.summary()

def fit():
    return model.fit(
        ds_train,
        validation_data=ds_test,
        epochs=1,
        batch_size=128,
        shuffle=True,
    )

train_until_improvement_treshold(fit)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

n = 100

test_samples = tf.convert_to_tensor(list(ds_test.unbatch().take(n).map(lambda x, y: x)))
encoded_samples = encoder(test_samples)
decoded_samples = decoder(encoded_samples)

plt.figure(figsize=(n * 2, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(test_samples[i])
    plt.title("original")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_samples[i])
    plt.title("reconstructed")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
