# Tensor shapes in pyro

See [here](https://pytorch.org/docs/master/notes/broadcasting.html) for how broadcasting works.

# Distribution shapes 

The simplest distribution shape is a single univariate distribution.

In [29]:
import pyro
import torch
from pyro.distributions import Bernoulli, Categorical, Normal, MultivariateNormal
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

In [2]:
d = Bernoulli(0.5)

In [3]:
assert d.event_shape == ()
assert d.batch_shape == ()

In [4]:
x = d.sample()

In [5]:
assert x.shape == ()

In [6]:
d.log_prob(x)  # the log likelihood of the sample

tensor(-0.6931)

In [7]:
assert d.log_prob(x).shape == ()

Distributions can be **batched** by passing in batched parameters. Distributions have two shape attributions
- `.batch_shape` = conditionally independent random variables. This can be e.g. the number of IID samples you generate.
- `.event_shape` = dependent random variables. E.g. a distribution over scalars has `len(event_shape)==0`, vectors `len(event_shape==1)`, and matrices `len(event_shape==2)`.

`.log_prob()` produces a single number for each event, and so has the same shape as `.batch_shape`.

In [8]:
d = Bernoulli(0.5 * torch.ones(3,4))

In [9]:
assert d.batch_shape == (3,4)
assert d.event_shape == ()

Another way to batch distributions is via `.expand()`. Parameters must be identical along the leftmost dimensions (**Todo: don't understand**).

In [10]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3,4])

In [11]:
assert d.batch_shape == (3,4)
assert d.event_shape == ()

In [12]:
x = d.sample()

In [13]:
assert x.shape == (3,4)
assert d.log_prob(x).shape == (3,4)

A Multivariate normal distribution has a non-empty `.event_shape`

In [14]:
d = MultivariateNormal(loc=torch.zeros(3), covariance_matrix=torch.eye(3,3))

In [15]:
assert d.batch_shape == ()
assert d.event_shape == (3,)
assert len(d.event_shape) == 1

In [16]:
x = d.sample()

In [17]:
assert x.shape == (3,)
assert d.log_prob(x).shape == ()

## Reshaping distributions

You can treat a univariate distribution as multivariate by calling `to_event(n)` where `n` is the number of batch dimensions (from the right) to declare as *dependent*.

In [20]:
d = Bernoulli(0.5).expand([3,4]).to_event(1)

In [21]:
assert d.batch_shape == (3,)
assert d.event_shape == (4,)

In [22]:
x = d.sample()

Samples have shape `batch_shape + event_shape` whereas `.log_prob(x)` has shape `batch_shape`. 

In [27]:
assert x.shape == (3,4)
assert d.log_prob(x).shape == (3,)

It is important to ensure that `batch_shape` is carefully controlled by either trimming it down with `.to_event` or by declaring dimensions as independent via `pyro.plate`.

## It is always safe to assume dependence

Often in Pyro we'll declare some dimensions as dependent even though they are in fact independent. For example

In [30]:
x = pyro.sample("x", Normal(0,1).expand([10]).to_event(1))

In [32]:
assert x.shape == (10,)

This is useful because:
- It allows us to swap in a higher-dimensional RV later
- Allows us to simplify the code because we then don't need a plate

In [33]:
with pyro.plate("x_plate", 10):
    x = pyro.sample("x", Normal(0,1))  # .expand([10]) is automatic
    assert x.shape == (10,)

The difference between these two versions is, with `.to_event`, Pyro assumes that the samples are dependent (even though they are actually independent).

This is always a safe thing to do -- like for d-separation, it is always safe to assume variables may be dependent, but unsafe to assume independence because this narrows the model class to outside of the true model.

In practice, SVI uses reparametrized gradient estimators for `Normal` so, in this case, both gradient estimators have the same performance.
 

## Declaring independent dimensions with `plate`

`pyro.plate` allows you to declare certain **batch** dimensions as independent, allowing inference algorithms to take advantage of this independence. For example, the index of data over a minibatch, is an independent dimension.

