In [36]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

In [17]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

tensor(-0.6931) tensor(1.)


In [23]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()

s = d.sample()
assert s.shape == (3,4)
assert d.log_prob(s).shape == (3,4)

In [25]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)

s = d.sample()
assert s.shape == (3,)
assert d.log_prob(s).shape == ()

In [43]:
#Reshaping distributions:
d = Bernoulli(0.5 * torch.ones(3,4)).independent(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)

s = d.sample()
assert s.shape == (3,4)
assert d.log_prob(s).shape == (3,)

In [40]:
#It is always safe to assume dependence:
d = Normal(0,1)
s = pyro.sample("x", Normal(0, 1).expand_by([10]).independent(1))
assert s.shape == (10,)
assert d.log_prob(s).shape ==(10,)