In [None]:
import tensorflow as tf
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

def mnist_to_convolutional(x):
    return np.reshape(x, (-1, 28, 28, 1)).astype(np.float32) / 255.0

x_train = mnist_to_convolutional(x_train)
x_test = mnist_to_convolutional(x_test)


In [None]:
import tensorflow as tf

x_shape = x_train.shape[1:]

encoder = tf.keras.Sequential([
    tf.keras.layers.Input(shape=x_shape),
    tf.keras.layers.Conv2D(
        filters=16,
        kernel_size=3,
        activation=tf.keras.activations.relu,
        padding="same",
        strides=1,
    ),
    tf.keras.layers.Conv2D(
        filters=32,
        kernel_size=5,
        activation=tf.keras.activations.relu,
        padding="same",
        strides=2,
    ),
    tf.keras.layers.Conv2D(
        filters=8,
        kernel_size=5,
        activation=tf.keras.activations.relu,
        padding="same",
        strides=2,
    ),
])

decoder = tf.keras.Sequential([
    tf.keras.layers.Input(shape=encoder.output_shape[1:]),
    tf.keras.layers.Conv2DTranspose(
        filters=8,
        kernel_size=5,
        activation=tf.keras.activations.relu,
        padding="same",
        strides=2,
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=32,
        kernel_size=5,
        activation=tf.keras.activations.relu,
        padding="same",
        strides=2,
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=16,
        kernel_size=3,
        activation=tf.keras.activations.relu,
        padding="same",
        strides=1,
    ),
    tf.keras.layers.Conv2D(
        filters=1,
        kernel_size=2,
        activation=tf.keras.activations.sigmoid,
        padding="same",
        strides=1
    ),
])

input = tf.keras.layers.Input(shape=x_shape)

autoencoder = tf.keras.Model(inputs=input, outputs=decoder(encoder(input)))

autoencoder.compile(
    loss=tf.keras.losses.Huber(),
    optimizer=tf.keras.optimizers.Adam(),
)

history = autoencoder.fit(
    x=x_train,
    y=x_train,
    epochs=5,
    batch_size=32,
    shuffle=True,
    validation_data=(x_test, x_test),
)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def deparallelized(f):
    return lambda x: np.array([f(np.array([x_i]))[0] for x_i in x])

encoded_samples = deparallelized(encoder)(x_test[:100])
decoded_samples = deparallelized(decoder)(encoded_samples)

n = encoded_samples.shape[0]
plt.figure(figsize=(n * 2, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i])
    plt.title("original")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_samples[i])
    plt.title("reconstructed")
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
