In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import sys
sys.path.append("..")

In [None]:
import numpy as np
from scipy.stats import norm

from MNIST_VAE import Hyper, MnistVae

In [None]:
from keras.datasets import mnist

# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

In [None]:
h = Hyper(epochs=50)
model = MnistVae(h)
        
history = model.fit(x_train, x_train,
        shuffle=True,
        epochs=h.epochs,
        batch_size=h.batch_size,
        validation_data=(x_test, x_test))

In [None]:
history.history

In [None]:
# display a 2D plot of the digit classes in the latent space
x_test_encoded = model.encoder.predict(x_test, batch_size=h.batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()

# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z, since the prior of the latent space is Gaussian
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        x_decoded = model.generator.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

In [None]:
encoding = model.encode(x_test[:3])
decoding = model.generate(encoding)

plt.figure(figsize=(3, 3))
plt.imshow(x_test[0].reshape((28, 28)), cmap='Greys_r')
plt.show()

plt.figure(figsize=(3, 3))
plt.imshow(decoding[0].reshape((28, 28)), cmap='Greys_r')
plt.show()

In [None]:
model.save('../models/mnist_test.h5')

In [None]:
model2 = MnistVae(h)
model2.load_weights('../models/mnist_test.h5')

In [None]:
encoding = model2.encode(x_test[:3])
decoding = model2.generate(encoding)

plt.figure(figsize=(3, 3))
plt.imshow(x_test[0].reshape((28, 28)), cmap='Greys_r')
plt.show()

plt.figure(figsize=(3, 3))
plt.imshow(decoding[0].reshape((28, 28)), cmap='Greys_r')
plt.show()