In [None]:
import tensorflow as tf
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0


In [None]:
from matplotlib import pyplot as plt

columns = 4
rows = 4
fig = plt.figure(figsize=(10, 10))
for i in range(columns * rows):
    img = x_train[i]
    subplot = fig.add_subplot(rows, columns, i + 1)
    subplot.set_title("label {}".format(y_train[i]))
    plt.imshow(img, cmap="gray")
plt.show()


In [None]:
import tensorflow as tf
import numpy as np

x_shape = x_train.shape[1:]

encoder = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=x_shape),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(
            units=np.multiply.reduce(x_shape) * 8, activation=tf.keras.activations.relu
        ),
        tf.keras.layers.Dense(units=32, activation=tf.keras.activations.relu),
    ]
)

decoder = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape=encoder.output_shape[1:]),
        tf.keras.layers.Dense(
            units=np.multiply.reduce(x_shape) * 8, activation=tf.keras.activations.relu
        ),
        tf.keras.layers.Dense(
            units=np.multiply.reduce(x_shape), activation=tf.keras.activations.sigmoid
        ),
        tf.keras.layers.Reshape(target_shape=x_shape),
    ]
)

input = tf.keras.layers.Input(shape=x_shape)

autoencoder = tf.keras.Model(inputs=input, outputs=decoder(encoder(input)))

autoencoder.compile(
    loss=tf.keras.losses.Huber(),
    optimizer=tf.keras.optimizers.Adam(),
)

history = autoencoder.fit(
    x=x_train,
    y=x_train,
    epochs=5,
    batch_size=128,
    shuffle=True,
    validation_data=(x_test, x_test),
)


In [None]:
encoded_samples = encoder(x_test).numpy()
decoded_samples = decoder(encoded_samples).numpy()

n = 100
plt.figure(figsize=(n * 2, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i])
    plt.title("original")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_samples[i])
    plt.title("reconstructed")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
