# Quick training playground for autoencoder

Inspired by https://www.reddit.com/r/deeplearning/comments/jkci6f/exploring_mnist_latent_space/


# 1. Autoencoder

In [42]:
import numpy as np
import tensorflow as tf
from keras import models, layers, applications
from plotly import express as px

# Load dataset
(x_train, _), (x_test, _) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

In [47]:
# Encoder
input_img = layers.Input(shape=(28, 28, 1))


x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(input_img)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(24, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)

x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.Conv2D(156, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.Flatten()(x)

encoded = layers.Dense(2, activation='sigmoid')(x)


# Build the encoder
encoder = models.Model(input_img, encoded, name='encoder')
encoder.summary()

Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d_11 (Conv2D)          (None, 28, 28, 8)         80        
                                                                 
 batch_normalization_10 (Bat  (None, 28, 28, 8)        32        
 chNormalization)                                                
                                                                 
 max_pooling2d_6 (MaxPooling  (None, 14, 14, 8)        0         
 2D)                                                             
                                                                 
 conv2d_12 (Conv2D)          (None, 14, 14, 16)        1168      
                                                                 
 batch_normalization_11 (Bat  (None, 14, 14, 16)       64  

In [54]:
# Decoder
input_latent = layers.Input(shape=(2,))

x = layers.Dense(7 * 7 * 32, activation='relu')(input_latent)
x = layers.Reshape((7, 7, 32))(x)


# Decoder with skip connections
x = layers.Conv2D(24, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D((2, 2))(x)

# Add the first skip connection from the encoder
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = layers.BatchNormalization()(x)

x = layers.UpSampling2D((2, 2))(x)


# Add the second skip connection from the encoder
x = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

# Build the decoder
decoder = models.Model(input_latent, x, name='decoder')

# Build the autoencoder
decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_14 (InputLayer)       [(None, 2)]               0         
                                                                 
 dense_9 (Dense)             (None, 1568)              4704      
                                                                 
 reshape_6 (Reshape)         (None, 7, 7, 32)          0         
                                                                 
 conv2d_31 (Conv2D)          (None, 7, 7, 24)          6936      
                                                                 
 batch_normalization_25 (Bat  (None, 7, 7, 24)         96        
 chNormalization)                                                
                                                                 
 up_sampling2d_12 (UpSamplin  (None, 14, 14, 24)       0         
 g2D)                                                      

In [49]:
autoencoder = models.Model(input_img, decoder(encoder(input_img)))
autoencoder.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 encoder (Functional)        (None, 2)                 13018     
                                                                 
 decoder (Functional)        (None, 28, 28, 3)         8759      
                                                                 
Total params: 21,777
Trainable params: 21,549
Non-trainable params: 228
_________________________________________________________________


In [None]:
# cb = tf.keras.callbacks.EarlyStopping(
#     monitor="val_loss",
#     min_delta=0,
#     patience=20,
#     verbose=0,
#     mode="auto",
#     baseline=None,
#     restore_best_weights=True,
#     start_from_epoch=0,
# )

opt = tf.keras.optimizers.Adam(learning_rate=0.001)


# Compile the model
autoencoder.compile(optimizer=opt, loss='mean_squared_error', metrics=['accuracy'], )

# Train the autoencoder
history = autoencoder.fit(x_train, x_train,
                epochs=80,
                batch_size=32,
                shuffle=True,
                validation_data=(x_test, x_test),
                # callbacks=[cb],
                          )

In [None]:
import pickle
with open('trainHistoryDict.pkl', 'wb') as file_pi:
    pickle.dump(history.history, file_pi)

In [None]:
import pandas as pd
fig = px.line(pd.DataFrame({'train_loss' :  history.history['loss'], 'val_loss' : history.history['val_loss']}))
fig.write_image("train_val_loss.pdf")

# 2. Test encoder

In [None]:
preds_train = encoder.predict(x_train)
preds_test = encoder.predict(x_test)

In [None]:
px.scatter(pd.DataFrame(preds_test), x=0,y=1)

In [None]:
px.scatter(pd.DataFrame(preds_train), x=0,y=1)