In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

In [5]:
# Make multivariate Distribution

normal_distribution = tfd.MultivariateNormalDiag(loc=[[0.5, 1.], [0.1, 0], [0, 0.2]],
                                                scale_diag=[[2, 3], [1, 3], [4, 4]])
normal_distribution

<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[3] event_shape=[2] dtype=float32>

In [6]:
# Sample

normal_distribution.sample(5)

<tf.Tensor: shape=(5, 3, 2), dtype=float32, numpy=
array([[[ 0.809603  , -4.3343606 ],
        [-0.5144393 , -3.8310034 ],
        [-1.5013548 ,  4.2068233 ]],

       [[-2.1317112 ,  8.458184  ],
        [ 0.8523676 , -5.365728  ],
        [ 1.831659  ,  7.0399384 ]],

       [[ 0.901067  ,  0.8922256 ],
        [ 0.17544934,  5.216042  ],
        [ 0.31113783, -2.0891843 ]],

       [[ 2.9185498 ,  2.9119256 ],
        [ 0.19371161, -1.3805535 ],
        [-0.02062214,  2.4009593 ]],

       [[-0.47351587,  1.4214468 ],
        [-1.047671  ,  5.4978595 ],
        [ 0.65695685, -3.8748133 ]]], dtype=float32)>

In [7]:
# Multivariate Normal batched distribution
# We are broadcasting batch shapes of 'loc' and 'scale_diag'
# against each other

loc = [[[0.3, 1.5, 1.], [0.2, 0.4, 2.8]],
      [[2., 2.3, 8.], [1.4, 1,1.3]]]
scale_diag = [0.4, 1., 0.7]
normal_distribution = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
normal_distribution

<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[2, 2] event_shape=[3] dtype=float32>

In [9]:
# Use independent to move part of the batch shape

ind_normal_distributions = tfd.Independent(normal_distribution,
                                          reinterpreted_batch_ndims=1)
ind_normal_distributions

<tfp.distributions.Independent 'IndependentMultivariateNormalDiag' batch_shape=[2] event_shape=[2, 3] dtype=float32>

In [12]:
# Draw some samples

samples = ind_normal_distributions.sample(5)
samples.shape

TensorShape([5, 2, 2, 3])

In [13]:
# '[B, E]' shaped input

inp = tf.random.uniform((2, 2, 3))
ind_normal_distributions.log_prob(inp)

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-11.014406, -67.21721 ], dtype=float32)>

In [16]:
# '[E]' shaped input (broadcasting over batch size)

inp = tf.random.uniform((2, 3))
ind_normal_distributions.log_prob(inp)

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([-11.905516, -68.43455 ], dtype=float32)>

In [17]:
# '[S, B, E]' shaped input (broadcasting over batch size)

inp = tf.random.uniform((5, 2, 2, 3))
ind_normal_distributions.log_prob(inp)

<tf.Tensor: shape=(5, 2), dtype=float32, numpy=
array([[ -8.453414, -79.14973 ],
       [ -9.594481, -83.12004 ],
       [-11.334316, -81.84296 ],
       [ -8.279422, -70.33672 ],
       [ -9.213768, -79.41513 ]], dtype=float32)>

In [22]:
# '[s, b, e]' shaped input, where [b, e] is broadcastable over [B, E]

inp = tf.random.uniform((5, 1, 2, 1))
ind_normal_distributions.log_prob(inp)



<tf.Tensor: shape=(5, 2), dtype=float32, numpy=
array([[ -9.840872, -58.58171 ],
       [-12.37376 , -83.1953  ],
       [-11.041252, -85.17639 ],
       [ -9.692735, -59.609215],
       [-10.273886, -69.077835]], dtype=float32)>