In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

In [3]:
exp = tfd.Exponential(rate=[[1., 1.5, 0.8], [0.3, 0.4, 1.8]])

In [4]:
ind_exp = tfd.Independent(exp)
ind_exp.sample(4)

<tf.Tensor: shape=(4, 2, 3), dtype=float32, numpy=
array([[[ 1.114776  ,  1.9618077 ,  1.3775986 ],
        [ 2.2595909 ,  1.27244   ,  0.4738011 ]],

       [[ 0.47329444,  1.4709959 ,  1.1949813 ],
        [ 4.6351604 ,  1.7579396 ,  0.09245706]],

       [[ 0.6893939 ,  0.7641494 ,  0.3967921 ],
        [13.440917  ,  0.17377266,  0.18867895]],

       [[ 0.31947738,  0.41454008,  1.9400377 ],
        [ 0.21371049,  1.519958  ,  0.5807488 ]]], dtype=float32)>

### Increasing the rank

In [14]:
rates = [
    [[[1., 1.5, 0.8], [0.3, 0.4, 1.8]]],
    [[[0.2, 0.4, 1.4], [0.4, 1.1, 0.9]]]
]

In [16]:
exp = tfd.Exponential(rate=rates)
exp

<tfp.distributions.Exponential 'Exponential' batch_shape=[2, 1, 2, 3] event_shape=[] dtype=float32>

In [18]:
ind_exp = tfd.Independent(exp, reinterpreted_batch_ndims=2)
ind_exp

<tfp.distributions.Independent 'IndependentExponential' batch_shape=[2, 1] event_shape=[2, 3] dtype=float32>

In [20]:
# Sample can also be multidimensional

ind_exp.sample([4,2]).shape

TensorShape([4, 2, 2, 1, 2, 3])

In [21]:
# The log prob method broadcasts 0.5 for the event and batch shape

ind_exp.log_prob(0.5)

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[-4.2501554],
       [-5.3155975]], dtype=float32)>

In [22]:
# Another example, this time with a (1,3)

# In this case it broadcasts for the first dimension of the event shape
# and for both dimensions of the batch shape

ind_exp.log_prob([[0.3, 0.5, 0.8]])

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[-4.7701554],
       [-5.885597 ]], dtype=float32)>

In [23]:
# Now, passing a rank 5 input to the log prob method

ind_exp.log_prob(tf.random.uniform((5, 1, 1, 2, 1)))

<tf.Tensor: shape=(5, 2, 1), dtype=float32, numpy=
array([[[-4.8681116],
        [-5.8733225]],

       [[-3.7062087],
        [-4.810455 ]],

       [[-4.3748903],
        [-5.3849773]],

       [[-3.6069574],
        [-5.03947  ]],

       [[-5.6348534],
        [-6.529839 ]]], dtype=float32)>