In [1]:
'''Example of VAE on MNIST dataset using MLP
The VAE has a modular design. The encoder, decoder and VAE
are 3 models that share weights. After training the VAE model,
the encoder can be used to  generate latent vectors.
The decoder can be used to generate MNIST digits by sampling the
latent vector from a Gaussian distribution with mean=0 and std=1.
# Reference
[1] Kingma, Diederik P., and Max Welling.
"Auto-encoding variational bayes."
https://arxiv.org/abs/1312.6114
'''
%matplotlib notebook

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import matplotlib.pyplot as plt
from keras.callbacks import ReduceLROnPlateau, EarlyStopping

from keras.datasets import mnist
from abyss_deep_learning.keras.autoencoder import VAE, config_gpu
config_gpu(0)


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# MNIST dataset
(x_train, __y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1]

x_train = x_train[..., np.newaxis].astype('float32') / 255
x_val = x_test[0::2, ..., np.newaxis].astype('float32') / 255
x_test = x_test[1::2, ..., np.newaxis].astype('float32') / 255
y_val = y_test[0::2, ...]
y_test = y_test[1::2, ...]

# network parameters
batch_size = 20
epochs = 50
print(x_val.shape)

(5000, 28, 28, 1)


In [3]:
vae = VAE()
vae.create_model()

  self.vae.compile(optimizer='nadam')


In [4]:
vae.vae.fit(
    x=x_train, validation_data=(x_val, None),
    batch_size=batch_size, epochs=1000, verbose=1,
    callbacks=[
        ReduceLROnPlateau(monitor='val_loss', patience=10, verbose=1),
        EarlyStopping(patience=30, verbose=1)])

Train on 60000 samples, validate on 5000 samples
Epoch 1/1000

KeyboardInterrupt: 

In [None]:
def plot_results(models,
                 data,
                 batch_size=16,
                 model_name="vae_mnist"):
    """Plots labels and MNIST digits as function of 2-dim latent vector
    # Arguments:
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        model_name (string): which model is using this function
    """

    encoder, decoder = models
    x_test, y_test = data
    os.makedirs(model_name, exist_ok=True)

    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")

    # display a 30x30 2D manifold of digits
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-8, 8, n)
    grid_y = np.linspace(-8, 8, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    plt.imshow(figure)
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")


In [None]:
plot_results((vae.encoder, vae.decoder), (x_test, y_test), batch_size=batch_size)