In [None]:
import os
import keras
from keras import layers
from keras.datasets import mnist
from keras.utils import plot_model
from keras.models import load_model
from keras import regularizers
from keras.losses import mse 
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook

*Let's prepare out input data first. We will normalize all the values between 0 and 1 and we will flatten the 28x28 images into vectors of size 784.*

In [None]:
(x_train,y_train),(x_test,y_test) = mnist.load_data()

# normalizing the values between 0 and 1
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

# Converting to 1D Array (Vector respresentation)
x_train = x_train.reshape(-1, np.prod(x_train.shape[1:]))
x_test = x_test.reshape(-1, np.prod(x_test.shape[1:]))

In [None]:
# mapping inputs to our latent representations

original_dim = 28 * 28
intermediate_dim = 64
latent_dim = 2

inputs = keras.Input(shape=(original_dim,))
h = layers.Dense(intermediate_dim, activation='relu')(inputs)
z_mean = layers.Dense(latent_dim)(h)
z_log_sigma = layers.Dense(latent_dim)(h)

# using these parameters to sample new similar points from latent space

from keras import backend as K

def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=0.1)
    return z_mean + K.exp(z_log_sigma) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_sigma])

# mapping these sampled points back to reconstructed inputs

# Create encoder
encoder = keras.Model(inputs, [z_mean, z_log_sigma, z], name='encoder')

# Create decoder
latent_inputs = keras.Input(shape=(latent_dim,), name='z_sampling')
x = layers.Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = layers.Dense(original_dim, activation='sigmoid')(x)
decoder = keras.Model(latent_inputs, outputs, name='decoder')

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = keras.Model(inputs, outputs, name='vae_mlp')

# custom loss function: the sum of a reconstruction term, and the KL divergence regularization term.

reconstruction_loss = mse(inputs, outputs)
reconstruction_loss *= original_dim
kl_loss = 1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')

vae.fit(x_train,x_train,
        epochs=100,
        batch_size=256,
        shuffle = True,
        validation_data = (x_test,x_test))

vae.save('./models/vae.h5')
encoder.save('./models/vae_encoder.h5')
decoder.save('./models/vae_decoder.h5')

In [None]:
n = 10  # How many digits we will display
plt.figure(figsize=(6, 4))
for i in range(n):
    # Display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [None]:
fig, ax = plt.subplots(1,2)
ax[0].scatter(x=encoded_imgs[:,0],y=encoded_imgs[:,1], c = y_test, s =8, cmap='tab10')

def onclick(event):
#     global flag
    ix, iy = event.xdata, event.ydata
    latent_vector = np.array([[ix, iy]])
    
    decoded_img = decoder.predict(latent_vector)
    decoded_img = decoded_img.reshape(28, 28)
    ax[1].imshow(decoded_img, cmap='gray')
    plt.draw()

# motion_notify_event
cid = fig.canvas.mpl_connect('motion_notify_event', onclick)
plt.show()