***Convolutional Autoencoder***  
MNIST dataset

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from keras.datasets import mnist
from keras.models import Model, Sequential, load_model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Flatten, Reshape
from keras import regularizers
from livelossplot import PlotLossesKeras

**Import dataset** 

In [None]:
(X_train, _), (X_test, _) = mnist.load_data()

max_value = float(X_train.max())
X_train = X_train.astype("float32")/max_value
X_test = X_test.astype("float32")/max_value

X_train = X_train.reshape((len(X_train), 28, 28, 1))
X_test = X_test.reshape((len(X_test), 28, 28, 1))

print(X_train.shape)
print(X_test.shape)

**Build model**

In [None]:
autoencoder = Sequential()

# Encoder layers
autoencoder.add(Conv2D(16, (3, 3), activation = "relu", padding = "same", 
    input_shape = X_train.shape[1:]))
autoencoder.add(MaxPooling2D((2, 2), padding = "same"))
autoencoder.add(Conv2D(8, (3, 3), activation = "relu", padding = "same"))
autoencoder.add(MaxPooling2D((2, 2), padding = "same"))
autoencoder.add(Conv2D(4, (3, 3), strides = (2, 2), activation = "relu", padding = "same")) # Conv2D(8, ...)

# padding = "same": the size of the feature map stays the same, e.g. first conv layer: 28x28

# Flatten encoding for visualization
autoencoder.add(Flatten())
autoencoder.add(Reshape((4, 4, 4))) # (4, 4, 8)

# Decoder layers
autoencoder.add(Conv2D(8, (3, 3), activation = "relu", padding = "same"))
autoencoder.add(UpSampling2D((2, 2)))
autoencoder.add(Conv2D(8, (3, 3), activation = "relu", padding = "same"))
autoencoder.add(UpSampling2D((2, 2)))
autoencoder.add(Conv2D(16, (3, 3), activation = "relu"))
autoencoder.add(UpSampling2D((2, 2)))
autoencoder.add(Conv2D(1, (3, 3), activation = "sigmoid", padding = "same"))

autoencoder.summary()

To extract the **encoder** model from the autoencoder, we’re going to use a slightly different approach than before. Rather than extracting the first 6 layers, we’re going to create a new Model with the same input as the autoencoder, but the output will be that of the flattening layer. As a side note, this is a very useful technique for grabbing submodels for things like transfer learning.  

The encoded image is a vector of length 128.

In [None]:
encoder = Model(
    inputs = autoencoder.input, 
    outputs = autoencoder.get_layer("flatten_1").output)

encoder.summary()

**Train model**

In [None]:
autoencoder.compile(
    optimizer = "adam",
    loss = "binary_crossentropy")

start_time = time.time()

epochs = 5

autoencoder.fit(
    X_train, X_train,
    epochs = epochs,
    batch_size = 128,
    validation_data = (X_test, X_test),
    callbacks = [PlotLossesKeras()])

end_time = time.time() - start_time
print(f"Training time: {end_time} seconds for {epochs} epochs")
print(f"Training time: {end_time/epochs} per epoch on average")

**Save model**

In [None]:
autoencoder.save("Models/conv_autoencoder.model")
encoder.save("Models/conv_autoencoder_encoder.model")

**Load model**

In [None]:
autoencoder = load_model("Models/conv_autoencoder.model")
encoder = load_model("Models/conv_autoencoder_encoder.model")

**Display output**

In [None]:
n = 10
np.random.seed(34)

encoded_imgs = encoder.predict(X_test)
decoded_imgs = autoencoder.predict(X_test)

plt.figure(figsize = (18, 4))

for i in range(n):
    j = np.random.randint(0, len(X_test))
    # original image
    ax = plt.subplot(3, n, i+1)
    ax.set_title(f"[{j}]")
    plt.imshow(X_test[j].reshape(28, 28))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    # encoded image
    ax = plt.subplot(3, n, i+1+n)
    plt.imshow(encoded_imgs[j].reshape(4, 16))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    # reconstructed image
    ax = plt.subplot(3, n, i+1+2*n)
    plt.imshow(decoded_imgs[j].reshape(28, 28))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()

In [None]:
n = 4
seed = lambda k: int(99/(2*k+1)*13) # arbitrary mapping from k to some seed

# k rows of images
for k in range(8):
    np.random.seed(seed(k))
    plt.figure(figsize = (18, 4))
    for i in range(n):
        j = np.random.randint(0, len(X_test))
        ax = plt.subplot(1, n, i+1)
        plt.imshow(encoded_imgs[j].reshape(4, 16))
        ax.set_title(f"[{j}]")
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()