In [0]:
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)

In [0]:
import numpy as np
np.random.seed(7)

In [0]:
from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [0]:
image_shape = (X_train.shape[1], X_train.shape[2])
input_size = image_shape[0] * image_shape[1]
print('input size -',input_size)

In [0]:
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.

In [0]:
X_train = X_train.reshape((len(X_train), input_size))
X_test = X_test.reshape((len(X_test), input_size))

In [0]:
from keras.layers import Input, Dense
from keras.models import Sequential, Model

In [0]:
embedding_size = 32

In [0]:
encoder = Sequential(name='encoder')

encoder.add(Dense(128, activation='relu', input_shape=(input_size,), name='encoder_layer1'))
encoder.add(Dense(64, activation='relu', name='encoder_layer2'))
encoder.add(Dense(embedding_size, activation='relu', name='encoder_layer3'))

encoder.summary()

In [0]:
decoder = Sequential(name='decoder')

decoder.add(Dense(64, activation='relu', input_shape=(embedding_size,), name='decoder_layer1'))
decoder.add(Dense(128, activation='relu', name='decoder_layer2'))
decoder.add(Dense(input_size, activation='sigmoid', name='decoder_layer3'))

decoder.summary()

In [0]:
autoencoder = Sequential(name='autoencoder')
autoencoder.add(encoder)
autoencoder.add(decoder)
autoencoder.summary()

In [0]:
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

In [0]:
epochs = 100
batch_size = 256

In [0]:
history = autoencoder.fit(X_train, X_train,
                          epochs=epochs,
                          batch_size=batch_size,
                          shuffle=True,
                          validation_data=(X_test, X_test))

In [0]:
decoded_images = autoencoder.predict(X_test)

In [0]:
import matplotlib.pyplot as plot

def show_images(input_images, processed_images, image_shape, number_of_samples=10): 
  if(processed_images is None):
    number_of_rows = 1 
  else:
    number_of_rows = 2

  plot.figure(figsize=(number_of_samples*number_of_rows, 4))

  for sample_index in range(number_of_samples):
    plot_index = sample_index + 1
    ax = plot.subplot(number_of_rows, number_of_samples, plot_index)
    plot.imshow(input_images[sample_index].reshape(image_shape))
    plot.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    if(processed_images is None):
      continue

    plot_index = sample_index + 1 + number_of_samples
    ax = plot.subplot(number_of_rows, number_of_samples, plot_index)
    plot.imshow(processed_images[sample_index].reshape(image_shape))
    plot.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
  plot.show()

In [0]:
show_images(X_test, decoded_images, image_shape)