In [None]:
from math import sqrt

import matplotlib.pyplot as plt
from matplotlib.offsetbox import TextArea, AnnotationBbox, OffsetImage
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPool2D, Dense, Flatten, Reshape
from tensorflow.keras.models import Model

In [None]:
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

In [None]:
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

In [None]:
train_images = train_images / 255.0
test_images = test_images / 255.0

train_images = np.expand_dims(train_images, -1)
test_images = np.expand_dims(test_images, -1)

In [None]:
def create_encoder(input_shape, encoding_dim):
    inputs = Input(shape=input_shape)
    conv1 = Conv2D(32, 3, padding="same", activation="relu")(inputs)
    pool1 = MaxPool2D()(conv1)
    conv2 = Conv2D(64, 3, padding="same", activation="relu")(pool1)
    pool2 = MaxPool2D()(conv2)
    flatten = Flatten()(pool2)
    dense = Dense(encoding_dim, activation="softmax")(flatten)
    
    encoder = Model(inputs, dense)
    encoder.summary()
    
    return encoder

In [None]:
def create_decoder(encoding_dim):
    if not sqrt(encoding_dim).is_integer():
        raise ValueError("Encoding dim must be a perfect square.")
    
    inputs = Input(shape=encoding_dim)
    reshape = Reshape((int(sqrt(encoding_dim)), int(sqrt(encoding_dim)), 1))(inputs)
    conv1 = Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(reshape)
    conv2 = Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(conv1)
    conv3 = Conv2D(1, 3, padding="same", activation="sigmoid")(conv2)
    
    decoder = Model(inputs, conv3)
    decoder.summary()
    
    return decoder

In [None]:
input_shape = train_images[0].shape
encoding_dim = 49

In [None]:
encoder = create_encoder(input_shape, encoding_dim)
decoder = create_decoder(encoding_dim)

inputs = Input(shape=input_shape)
encoded = encoder(inputs)
decoded = decoder(encoded)

autoencoder = Model(inputs, decoded)
autoencoder.summary()

autoencoder.compile(optimizer="adam", loss="mse")

In [None]:
autoencoder.fit(train_images, train_images, batch_size=256, epochs=250)

In [None]:
test_loss = autoencoder.evaluate(test_images, test_images, batch_size=256)

In [None]:
visual = test_images[:5]
preds = autoencoder.predict(visual)
plt.figure(figsize=(10,10))
for i in range(5):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(test_images[i].squeeze(), cmap=plt.cm.binary)
    plt.xlabel("Original")
for i in range(5):
    plt.subplot(5,5,i+6)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(preds[i].squeeze(), cmap=plt.cm.binary)
    plt.xlabel("Reconstructed")
plt.show()

In [None]:
def plot_latent(mode, count):
    idx = np.random.choice(len(test_images), count)
    inputs = test_images[idx]
    fig, ax = plt.subplots(figsize=(10, 7))
    ax.set_title("Autoencoder Latent Space")
    coords = encoder.predict(inputs)[:, :2]
    
    if mode == 'imgs':
        for image, (x, y) in zip(inputs, coords):
            im = OffsetImage(image.reshape(28, 28), zoom=1, cmap='gray')
            ab = AnnotationBbox(im, (x, y), xycoords='data', frameon=False)
            ax.add_artist(ab)
        ax.update_datalim(coords)
        ax.autoscale()
    elif mode == 'dots':
        classes = test_labels[idx]
        plt.scatter(coords[:, 0], coords[:, 1], c=classes)
        plt.colorbar()
        for i in range(10):
            class_center = np.mean(coords[classes == i], axis=0)
            text = TextArea('{} ({})'.format(class_names[i], i))
            ab = AnnotationBbox(text, class_center, xycoords='data', frameon=True)
            ax.add_artist(ab)
    plt.show()

In [None]:
plot_latent("dots", 10000)

In [None]:
plot_latent("imgs", 1000)