Pyro distributions accept and return Pytorch tensors, which have a `shape`.
However, when dealing with the distributions themselves, there are several
`shape`s to keep in mind: `event_shape`, `batch_shape`, and `sample_shape`.

In [None]:
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
import torch

from pyro import poutine

`event_shape` can be thought of as the shape of a single event or draw from the
distribution. The general rule is that a distribution over tensors of order `n`
have an `event_shape` that is an n-tuple. A univariate distribution
(e.g. `Normal`) is distributed over a single variable, which can be considered
a tensor of order `0`. Hence, such distributions have an empty `event_shape`.
Multivariate distributions over vectors of variables (e.g. `MultivariateNormal`)
have an `event_shape` that is a 1-tuple, since vectors are tensors of order `1`.
Distributions over matrices of variables (e.g. `LKJ`) have 2-tuple
`event_shape`s, and so forth.

In [None]:
print(dist.Normal(torch.zeros(5), torch.ones(5)).event_shape)
print(dist.MultivariateNormal(torch.zeros(5), torch.eye(5)).event_shape)
print(dist.LKJ(5).event_shape)

This may seem trivial when you're only drawing a single sample, but when you're
drawing multiple batches of samples in parallel (which is easy with pyro), the
`event_shape` has a special meaning: it represents the shape of the dependent
random variables in the draw. If we sample from a multivariate distribution,
variables that are part of the same event (i.e. samples where the index differs
only in the rightmost `len(event_shape)` dimensions) are dependent on each
other, whereas variables belonging to different events (i.e. samples where the
index differs outside the rightmost `len(event_shape)` dimensions) are
independent

In [None]:
cov = torch.rand(5,5).abs()
cov = torch.mm(cov, cov.t()) + 0.01 * torch.eye(5)
d = dist.MultivariateNormal(torch.zeros(2,5), torch.tile(cov, (2,1,1)))
print(f"event shape: {d.event_shape}")
s = d.sample((100000,))
cond_s = s[s[:,0,0] > 1.]
plt.hist(s[:,0,1].cpu().numpy(), bins=100, density=True, color="red", alpha=0.5)
plt.hist(cond_s[:,0,1].cpu().numpy(), bins=100, density=True, color="blue", alpha=0.5)
plt.show()

plt.hist(s[:,1,1].cpu().numpy(), bins=100, density=True, color="red", alpha=0.5)
plt.hist(cond_s[:,1,1].cpu().numpy(), bins=100, density=True, color="blue", alpha=0.5)
plt.show()

In the above, we sample from 2 i.i.d. multivariate Gaussians with 5
dimensions each. The `1st` dimension of the `1st` Gaussian depends on the `0th`
dimension (top plot) but the `1st` dimension of the `2nd` Gaussian is
independent of the `0th` dimension of the `1st` Gaussian.

In contrast, `batch_shape` represents the conditionally independent RVs. For
example, the `MultivariateNormal` created above has a `batch_shape` of 2.

In [None]:
print(d.batch_shape)


This is because the mean vector and covariance matrix used to construct the
distribution had shapes `(2,5)` and `(2,5,5)`, respectively. The batch
dimensions don't have to be identically distributed.

In [None]:
m1 = torch.zeros(5)
cov1 = torch.rand(5,5).abs()
cov1 = torch.mm(cov1, cov1.t()) + 0.01 * torch.eye(5)
m2 = torch.ones(5)
cov2 = torch.rand(5,5).abs()
cov2 = torch.mm(cov2, cov2.t()) + 0.01 * torch.eye(5)
d = dist.MultivariateNormal(torch.stack([m1, m2]), torch.stack([cov1, cov2]))

s = d.sample((100000,))
print(s.shape)
plt.hist(s[:, 0, 0].cpu().numpy(), bins=100, color="red", alpha=0.5)
plt.hist(s[:, 1, 0].cpu().numpy(), bins=100, color="blue", alpha=0.5)
plt.show()


Finally, note that when we call the `sample` method we pass a `sample_shape`.
This represents a tensor of i.i.d. samples drawn from the distribution.

In [None]:
plt.hist(s[:50000, 0, 0].cpu().numpy(), bins=100, color="red", alpha=0.5)
plt.hist(s[50000:, 0, 0].cpu().numpy(), bins=100, color="blue", alpha=0.5)
plt.show()

The shape of a sample is the concatenation of `event_shape`, `batch_shape`, and
`sample_shape`. The shape of `log_prob`, on the other hand, is the concatenation
of `sample_shape` and `batch_shape`, since the PDF is defined over entire
events and not individual dimensions of them. As far as we can tell, pyro does
not support getting conditional probabilities for that are conditioned on some
variables of a multivariate distribution.


In [None]:
sample_shape = (100, 15)
s = d.sample(sample_shape)
assert(s.shape == sample_shape + d.batch_shape + d.event_shape)
assert(d.log_prob(s).shape == sample_shape + d.batch_shape)