In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from keras.layers import *
from keras.models import Model
from keras import objectives
from lasagnekit.datasets.mnist import MNIST
from lasagnekit.misc.plot_weights import dispims_color

In [None]:
data = MNIST()
data.load()


In [None]:
batch_size = 128

In [None]:
d = 28 * 28
latent_dim = 10
epsilon_std = 0.01
x = Input(shape=(d,))
h = Dense(256, activation='relu')(x)
h = Dense(256, activation='relu')(h)
z_mean = Dense(latent_dim)(h)
z_log_sigma = Dense(latent_dim)(h)

In [None]:
def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(z_mean.shape[0], latent_dim),
                              mean=0., std=epsilon_std)
    return z_mean + K.exp(z_log_sigma) * epsilon

In [None]:
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma])

In [None]:

# we instantiate these layers separately so as to reuse them later
decoder_h = Dense(256, activation='relu')
decoder_mean = Dense(d, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

In [None]:
# end-to-end autoencoder
vae = Model(x, x_decoded_mean)

# encoder, from inputs to latent space
encoder = Model(x, z_mean)

# generator, from latent space to reconstructed inputs
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

In [None]:
def vae_loss(x, x_decoded_mean):
    xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

vae.compile(optimizer='adam', loss=vae_loss)

In [None]:
vae.fit(data.X, data.X,
        shuffle=True,
        nb_epoch=10,
        batch_size=batch_size)

In [None]:
samples = generator.predict(np.random.normal(0, 0.1, size=(100, latent_dim)))
img = dispims_color(
    samples.reshape((samples.shape[0], 28, 28, 1)) * np.ones((1, 1, 1, 3)))
fig = plt.figure(figsize=(15, 15))
plt.imshow(img, interpolation='none')