In [1368]:
import torch
import pyro
from pyro.infer import config_enumerate
import pyro.poutine as poutine

# Distributions

Pyro distributions are wrappers around Torch distributions. They have the following main functions:

* `d.sample(sample_shape=[])` : randomly sample values from the distribution

* `d.log_prob(x)` : returns the log probability (likelihood) of a given variable under the distribution

Note, however, that we will generally use the pyro.sample("name", dist) wrapper function in implementing models.

In [1369]:
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal, HalfNormal

In [1370]:
d = Bernoulli(0.7)
x = d.sample((2, 5))
x

tensor([[0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0.]])

In [1371]:
d.log_prob(x)

tensor([[-1.2040, -0.3567, -0.3567, -0.3567, -0.3567],
        [-0.3567, -0.3567, -0.3567, -0.3567, -1.2040]])

In [1372]:
d = Categorical(probs=torch.tensor([0.1, 0.5, 0.3, 0.1]))
d.sample((2,10))

tensor([[1, 1, 2, 2, 1, 1, 1, 2, 1, 1],
        [1, 1, 0, 3, 1, 1, 2, 1, 2, 1]])

In [1373]:
d = Normal(0, 1)
d.sample((2,5))

tensor([[ 0.2676,  1.1033,  1.2653, -0.7328, -1.2921],
        [ 0.4180,  1.9084, -0.3475,  1.8060, -0.4609]])

In [1374]:
d.log_prob(torch.tensor(0))

tensor(-0.9189)

In [1375]:
d.log_prob(torch.tensor(10))

tensor(-50.9189)

In [1376]:
loc = torch.tensor([0., -10., 10.])
cov = torch.eye(3)
d = MultivariateNormal(loc, cov)
d.sample()

tensor([ 0.0129, -9.0116, 11.4071])

# Sampling within Pyro models

In order to implement inference algorithms, Pyro defines a wrapper for sampling named values (or "sites") within a model.

A site corresponds to a latent variable, unless the argument "obs" is given, marking the sampled value as an observation (in this case, the observation will always be returned when directly calling the model).

In [1377]:
def model1():
    return pyro.sample("Z", Normal(0, 1))

model1()

tensor(-0.4096)

In [1378]:
def model2():
    return pyro.sample("obs", Normal(0, 1), obs=torch.tensor(100))

model2()

tensor(100)

# Declaring iid variables with Plate

Pyro offers a "plate" context manager that allows multiple independently and identically distributed (iid) variables to be drawn. Declaring iid in this way increases the efficiency of inference.

In [1379]:
data = torch.linspace(-100, 100, 10)

def model3(data):
  with pyro.plate("data", len(data)):
    return pyro.sample("obs", Normal(0, 1), obs=data)

model3(data)

tensor([-100.0000,  -77.7778,  -55.5556,  -33.3333,  -11.1111,   11.1111,
          33.3333,   55.5556,   77.7778,  100.0000])

# Registering tunable parameters

When doing statistical inference using SVI, you want to be able to register all parameters that should be updated during gradient descent. 

These parameters are stored in a global dictionary.

It is also possible to impose a constraint on a parameter, such as that it must be greater than 0.

NOTE: if training multiple models in the same code/REPL, you need to manually clear the param store.

In [1380]:
from pyro.distributions import constraints

pyro.clear_param_store()

def model4(data):
  loc = pyro.param("loc", torch.tensor(-2.))
  scale = pyro.param("scale", torch.tensor(1.), constraint=constraints.positive)

  with pyro.plate("data", len(data)):
    return pyro.sample("obs", Normal(loc, scale), obs=data)
  
model4(data)

pyro.get_param_store()["loc"]

tensor(-2., requires_grad=True)

# Tensor shapes in Pyro

### Three types of dimensions that make up a tensor:

* __event shape__ : dependent random variables (e.g., one draw from a MultivariateNormal)

* __batch shape__ : independent random variables (e.g., two draws from two separate Normals)

* __sample shape__ : independent and identically distributed (iid) random variables (e.g., two draws from one Normal)

### Best practices

Batch shape is always the leftmost shape dimension, so use the following notations to access values across batches:

* Negative indices like `x.sum(-1)`

* Ellipsis notation like `pixel = image[..., i, j]`

Tensors broadcast on the right, e.g., `torch.ones(3,4,5) + torch.ones(5)`.

In general, it's always safe to assume dependence, but declaring variables independent when appropriate can improve efficiency in inference.

In [1381]:
loc = torch.tensor([0., -10., 10.])
cov = torch.eye(3)
d = MultivariateNormal(loc, cov)
x = d.sample()

print(x.shape, '=', d.batch_shape, '+', d.event_shape)

torch.Size([3]) = torch.Size([]) + torch.Size([3])


In [1382]:
l = d.log_prob(x)

print(l.shape, '=', d.batch_shape)

torch.Size([]) = torch.Size([])


In [1383]:
batch_size = 5
loc = torch.tensor([0., -10., 10.]).repeat((batch_size,1))
cov = torch.eye(3).repeat((batch_size,1,1))
d = MultivariateNormal(loc, cov)
x = d.sample()

print(x.shape, '=', d.batch_shape, '+', d.event_shape)

torch.Size([5, 3]) = torch.Size([5]) + torch.Size([3])


In [1384]:
l = d.log_prob(x)

print(l.shape, '=', d.batch_shape)

torch.Size([5]) = torch.Size([5])


In [1385]:
sample_size = 10
loc = torch.tensor([0., -10., 10.])
cov = torch.eye(3)
d = MultivariateNormal(loc, cov)
with pyro.plate("samples", sample_size):
    x = pyro.sample("x", d)

print(x.shape, '=', torch.Size([sample_size]), '+', d.event_shape)

torch.Size([10, 3]) = torch.Size([10]) + torch.Size([3])


# Debugging with Poutine

Poutine is a library of _effect handlers_ that can record and modify the behavior of Pyro programs.

One common use is to trace the execution of a model and record the shapes of all sites in the model. This will show batch/sample shapes on the left, and event shapes on the right.

In [1386]:
pyro.clear_param_store()

def model5(data):
  loc = pyro.param("loc", torch.tensor([0., -10., 10.]))
  cov = pyro.param("cov", torch.eye(3))

  with pyro.plate("factors_all", 5):
    factors = pyro.sample("factors", MultivariateNormal(loc, cov))

  scale = pyro.sample("scale", HalfNormal(factors.sum()))

  with pyro.plate("data", len(data)):
    return pyro.sample("obs", Normal(0, scale), obs=data)

trace = poutine.trace(model5).get_trace(data)
trace.compute_log_prob()
print(trace.format_shapes())

   Trace Shapes:       
    Param Sites:       
             loc    3  
             cov  3 3  
   Sample Sites:       
factors_all dist    |  
           value  5 |  
        log_prob    |  
    factors dist  5 | 3
           value  5 | 3
        log_prob  5 |  
      scale dist    |  
           value    |  
        log_prob    |  
       data dist    |  
           value 10 |  
        log_prob    |  
        obs dist 10 |  
           value 10 |  
        log_prob 10 |  
