In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as st

In [2]:
class Encoder(tf.keras.Model):
    """
    Encoder q(z|x) as a diagonal MVN distribution
    Feed-Forward Net that learns the mean and variance params of a MVN
    that generates the latent code z conditioned on input x 
    """
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = tf.keras.layers.Dense(128, activation='relu')
        self.fc2 = tf.keras.layers.Dense(32, activation='relu')
        self.fc_mean = tf.keras.layers.Dense(latent_dim)
        self.fc_log_var = tf.keras.layers.Dense(latent_dim)
    
    def call(self, x):
        h1 = self.fc1(x)
        h = self.fc2(h1)
        mean = self.fc_mean(h)
        log_var = self.fc_log_var(h)
        return (mean, log_var)

In [3]:
class Decoder(tf.keras.Model):
    """
    Decoder p(x|z) as a diagnoal MVN distribution
    Feed-Forward Net that learns the parameters of the MVN that
    generate the data x from the latent code z
    """
    def __init__(self, data_dim):
        super(Encoder, self).__init__()
        self.fc1 = tf.keras.layers.Dense(32, activation='relu')
        self.fc2 = tf.keras.layers.Dense(128, activation='relu')
        self.fc_mean = tf.keras.layers.Dense(data_dim)
        self.fc_log_var = tf.keras.layers.Dense(data_dim)
    
    def call(self, z):
        h1 = self.fc1(z)
        h = self.fc2(h1)
        mean = self.fc_mean(h)
        log_var = self.fc_log_var(h)
        return (mean, log_var)

In [None]:
class VAE(tf.keras.Model):
    def __init__(self, latent_dim, data_dim):
        self.latent_dim = latent_dim
        self.data_dim = data_dim
        self.enc = Encoder(latent_dim)
        self.dec = Decoder(data_dim)
    
    def call(self, x):
        mean, log_var = self.enc(x)
        e = np.random.multivariate_normal(mean=np.zeros((self.latent_dim, 1)),
                                          cov=np.eye(self.latent_dim), size = x.shape[0])
        z = mean + e*np.exp(log_var)
        mean, log_var = self.dec(z)
        x_h = np.random.multivariate_normal(mean=np.zeros((self.data_dim, 1)),
                                          cov=np.eye(self.data_dim), size = x.shape[0])
        x_h = mean + x_h*np.exp(log_var)
        return (z, x_h)