# Quick training playground for autoencoder

Inspired by https://www.reddit.com/r/deeplearning/comments/jkci6f/exploring_mnist_latent_space/


# 1. Autoencoder

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import models, layers, applications, backend as K
from tensorflow.keras.losses import MeanSquaredError, KLDivergence
from plotly import express as px

# Load dataset
(x_train, _), (x_test, _) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

In [2]:
# Encoder

input_img = layers.Input(shape=(28, 28, 1))

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.2)(x)

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)


x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.Flatten()(x)

x = layers.Dense(128, activation='relu')(x)
x = layers.BatchNormalization()(x)

encoded = layers.Dense(2, activation='sigmoid')(x)

# Build the encoder
encoder = models.Model(input_img, encoded, name='encoder')
encoder.summary()

2023-08-04 11:27:32.950312: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 28, 28, 32)        320       
                                                                 
 batch_normalization (BatchN  (None, 28, 28, 32)       128       
 ormalization)                                                   
                                                                 
 dropout (Dropout)           (None, 28, 28, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 28, 28, 32)        9248      
                                                                 
 batch_normalization_1 (Batc  (None, 28, 28, 32)       128       
 hNormalization)                                           

In [4]:
# Decoder
input_latent = layers.Input(shape=(2,))

x = layers.Dense(7 * 7 * 128, activation='relu')(input_latent)

x = layers.Reshape((7, 7, 128))(x)

# Decoder with skip connections
x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)

# Add the first skip connection from the encoder
x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)


x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)


# Add the second skip connection from the encoder
x = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

# Build the decoder
decoder = models.Model(input_latent, x, name='decoder')

# Build the autoencoder
decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 2)]               0         
                                                                 
 dense_5 (Dense)             (None, 6272)              18816     
                                                                 
 reshape_1 (Reshape)         (None, 7, 7, 128)         0         
                                                                 
 conv2d_10 (Conv2D)          (None, 7, 7, 128)         147584    
                                                                 
 batch_normalization_11 (Bat  (None, 7, 7, 128)        512       
 chNormalization)                                                
                                                                 
 up_sampling2d_2 (UpSampling  (None, 14, 14, 128)      0         
 2D)                                                       

                                                                 
 batch_normalization_12 (Bat  (None, 14, 14, 64)       256       
 chNormalization)                                                
                                                                 
 conv2d_12 (Conv2D)          (None, 14, 14, 64)        36928     
                                                                 
 batch_normalization_13 (Bat  (None, 14, 14, 64)       256       
 chNormalization)                                                
                                                                 
 up_sampling2d_3 (UpSampling  (None, 28, 28, 64)       0         
 2D)                                                             
                                                                 
 conv2d_13 (Conv2D)          (None, 28, 28, 32)        18464     
                                                                 
 batch_normalization_14 (Bat  (None, 28, 28, 32)       128       
 chNormali

In [None]:
autoencoder = models.Model(inputs = input_img, outputs = [decoder(encoder(input_img)), encoder(input_img)])
autoencoder.summary()

In [None]:
autoencoder.output

In [None]:
def density_loss(z_true, z_pred, radius=0.05): 
    distances = tf.norm(z_pred[:, tf.newaxis, :] - z_pred, axis=2)
    num_neighbors = tf.reduce_sum(tf.cast(distances < radius, tf.float32), axis=1)
    density = tf.reduce_mean(num_neighbors)
    return density

In [None]:
def mse_loss(x_true, x_pred): 
    mse = tf.reduce_mean(tf.square(x_true - x_pred))
    return mse

In [None]:
# Regular optimizer
opt = tf.keras.optimizers.Adam(learning_rate=0.001)

# Compile using two losses
autoencoder.compile(optimizer=opt, loss=[mse_loss, density_loss], loss_weights=[1.0, 0.001])

# Callbacks
early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoint_autoencoder.h5',
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

# Train the model
history = autoencoder.fit(
    x_train,
    [x_train, x_train],
    epochs=75,
    batch_size=256,
    validation_data=(x_test, [x_test, x_test]),
    callbacks = [checkpoint_cb, early_stop_cb],
)

In [None]:
d = {k:v for k,v in history.history.items()}
fig = px.line(pd.DataFrame(d))
fig

# 2. Evaluate encoder mapping

In [None]:
preds_train = encoder.predict(x_train)
preds_test = encoder.predict(x_test)

In [None]:
fig = px.scatter(pd.DataFrame(preds_test), x=0,y=1)
fig.update_yaxes(scaleanchor = "x", scaleratio = 1)
fig

In [None]:
fig = px.scatter(pd.DataFrame(preds_train), x=0,y=1)
fig.update_yaxes(scaleanchor = "x", scaleratio = 1)
fig

# 3. Save best

In [None]:
encoder.save('saved_models/best_encoder.h5')
decoder.save('saved_models/best_decoder.h5')