In [3]:
import tensorflow.keras as keras

In [4]:
# input_dims is an integer containing the dimensions of the model input
# hidden_layers is a list containing the number of nodes for each hidden layer in the encoder, respectively
# the hidden layers should be reversed for the decoder
# latent_dims is an integer containing the dimensions of the latent space representation
# encoder, decoder, auto = autoencoder(784, [128, 64], 32)

# encoder is the encoder model
# decoder is the decoder model
# auto is the full autoencoder model

def autoencoder(input_dims, hidden_layers, latent_dims):

  input_encoder = keras.layers.Input(shape=(input_dims,))

  input_encoded = input_encoder
  for n in hidden_layers:
    encoded = keras.layers.Dense(n, activation='relu')(input_encoded)
    input_encoded = encoded
  
  latent = keras.layers.Dense(latent_dims, activation='relu')(encoded)

  #     Encoder Model
  encoder = keras.models.Model(input_encoder, latent)
  
  input_decoded = keras.layers.Input(shape=(latent_dims,))
  
  prev = input_decoded
  for i,n in enumerate(hidden_layers[::-1]):
#     activation = 'relu' if i != len(hidden_layers) - 1 else 'sigmoid'
    decoded = keras.layers.Dense(n, activation='relu')(prev)
    prev = decoded
  decoded = keras.layers.Dense(input_dims, activation='sigmoid')(decoded)

  #     Decoder Model    
  decoder = keras.models.Model(input_decoded, decoded)
  input_auto = keras.layers.Input(shape=(input_dims,))
  encod_out = encoder(input_auto)
  decod_out = decoder(encod_out)
  auto = keras.models.Model(inputs=input_auto, outputs=decod_out)

  auto.compile(optimizer='adam', loss='binary_crossentropy')

  return encoder, decoder, auto

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# autoencoder = __import__('0-vanilla').autoencoder

(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((-1, 784))
x_test = x_test.reshape((-1, 784))
np.random.seed(0)
tf.random.set_seed(0)
# tf.set_random_seed(0)
encoder, decoder, auto = autoencoder(784, [128, 64], 32)
auto.fit(x_train, x_train, epochs=50,batch_size=256, shuffle=True,
                validation_data=(x_test, x_test))
encoded = encoder.predict(x_test[:10])
print(np.mean(encoded))
reconstructed = decoder.predict(encoded)

for i in range(10):
    ax = plt.subplot(2, 10, i + 1)
    ax.axis('off')
    plt.imshow(x_test[i].reshape((28, 28)))
    ax = plt.subplot(2, 10, i + 11)
    ax.axis('off')
    plt.imshow(reconstructed[i].reshape((28, 28)))
plt.show()