In [5]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, MaxPool2D
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers

In [7]:
model = Sequential([
    tfpl.Convolution2DReparameterization(16, 
                                         [3, 3], 
                                         activation='relu', 
                                         input_shape=(28, 28, 1),
                                         kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
                                         kernel_prior_fn=tfpl.default_multivariate_normal_fn),
    MaxPool2D(3),
    Flatten(),
    tfpl.DenseReparameterization(tfpl.OneHotCategorical.params_size(10)),
    tfpl.OneHotCategorical(10)
])

The callable function `default_mean_field_normal_fn` returns an independent normal distribution with trainable mean and standard deviation parameters. So the default posterior distribution for the reparameterization layers is similar to example we saw about the dense variational layers. The prior distribution is defined by the `default_multivariate_normal_fn` function, which is again a function defined in the tfp layers module. This function returns a spherical Gaussian with zero mean and standard deviation equal to 1 for each component (the same for the dense variational example).

We could change these defaults and there is a required specification for any callable that we pass to either the kernel posterior or kernel prior function keyword arguments.

In [8]:
def custom_multivariate_normal(dtype, shape, name, trainable, add_variable_fn):
    normal = tfd.Normal(loc=tf.zeros(shape, dtype), scale=2*tf.ones(shape, dtype))
    batch_ndims = tf.size(normal.batch_shape_tensor())
    return tfd.Independent(normal, reinterpreted_batch_ndims=batch_ndims)

The callable should, as before, return a distribution object. This is a simple example of a function that could be used to change the standard deviation of the prior distribution. The default multivariate normal function returns a spherical gaussian with standard deviation of 1. Here we changed the prior to have a larger standard deviation of 2. We are also using the Independent distribution to make sure that the dimensions are part of the event space of the distribution.

Now, we can use it on our model.

In [None]:
model = Sequential([
    tfpl.Convolution2DReparameterization(16, 
                                         [3, 3], 
                                         activation='relu', 
                                         input_shape=(28, 28, 1),
                                         kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
                                         kernel_prior_fn=tfpl.custom_multivariate_normal_fn),
    MaxPool2D(3),
    Flatten(),
    tfpl.DenseReparameterization(tfpl.OneHotCategorical.params_size(10)),
    tfpl.OneHotCategorical(10)

The bias parameters also have prior and posterior function keyword arguments.

In [None]:
model = Sequential([
    tfpl.Convolution2DReparameterization(16, 
                                         [3, 3], 
                                         activation='relu', 
                                         input_shape=(28, 28, 1),
                                         kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
                                         kernel_prior_fn=tfpl.default_multivariate_normal_fn
                                         bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=True),
                                         bias_prior_fn=None),
    MaxPool2D(3),
    Flatten(),
    tfpl.DenseReparameterization(tfpl.OneHotCategorical.params_size(10)),
    tfpl.OneHotCategorical(10)
])

On the `bias_posterior_fn` we defined the `default_mean_field_normal_fn` but with the `is_singular` keyword as true. This way the bias point becomes a point estimate, instead of being represented by a distribution. This means that the callable that is returned is a deterministic distribution object, which is not a distribution in reality, it just returns a value given by its loc parameter and it is defined to have probability 1 at that value and 0 elsewhere.

So the bias parameter is being learned in the same way as regular convolutional layers but it is fixed to a single value. This means that we are not using a prior distribution for the bias and so the `bias_prior_fn` is set to none.

In [10]:
model = Sequential([
    tfpl.Convolution2DReparameterization(16, 
                                         [3, 3], 
                                         activation='relu', 
                                         input_shape=(28, 28, 1),
                                         kernel_posterior_tensor_fn=tfd.Distribution.sample,
                                         kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
                                         kernel_prior_fn=tfpl.default_multivariate_normal_fn,
                                         bias_posterior_tensor_fn=tfd.Distribution.sample,
                                         bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=True),
                                         bias_prior_fn=None),
    MaxPool2D(3),
    Flatten(),
    tfpl.DenseReparameterization(tfpl.OneHotCategorical.params_size(10)),
    tfpl.OneHotCategorical(10)
])

Here we also define the keyword argument `kernel_posterior_tensor_fn`. This function is used to compute the forward pass of the network as well as approximating the KL divergence if the option `kl_use_exact` is set to false. We also need to specify  how to convert the posterior distribution to a tensor. In a reparameterization layer this option is passed to `kernel_posterior_tensor_fn` and `bias_posterior_tensor_fn` arguments. The default is again just sampling from the distribution. In the case of the bias, since it is now just a deterministic quantity, sampling from the deterministic distribution simply returns the value of the loc parameter. It wouldn't make a difference if we pass the mean method instead.

In [None]:
model = Sequential([
    tfpl.Convolution2DReparameterization(16, 
                                         [3, 3], 
                                         activation='relu', 
                                         input_shape=(28, 28, 1),
                                         kernel_posterior_tensor_fn=tfd.Distribution.sample,
                                         kernel_posterior_fn=tfpl.default_mean_field_normal_fn(),
                                         kernel_prior_fn=tfpl.default_multivariate_normal_fn,
                                         bias_posterior_tensor_fn=tfd.Distribution.sample,
                                         bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=True),
                                         bias_prior_fn=None,
                                         kernel_divergence_fn=(lambda q, p, _: tfd.kl_divergence(q,p) / dataset_size)),
    MaxPool2D(3),
    Flatten(),
    tfpl.DenseReparameterization(tfpl.OneHotCategorical.params_size(10)),
    tfpl.OneHotCategorical(10)
])

Finally, we can also set a function to compute the KL divergence between posterior and prior distributions. The default is a lambda function that takes the posterior, prior and an unused argument and returns the result of the KL divergence on q and p. This function comes from the distributions module. This default attempts to compute the KL divergence analytically. This may or may not be possible depending on the choice of posterior and prior distributions. So if it isn't possible we will get an error. The unused parameter of the lambda function is the tensor obtained from applying the posterior tensor function to the posterior distribution. We could use this argument to approximate the KL divergence instead of computing it analytically. The same issue with the scaling of the KL divergence loss also happens on these reparameterization layers. In order to make sure that the value of our objective is equal to the negative of the ELBO we need to scale the KL divergence terms by a factor of 1/dataset size.