In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# mnist = tf.keras.datasets.fashion_mnist
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0

latent_dim = 128

class Autoencoder(tf.keras.Model):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()

        self.encoder = tf.keras.Sequential([
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(latent_dim, activation='relu'),
        ])

        self.decoder = tf.keras.Sequential([
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(784, activation='sigmoid'),
            tf.keras.layers.Reshape((28, 28))
        ])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

autoencoder = Autoencoder(latent_dim)
autoencoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss='mae')

train_images_future = np.roll(train_images, -1, 0)[:-1]
train_images = train_images[:-1]

test_images_future = np.roll(test_images, -1, 0)[:-1]
test_images = test_images[:-1]

autoencoder.fit(train_images, train_images_future, batch_size=None, epochs=100, shuffle=False, validation_data=(test_images, test_images_future))

num = 8
# test = autoencoder(test_images)
test = autoencoder(train_images)
fig = plt.figure(figsize=(num*2, 4), tight_layout=True)
for i in range(num): plt.subplot(2, num, i+1); plt.imshow(train_images[i] * 255.0); plt.axis('off')
for i in range(num): plt.subplot(2, num, i+1+num); plt.imshow(test[i] * 255.0); plt.axis('off')

test2 = autoencoder(test)
test3 = autoencoder(test2)
fig = plt.figure(figsize=(num*2, 4), tight_layout=True)
for i in range(num): plt.subplot(2, num, i+1); plt.imshow(test2[i] * 255.0); plt.axis('off')
for i in range(num): plt.subplot(2, num, i+1+num); plt.imshow(test3[i] * 255.0); plt.axis('off')

plt.show()
