## [Tensor Shapes in Pyro](http://pyro.ai/examples/tensor_shapes.html#Tensor-shapes-in-Pyro)

In [4]:
torch.ones(3,4,5) + torch.ones(5)

tensor([[[2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.]],

        [[2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.]],

        [[2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2.]]])

Distribution .sample().shape == batch_shape + event_shape.

Distribution .log_prob(x).shape == batch_shape (but not event_shape!).

Use my_dist.to_event(1) to declare a dimension as dependent.

All dimensions must be declared either dependent or conditionally independent.

Try to support batching on the left. This lets Pyro auto-parallelize.

use negative indices like x.sum(-1) rather than x.sum(2)

use ellipsis notation like pixel = image[..., i, j]

use Vindex if i,j are enumerated, pixel = Vindex(image)[..., i, j]

When using pyro.plate’s automatic subsampling, be sure to subsample your data:

Either manually subample by capturing the index with pyro.plate(...) as i: ...

or automatically subsample via batch = pyro.subsample(data, event_dim=...).

In [5]:
import torch
import os
import pyro

In [7]:
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal

In [8]:
from pyro.distributions.util import broadcast_shape

In [9]:
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate

In [10]:
import pyro.poutine as poutine
from pyro.optim import Adam

In [11]:
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.2')

In [12]:
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

---

Indices over .batch_shape denote conditionally independent random variables, whereas indices over .event_shape denote dependent random variables (ie one draw from a distribution). Because the dependent random variables define probability together, the .log_prob() method only produces a single number for each event of shape .event_shape. Thus the total shape of .log_prob() is .batch_shape:

      |      iid     | independent | dependent
------+--------------+-------------+------------

shape = sample_shape + batch_shape + event_shape

In [20]:
d = Bernoulli(0.5) # Univariate Distribution

In [15]:
d.batch_shape

torch.Size([])

In [16]:
d.event_shape

torch.Size([])

In [17]:
x = d.sample()

In [18]:
x.shape

torch.Size([])

In [19]:
d.log_prob(x).shape

torch.Size([])

In [21]:
d = Bernoulli(0.5 * torch.ones(3,4))

In [22]:
d.batch_shape

torch.Size([3, 4])

In [23]:
d.event_shape

torch.Size([])

In [24]:
x = d.sample()

In [25]:
x.shape

torch.Size([3, 4])

In [26]:
d.log_prob(x).shape

torch.Size([3, 4])

Another way to batch distributions is via the .expand() method. This only works if parameters are identical along the leftmost dimensions.

In [27]:
d = Bernoulli(torch.tensor([0.1,0.2,0.4,0.3])).expand([3,4])

In [28]:
d.batch_shape

torch.Size([3, 4])

In [29]:
d.event_shape

torch.Size([])

In [30]:
x = d.sample()

In [31]:
x.shape

torch.Size([3, 4])

In [32]:
x

tensor([[1., 0., 1., 1.],
        [0., 0., 0., 1.],
        [1., 0., 0., 1.]])

In [33]:
d.log_prob(x)

tensor([[-2.3026, -0.2231, -0.9163, -1.2040],
        [-0.1054, -0.2231, -0.5108, -1.2040],
        [-2.3026, -0.2231, -0.5108, -1.2040]])

In [34]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3,3))

In [35]:
d.batch_shape

torch.Size([])

In [36]:
d.event_shape

torch.Size([3])

In [37]:
x = d.sample()

In [38]:
x

tensor([-0.3474,  0.6078,  0.0177])

In [39]:
d.log_prob(x)

tensor(-3.0020)

#### In Pyro you can treat a univariate distribution as multivariate by calling the .to_event(n) property where n is the number of batch dimensions (from the right) to declare as dependent.

In [40]:
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1)

In [41]:
d.batch_shape

torch.Size([3])

In [49]:
x=d.sample()
x

