# additional deep autoencoder example (keras)

Separating the encoder and decoder parts of a deep autocoder are not that obvious in keras (see issue https://github.com/benjaminirving/mlseminars-autoencoders/issues/1 ). Below is a simple example of how to extract the encoded part. 

You basically have to define the links between the layers that you are interested in so that keras knows the beginning and end layers (in future maybe keras will define an input_layer input under Model).

The code basically looks like this:

decoder1 = autoencoder.layers[-3]
decoder2 = autoencoder.layers[-2]
decoder3 = autoencoder.layers[-1]

decoder = Model(input=encoded_input, output=decoder3(decoder2(decoder1(encoded_input))))

and, below is a full working example (based on the autoencoder from https://blog.keras.io/building-autoencoders-in-keras.html).

In [None]:
from keras.layers import Input, Dense
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

encoding_dim = 32
# Model definition
input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)

encoded = Dense(encoding_dim, activation='relu')(encoded)

decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)

# Training data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
print(x_train.shape)
print(x_test.shape)

# Model compilation
autoencoder = Model(input=input_img, output=decoded)
autoencoder.summary()

encoder = Model(input=input_img, output=encoded)
# create a placeholder for an encoded (32-dimensional) input
encoded_input = Input(shape=(encoding_dim,))
# retrieve the last layer of the autoencoder model
# create the decoder model

decoder1 = autoencoder.layers[-3]
decoder2 = autoencoder.layers[-2]
decoder3 = autoencoder.layers[-1]

decoder = Model(input=encoded_input, output=decoder3(decoder2(decoder1(encoded_input))))

autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# Model training
autoencoder.fit(x_train, x_train,
                nb_epoch=20,
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))


# Testing

# encode and decode some digits
# note that we take them from the *test* set
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)

n = 10  # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()


(60000, 784)
(10000, 784)
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_3 (InputLayer)             (None, 784)           0                                            
____________________________________________________________________________________________________
dense_7 (Dense)                  (None, 128)           100480      input_3[0][0]                    
____________________________________________________________________________________________________
dense_8 (Dense)                  (None, 64)            8256        dense_7[0][0]                    
____________________________________________________________________________________________________
dense_9 (Dense)                  (None, 32)            2080        dense_8[0][0]                    
_________________________________________________________________