In [14]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Reshape, Conv2DTranspose
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers

In [15]:
# z ~ N(0, I) = p(z)
# p(x|z) = decoder
# x ~ p(x|z)

# encoder(x) = q(z|x) \approx p(z|x)
# log p(x) >= E_{z~q(z|x)}[-log q(z|x) + log p(x, z)]
#           = -KL(q(z|x) || p(z)) + E_{z~q(z|x)}[log p(x|z)] (ELBO)

The first two lines represent the generative model that underlies the derivation of the algorithm. The modeling assumptions is that the data examples are generated according to the following process: we first sample a latent variable z (distributed by a simple distribution, such as a zero mean isotropic Gaussian). This is the prior distribution. The sample z is then transformed by some function, which here we are denoting as decoder which paramterizes a distribution from which we can sample a data point x. The algorithm requires us to also define an encoder or recognition network, that you often see denoted as q. This encoder network takes a data example x as input and outputs a distribution over the latent variable z. This distribution is an approximation of the true posterior distribution of our model p(z|x). The objective that we want to maximize is the ELBO. This objective can be written as the sum of two terms: the first is the negative KL divergence between the prior p(z) and the posterior q(z|x). The second term is the expected log likelihood under the approximate posterior distribution of the encoder.

In [16]:
latent_size=2
event_shape=(28,28,1)

encoder = Sequential([
    Conv2D(8, (5, 5), strides=2, activation='tanh', input_shape=event_shape),
    Conv2D(8, (5, 5), strides=2, activation='tanh'),
    Flatten(),
    Dense(64, activation='tanh'),
    Dense(2*latent_size),
    tfpl.DistributionLambda(lambda t: tfd.MultivariateNormalDiag(
        loc=t[..., :latent_size], scale_diag=tf.math.exp(t[..., latent_size:])))
], name='encoder')

The last dense layer has 4 units, which is what is required to parameterize a two-dimensional Gaussian distribution. The final layer is a `DistributionLambda` layer which returns a multivariate normal diag distribution of the approximate posterior. You can see that for the keyword arguments loc and scale_diag are receiving slices of the input tensor from the previous dense layer. Overall you can see how this encoder network is progressively compressing the input through the layers until it has encoded into a two-dimensional latent variable.

If we pass in a batch of data examples of batch size 16, then the encoder network returns a multivariate normal diag distribution object with batch shape 16 and an event shape of 2.

In [17]:
decoder = Sequential([
    Dense(64, activation='tanh', input_shape=(latent_size,)),
    Dense(128, activation='tanh'),
    Reshape((4, 4, 8)),
    Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),
    Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),
    Conv2D(1, (3, 3), padding='SAME'),
    Flatten(),
    tfpl.IndependentBernoulli(event_shape)
], name='decoder')

The decoder network is progressively expanding its inputs through the layers. The input to the decoder is a length 2 vector, which will be sample of the latent variable. The layers are roughly the inverse of what we have in the encoder. Up to the last Conv2D (inclusive) we have up sample the input back to the spatial dimensions of the data. We then use a Conv2D layer to reduce the filters down to 1. It does not uses any activation function so the activations can have any real value. Finally, we use an Independent Bernoulli distribution with the same event_space of the data. This distribution will take the input tensor from the flatten layer as the logits to parameterize the Bernoulli distribution.

In [18]:
decoder(tf.random.normal([16, latent_size]))

<tfp.distributions.Independent 'decoder_independent_bernoulli_IndependentBernoulli_Independentdecoder_independent_bernoulli_IndependentBernoulli_Bernoulli' batch_shape=[16] event_shape=[28, 28, 1] dtype=float32>

The decoder outputs a batched independent Bernoulli distribution with batch shape equal to 16 and event shape equal to (28, 28, 1), which is the data event_shape.

In [20]:
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
prior

<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[] event_shape=[2] dtype=float32>

In [22]:
def loss_fn(x_true, approx_posterior, x_pred, prior_dist):
    return tf.reduce_mean(tfd.kl_divergence(approx_posterior, prior_dist)
                         - x_pred.log_prob(x_true))

Our loss function is the negative of the ELBO. We want to maximize the ELBO, meaning that we want to minimize our loss function.

The loss function reveives a batch of data examples `x_true`. The `approx` posterior is the output of the encoder, which is a multivariate normal diag distribution. `x_pred` is the output of the decoder which is an independent Bernoulli distribution. Finally, we are also passing in the prior distribution.

When computing the loss we are in fact computing the KL divergence between the posterior and the prior and we are subtracting the likelihood term which is the negative log probability of the data example x_true according to the decoder prediction x_pred. It returns the mean over the batch. 

In [23]:
def loss_fn(x_true, approx_posterior, x_pred, prior_dist):
    reconstruction_loss = -x_pred.log_prob(x_true)
    approx_posterior_sample = approx_posterior.sample()
    kl_approx = (approx_posterior.log_prob(approx_posterior_sample)
                - prior_dist.log_prob(approx_posterior_sample))
    return tf.reduce_mean(kl_approx + reconstruction_loss)

We can also compute the KL divergence using Monte Carlo samples, instead of doing it analytically using the `kl_divergence` function. Depending on the choice of prior and posterior the MC sample can be required to compute the KL divergence. The `reconstruction_loss` is the likelihood term. Then we sample from the posterior (in this case we are only sampling once). The KL divergence is computed by calculating the log probability of the posterior using the sample from the posterior and subtracting the result by the prior distribution log probability also using the sample from the posterior. The difference between these two log probabilities gives the Monte Carlo estimate. This estimate is then added to the `reconstruction_loss` and returned as the loss function.

In [24]:
@tf.function
def get_loss_and_grads(x):
    with tf.GradientTape() as tape:
        approx_posterior = encoder(x)
        approx_posterior_sample = approx_posterior.sample()
        x_pred = decoder(approx_posterior_sample)
        current_loss = loss_fn(x, approx_posterior, x_pred, prior)
    grads = tape.gradient(current_loss, encoder.training_variables + decoder.training_variables)
    return current_loss, grads

This function takes a batch of data examples x as input and executes the network computations inside the GradienTape context. First, we compute the approx_posterior distribution by passing x through the encoder. We then compute the output distribution by taking a sample from the approximate posterior and feeding it to the decoder. So `x_pred` is a Bernoulli distribution object that is trying to reconstruct the data input x. The gradients are computed by calling the `tape.gradient` and the variables that we want to get gradients for are the collection of trainable variables from encoder and decoder networks.

In [None]:
opt = tf.keras.optimizers.Adam()
num_epochs = 10
for epoch in range(num_epochs):
    for train_batch in train_data:
        loss, grads = get_loss_and_grads(train_batch)
        opt.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))

In [None]:
def vae(inputs):
    approx_posterior = encoder(inputs)
    decoded = decoder(approx_posterior.sample())
    return decoded.sample()

reconstruction = vae(x_sample)

We can define the end-to-end bottleneck architecture. The function above takes a batch of data examples as input and passes them through the encoder, samples from the approx posterior and passes the samples through the decoder, before sampling again from the output distribution returned by the decoder. This sample should be an approximate reconstruction of the original data input. So passing it `x_sample` it will return a tensor with the same shape as the input.