Plates declare the **rightmost**  batch dimension as independent. Plates can also be nested.
```python
with pyro.plate("x_axis", 320):
    # within this context, batch dimension -1 is independent
    with pyro.plate("y_axis", 200):
        # within this context, batch dimensions -2 and -1 are independent
```

One can also mix and match plates
```python
x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
    # within this context, batch dimension -2 is independent
with y_axis:
    # within this context, batch dimension -3 is independent
with x_axis, y_axis:
    # within this context, batch dimensions -3 and -2 are independent
```

In [34]:
b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))

In [42]:
assert b.shape == (2,)

In [55]:
def model1():
    a = pyro.sample("a", Normal(0,1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1)) 
    with pyro.plate("c_plate", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))  
    with pyro.plate("d_plate", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
    
    assert a.shape == ()       # batch_shape == ()       event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()       event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)     event_shape == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)     event_shape == (4,5)
    
    #----------------
    
    x_axis = pyro.plate("x_axis", 3, dim=-2)
    y_axis = pyro.plate("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0,1))
    with y_axis:
        y = pyro.sample("y", Normal(0,1))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0,1))
        z = pyro.sample("z", Normal(0,1).expand([5]).to_event(1))
    
    assert x.shape  == (3,1)     # batch_shape == (3,1)  event_shape == ()
    assert y.shape  == (2,1,1)   # batch_shape == (2,1,1)  event_shape == ()
    assert xy.shape == (2,3,1)  # batch_shape == (2,3,1)  event_shape == ()
    assert z.shape  == (2,3,1,5)   # batch_shape == (2,3,1)  event_shape == (5,) <- this one is non-intuitive!!
    

In [57]:
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

In [59]:
test_model(model1, model1, Trace_ELBO())

This example might also be helpful

In [66]:
x_axis = pyro.plate("x_axis", 3, dim=-5)
y_axis = pyro.plate("y_axis", 2, dim=-6)
with x_axis, y_axis:        
    z = pyro.sample("z", Normal(0,1).expand([5,6]).to_event(2))
z.shape

torch.Size([2, 3, 1, 1, 1, 1, 5, 6])

We can programatically observe the shapes of all objects in a model with `trace.format_shapes()`, printing
1. The distribution shpae
2. The value shape
3. The log probability shape (if calculated)

In [71]:
trace = poutine.trace(model1).get_trace()
# trace.compute_log_prob()  # <- optional
print(trace.format_shapes())

Trace Shapes:            
 Param Sites:            
Sample Sites:            
       a dist       |    
        value       |    
       b dist       | 2  
        value       | 2  
 c_plate dist       |    
        value     2 |    
       c dist     2 |    
        value     2 |    
 d_plate dist       |    
        value     3 |    
       d dist     3 | 4 5
        value     3 | 4 5
  x_axis dist       |    
        value     3 |    
  y_axis dist       |    
        value     2 |    
       x dist   3 1 |    
        value   3 1 |    
       y dist 2 1 1 |    
        value 2 1 1 |    
      xy dist 2 3 1 |    
        value 2 3 1 |    
       z dist 2 3 1 | 5  
        value 2 3 1 | 5  


## Subsampling tensors inside a `plate`

One of the main uses of `plate` is to subsample data. Since data are conditionally independent inside a plate, the expected value of loss on e.g. half the data should be half the expected loss on the full data.

To subsample data, pyro needs to know both the original data size and the subsample size. Pyro then chooses a random subset of data and yield a set of indicies (although this is customizable).

In [72]:
data = torch.arange(100.)

def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.plate("data", len(data), subsample_size=10) as ind:
        assert len(ind) == 10
        batch = data[ind]        
        mean_batch = mean[ind]
        
        # do stuff with the batch
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
        assert len(x) == 10
        

In [73]:
test_model(model2, guide=lambda: None, loss=Trace_ELBO())