# Variational Autoencoder

Variational Autoencoders (VAEs) are dimensionality reduction devices that code inputs into the latent parameters of a statistical process that supposedly gave rise to them. In practice, this is typically a normal distribution. The decoding involves sampling from the statistical distribution with a given latent parameter and chasing the result through a decoder network.

The difference to classical autoencoders is that variational autoencoders map an input to a function (the distribution) and classical autoencoders map an input to a vector. 

c.f. this exellent post https://github.com/yaniv256/VAEs-in-Economics/blob/master/Notebooks/One_Dimensional_VAE_Workshop.ipynb

which uses Chapter 8 in http://faculty.neu.edu.cn/yury/AAI/Textbook/Deep%20Learning%20with%20Python.pdf

In [12]:
import keras
from keras import layers
from keras import backend as K 
from keras.models import Model 
import numpy as np
img_shape = (28, 28, 1)
batch_size = 16
latent_dim = 2

import tensorflow as tf
tf.config.experimental_run_functions_eagerly(True)

"""
Encoder
"""
input_img = keras.Input(shape=img_shape)

x = layers.Conv2D(32, 3, padding='same', activation='relu')(input_img)
x = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x) 
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x) 
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x) 

shape_before_flattening = K.int_shape(x)

x = layers.Flatten()(x)
x = layers.Dense(32, activation='relu')(x)

# encoder output (latent parameters)
z_mean = layers.Dense(latent_dim)(x) 
z_log_var = layers.Dense(latent_dim)(x)


"""
Sampling
"""
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
    mean=0., stddev=1.) 
    return z_mean + K.exp(z_log_var) * epsilon

# sample from latent distribution
z = layers.Lambda(sampling)([z_mean, z_log_var])


"""
Decoder
"""
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]), activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)
x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)

decoder = Model(decoder_input, x)

# decoder output
z_decoded = decoder(z)


"""
Custom Loss
"""

class CustomVariationalLayer(keras.layers.Layer):
    
    def vae_loss(self, x, z_decoded): 
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        
        # reconstruction loss
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded) 
        
        # regularization loss
        kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) 
        
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded) 
        self.add_loss(loss, inputs=inputs) 
        return x

# loss 
y = CustomVariationalLayer()([input_img, z_decoded])


"""
Training (on MNIST)
"""

from keras.datasets import mnist

# compile model (input; loss)
vae = Model(input_img, y) 
vae.compile(optimizer='rmsprop', loss=None) 
vae.summary()

Instructions for updating:
Use `tf.config.run_functions_eagerly` instead of the experimental version.
Model: "functional_13"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 28, 28, 32)   320         input_9[0][0]                    
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 14, 14, 64)   18496       conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 14, 14, 64)   36928       conv2d_21[0][0]   

In [13]:
(x_train, _), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255. 
x_train = x_train.reshape(x_train.shape + (1,)) 

x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

vae.fit(x=x_train, y=None, shuffle=True,
        epochs=10,
        batch_size=batch_size,
        validation_data=(x_test, None))



Epoch 1/10

KeyboardInterrupt: 