tensor([[0., 0., 0., 1.],
        [0., 0., 1., 1.],
        [1., 1., 0., 0.]])

In [48]:
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(0)

In [47]:
d.batch_shape

torch.Size([3])

In [50]:
x = d.sample()
x

tensor([[1., 0., 0., 1.],
        [1., 0., 1., 0.],
        [0., 0., 0., 0.]])

While you work with Pyro programs, keep in mind that samples have shape batch_shape + event_shape, whereas .log_prob(x) values have shape batch_shape. You’ll need to ensure that batch_shape is carefully controlled by either trimming it down with .to_event(n) or by declaring dimensions as independent via pyro.plate.

In [56]:
x = pyro.sample("x", Normal(0,1).expand([10]).to_event(1))

In [55]:
x

tensor([ 0.6164, -1.1625,  0.3950, -1.7467, -0.4075, -0.2454,  0.1144,  0.6614,
        -0.2194, -1.0083])

#### The difference between these two versions is that the second version with plate informs Pyro that it can make use of conditional independence information when estimating gradients, whereas in the first version Pyro must assume they are dependent (even though the normals are in fact conditionally independent). This is analogous to d-separation in graphical models: it is always safe to add edges and assume variables may be dependent (i.e. to widen the model class), but it is unsafe to assume independence when variables are actually dependent (i.e. narrowing the model class so the true model lies outside of the class, as in mean field). In practice Pyro’s SVI inference algorithm uses reparameterized gradient estimators for Normal distributions so both gradient estimators have the same performance.

In [68]:
def model1():
    a = pyro.sample("a", Normal(0,1))
    b = pyro.sample("b", Normal(torch.zeros(2),1).to_event(1))
    with pyro.plate("c_plate", 2):
        c = pyro.sample("c", Normal(torch.zeros(2),1))
    with pyro.plate("d_plate", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))

    x_axis = pyro.plate("x_axis", 3, dim=-2)
    y_axis = pyro.plate("y_axis", 2, dim=-3)

    with x_axis:
        x = pyro.sample("x", Normal(0,1))
    with y_axis:
        y = pyro.sample("y", Normal(0,1))

    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0,1))
        z = pyro.sample("z", Normal(0,1).expand([5]).to_event(1))

In [64]:
test_model(model1, model1, Trace_ELBO())

In [69]:
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

Trace Shapes:            
 Param Sites:            
Sample Sites:            
       a dist       |    
        value       |    
     log_prob       |    
       b dist       | 2  
        value       | 2  
     log_prob       |    
 c_plate dist       |    
        value     2 |    
     log_prob       |    
       c dist     2 |    
        value     2 |    
     log_prob     2 |    
 d_plate dist       |    
        value     3 |    
     log_prob       |    
       d dist     3 | 4 5
        value     3 | 4 5
     log_prob     3 |    
  x_axis dist       |    
        value     3 |    
     log_prob       |    
  y_axis dist       |    
        value     2 |    
     log_prob       |    
       x dist   3 1 |    
        value   3 1 |    
     log_prob   3 1 |    
       y dist 2 1 1 |    
        value 2 1 1 |    
     log_prob 2 1 1 |    
      xy dist 2 3 1 |    
        value 2 3 1 |    
     log_prob 2 3 1 |    
       z dist 2 3 1 | 5  
        value 2 3 1 | 5  
     log_pro

In [70]:
data = torch.arange(100.)

In [71]:
def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.plate("data", len(data), subsample_size=10) as ind:
        assert len(ind) == 10    # ind is a LongTensor that indexes the subsample.
        batch = data[ind]        # Select a minibatch of data.
        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.
        # Do stuff with batch:
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)

In [72]:
test_model(model2, guide=lambda: None, loss=Trace_ELBO())

### [Broadcasting to allow parallel enumeration](http://pyro.ai/examples/tensor_shapes.html#Broadcasting-to-allow-parallel-enumeration)