This notebook covers the shape of tensors in pyro; can be viewed as a more friendly version of the official tutorial [http://pyro.ai/examples/tensor_shapes.html](http://pyro.ai/examples/tensor_shapes.html).

In [30]:
import pyro
import pyro.distributions as dist
import torch

A sampling from distributions has two shapes: `batch_shape` and `event_shape`. `batch_shape` means the shape of samplings (batch_shape can be used to indicate independence among variables; see `3_random_variable_dependency`); for example,  

In [42]:
d = dist.Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
x = d.sample()  # 3×4 random variables 
x

tensor([[0., 0., 1., 0.],
        [1., 0., 0., 1.],
        [0., 0., 0., 1.]])

In [44]:
d = dist.Bernoulli(torch.as_tensor([0.5, 0.4]))
assert d.batch_shape == (2, )
x = d.sample()
x 

tensor([1., 1.])

Typically `expand` is used to reshape `batch_shape`:

In [34]:
d = dist.Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])) 
assert d.batch_shape == (4, ) # 4 variables with different Bernoulli parameters
d1 = d.expand([2, 3, 4]) 
assert d1.batch_shape == (2,3,4)
d11 = d1.expand([2,2,3,4])
assert d11.batch_shape == (2,2,3,4)

try:
    d2 = d.expand([2,3,5])
except:
    print('the argument of expand must be [batch_shape_to_add] + [batch_shape_already]')

the argument of expand must be [batch_shape_to_add] + [batch_shape_already]


When the distribution returns a random tensor (vector, matrix, etc.), the `event_shape` indicates that size. For example, 

In [32]:
d = dist.MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()        # not specifying batch, so only one random vector will be generated each time
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape

`to_event` is used to shape `event_shape`:

In [35]:
d = dist.Bernoulli(0.5 * torch.ones(3, 4, 2, 5)).to_event(2)
assert d.batch_shape == (3, 4)
assert d.event_shape == (2, 5)
# to_event(n) turns the last n dimensions of batch_shape to event_shape
d.shape

<bound method TorchDistributionMixin.shape of Independent(Bernoulli(probs: torch.Size([3, 4, 2, 5])), 2)>

The shape of a sample is `batch_shape + event_shape`:

In [33]:
d = dist.MultivariateNormal(torch.zeros(3), torch.eye(3, 3)).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == (3, )
x = d.sample()
assert x.shape == (3, 4, 3) # 3×4 random vectors of size 3
x

tensor([[[ 0.8970,  0.8968, -1.1532],
         [ 1.7721,  0.8879,  2.1872],
         [ 0.1745, -0.2508,  0.0775],
         [ 1.0610, -0.6311,  1.3013]],

        [[ 0.5469, -0.0949,  1.2606],
         [ 0.2467,  1.3282,  2.0833],
         [ 0.2237,  0.8364, -0.7097],
         [ 0.7022,  1.1130, -0.1235]],

        [[-1.0572,  1.1353, -0.2552],
         [ 0.6337, -1.3282,  0.5570],
         [ 0.1824,  1.7431, -0.6060],
         [-1.0008,  2.3194, -0.7607]]])

We can use the following code to view the shapes of all random variables in a model:

In [None]:
from pyro import poutine
trace = poutine.trace(model).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

## Log probability of samples

For distributions which provide the `log_prob` function, we can use it to calculate the log probability of the sample. 

In [36]:
d = dist.Bernoulli(0.5)
x = d.sample()
d.log_prob(x) # ln(0.5)

tensor(-0.6931)

The shape of `log_prob` is the same as the distribution's `batch_shape` (variables indiced by `event_shape` are considered dependent):

In [37]:
d = dist.Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

In [38]:
d = dist.Bernoulli(0.5 * torch.ones(3, 4, 2, 5)).to_event(2)
x = d.sample()
assert d.log_prob(x).shape == (3, 4)