In [64]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    return loss.loss(model, guide)

In [17]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

tensor(-0.6931) tensor(1.)


In [23]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()

s = d.sample()
assert s.shape == (3,4)
assert d.log_prob(s).shape == (3,4)

In [25]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)

s = d.sample()
assert s.shape == (3,)
assert d.log_prob(s).shape == ()

In [43]:
#Reshaping distributions:
d = Bernoulli(0.5 * torch.ones(3,4)).independent(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)

s = d.sample()
assert s.shape == (3,4)
assert d.log_prob(s).shape == (3,)

In [76]:
#It is always safe to assume dependence:
d = Normal(0,1).expand_by([10]).independent(1)
s = pyro.sample("x", Normal(0, 1).expand_by([10]).independent(1))
assert s.shape == (10,)
assert d.log_prob(s).shape ==()

In [104]:
d = Normal(torch.zeros(2), 1).independent(1)
a = pyro.sample('a', d)
assert d.log_prob(a).shape == ()
xy = pyro.sample("xy", Normal(0, 1).expand_by([2, 3, 1]))

x_axis = pyro.iarange("x_axis", 1, dim=-1)
with x_axis:
    dist = Normal(0, 1).expand_by([2, 3, 1]).independent(1)
    s = pyro.sample('s', dist)
    print(dist.log_prob(s).shape)

torch.Size([2, 3])


##Important note: independent should be after the distribution, not the sample.

In [106]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).independent(1))
    with pyro.iarange("c_iarange", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
    with pyro.iarange("d_iarange", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).independent(2))
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_sahpe == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5)

    x_axis = pyro.iarange("x_axis", 3, dim=-2)
    y_axis = pyro.iarange("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0, 1).expand_by([3, 1]))
    with y_axis:
        y = pyro.sample("y", Normal(0, 1).expand_by([2, 1, 1]))
    
    with x_axis, y_axis:
        xy = pyro.sample('xy', Normal(0, 1).expand([2,3,1]))
        z = pyro.sample("z", Normal(0, 1).expand_by([2, 3, 1, 5]).independent(1))
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

test_model(model1, model1, Trace_ELBO())

-0.0

In [113]:
#Subsampling using iarange:
data = torch.arange(100)
def model2():
    mean = pyro.param('mean', torch.zeros(len(data)))
    with pyro.iarange('mean_range', len(data), subsample_size=10) as ind:
        batch = data[ind]
        mean_batch = mean[ind]
        x = pyro.sample('obs', Normal(mean_batch,1), obs=batch)
        assert len(x) == 10
test_model(model2, guide=lambda: None, loss=Trace_ELBO())

122071.890625

In [125]:
#Broadcasting to allow parallel enumeration
p = pyro.param("p", torch.arange(6) / 6)
b = pyro.sample('b', Categorical(torch.ones(6) / 6))
print(p)

tensor([ 0.0000,  0.1667,  0.3333,  0.5000,  0.6667,  0.8333])


In [133]:
@config_enumerate(default="parallel")
def model3():
    p = pyro.param("p", torch.arange(6) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))
    a = pyro.sample("a", Categorical(torch.ones(6) / 6))
    b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
    with pyro.iarange("c_iarange", 4):
        c = pyro.sample("c", Bernoulli(0.3).expand_by([4]))
        with pyro.iarange("d_iarange", 5):
            d = pyro.sample("d", Bernoulli(0.4).expand_by([5,4]))
            e_loc = locs[d.long()].unsqueeze(-1)
            e_scale = torch.arange(1, 8)
            e = pyro.sample("e", Normal(e_loc, e_scale)
                            .independent(1))  # Note this depends on d.

    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 6, 1, 1   )  # 2 enumerated Bernoullis x 6 Categoricals.
    assert c.shape == (   2, 1, 1, 1, 4   )  # Only 2 Bernoullis; does not depend on a or b.
    assert d.shape == (2, 1, 1, 1, 5, 4   )  # Only two Bernoullis.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

    assert e_loc.shape   == (2, 1, 1, 1, 5, 4, 1,)
    assert e_scale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_iarange_nesting=2))

tensor([-1.,  1.])
tensor([-1.,  1.])


-2.4835440370907236e-09

In [141]:
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None  # set to either True or False below

def fun(observe):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.iarange('x_axis', width, dim=-2)
    y_axis = pyro.iarange('y_axis', height, dim=-1)

    # Note that the shapes of these sites depend on whether Pyro is enumerating.
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([height]))
    if enumerated:
        assert x_active.shape  == (2, width, 1)
        assert y_active.shape  == (2, 1, 1, height)
    else:
        assert x_active.shape  == (width, 1)
        assert y_active.shape  == (height,)

    # The first trick is to broadcast. This works with or without enumeration.
    p = 0.1 + 0.5 * x_active * y_active
    if enumerated:
        assert p.shape == (2, 2, width, height)
    else:
        assert p.shape == (width, height)

    # The second trick is to index using ellipsis slicing.
    # This allows Pyro to add arbitrary dimensions on the left.
    dense_pixels = torch.zeros_like(p)
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)

    with x_axis, y_axis:
        if observe:
            pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

@config_enumerate(default="parallel")
def guide4():
    fun(observe=False)

# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())

# Test with enumeration.
enumerated = True
test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))

If you want to enumerate sites, you need to use TraceEnum_ELBO instead.
  'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))


17.217742919921875

In [144]:
#Automatic broadcasting via broadcast poutine
num_particles = 100  # Number of samples for the ELBO estimator
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])

def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([num_particles, width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([num_particles, 1, height]))
    return x_active, y_active

def sample_pixel_locations_automatic_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y))
    return x_active, y_active

def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand_by([width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand_by([height]))
    return x_active, y_active

def fun(observe, sample_fn):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.iarange('x_axis', width, dim=-2)
    y_axis = pyro.iarange('y_axis', height, dim=-1)

    with pyro.iarange("num_particles", 100, dim=-3):
        x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
        # Indices corresponding to "parallel" enumeration are appended
        # to the left of the "num_particles" iarange dim.
        assert x_active.shape  == (2, num_particles, width, 1)
        assert y_active.shape  == (2, 1, num_particles, 1, height)
        p = 0.1 + 0.5 * x_active * y_active
        assert p.shape == (2, 2, num_particles, width, height)

        dense_pixels = torch.zeros_like(p)
        for x, y in sparse_pixels:
            dense_pixels[..., x, y] = 1
        assert dense_pixels.shape == (2, 2, num_particles, width, height)

        with x_axis, y_axis:
            if observe:
                pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def test_model_with_sample_fn(sample_fn, broadcast=False):
    def model():
        fun(observe=True, sample_fn=sample_fn)

    @config_enumerate(default="parallel")
    def guide():
        fun(observe=False, sample_fn=sample_fn)

    if broadcast:
        model = poutine.broadcast(model)
        guide = poutine.broadcast(guide)
    print(test_model(model, guide, TraceEnum_ELBO(max_iarange_nesting=3)))

test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_automatic_broadcasting, broadcast=True)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting, broadcast=True)

1776.2379150390625
1776.2379150390625
1776.2379150390625
