In [1]:
import tensorflow_probability as tfp
tfd = tfp.distributions

In [3]:
# Create a two dimensional diagonal Gaussian dist

mv_normal = tfd.MultivariateNormalDiag(loc=[-1., 0.5], scale_diag=[1., 1.5])
print(mv_normal)

tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[], event_shape=[2], dtype=float32)


In [4]:
mv_normal.event_shape

TensorShape([2])

The fact that the event_shape=2 indicates that the random variable created is two-dimensional

In [7]:
# Produce 3 independent samples from the multivariate distribution

mv_normal.sample(3)

<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[-0.40149927,  0.94770396],
       [-2.0028892 ,  1.6918546 ],
       [-1.2507302 ,  2.4528666 ]], dtype=float32)>

Notice that each of the 3 samples is two-dimensional.

In [8]:
# Let's compare with a batched normal

batched_normal = tfd.Normal(loc=[-1., 0.5], scale=[1., 1.5])

In [9]:
print(batched_normal)

tfp.distributions.Normal("Normal", batch_shape=[2], event_shape=[], dtype=float32)


Notice that the batched normal has a batch_shape of 2 and a blank event shape. This happens because the normal distribution is univariate.

In [10]:
# Sampling from the batched_normal

batched_normal.sample(3)

<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[-1.2561183 ,  0.47626746],
       [-0.7919155 ,  0.2661795 ],
       [-0.5263653 ,  1.6373231 ]], dtype=float32)>

Notice that when we sample we get a (3,2) matrix, since the 3 is our sample size and 2 is the batch size. Despite the result having the same shape as the multivariate case, there is an important distinction between the two. The multivariate case is a two-dimensional random variable, while the normal distribution is a batch of two distributions of a single random variable, i.e. we create two random variables normally distributed.

In [11]:
# Computing log probs for the mv normal

mv_normal.log_prob([-0.2, 1.8])

<tf.Tensor: shape=(), dtype=float32, numpy=-2.9388978>

In [13]:
# Computing log probs for the batched normal

batched_normal.log_prob([-0.2, 1.8])

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-1.2389386, -1.699959 ], dtype=float32)>

For the mv normal when we pass a two dimensional array for the log prob method, this array represents a single realization of the two-dimensional random variable and correspondly the tensor that is returned contains a single log prob value.

On the other hand, in the case of the batched normal distribution, the input array represents a value for each of the random variables for the two normal distributions in the batch. The log probs for each of the two realizations are evaluated and returned as a length two tensor.

In this particular example, the two components of the  multivariate diagonal gaussian are independent, so the log_prob value for the mv normal is just the sum of the two returned values for the independent gaussian distributions defined on the batched_normal.

In [20]:
print('mv_normal: ')
print(mv_normal.log_prob([-0.2, 1.8]).numpy())

print('batched_normal: ')
print((batched_normal.log_prob([-0.2, 1.8])[0] + batched_normal.log_prob([-0.2, 1.8])[1]).numpy())

mv_normal: 
-2.9388978
batched_normal: 
-2.9388976


In [21]:
# Batched multivariate distribution

batched_mv_normal = tfd.MultivariateNormalDiag(loc=[[-1., 0.5],
                                                    [2., 0.],
                                                    [-0.5, 1.5]],
                                              scale_diag = [[1., 1.5],
                                                            [2., 0.5],
                                                            [1., 1.]])

In [22]:
print(batched_mv_normal)

tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[3], event_shape=[2], dtype=float32)


In [23]:
# Sample from these distributions

batched_mv_normal.sample(2)

<tf.Tensor: shape=(2, 3, 2), dtype=float32, numpy=
array([[[ 0.42425132, -1.1782391 ],
        [-0.09813643,  0.32067573],
        [ 0.17918658,  3.6430721 ]],

       [[-2.3836815 ,  2.0172658 ],
        [ 1.7220739 ,  0.5960695 ],
        [-1.6641382 ,  2.4435616 ]]], dtype=float32)>

It is a (2, 3, 2) tensor. The first 2 is the sample size, the 3 is the batch_shape and the last 2 is the event_size or the number of dimensions of the random variable.