In [10]:
from __future__ import print_function
import os
from collections import defaultdict
import numpy as np
import scipy.stats
import torch
from torch.distributions import constraints
from pyro.distributions.util import broadcast_shape


import matplotlib.pyplot as plt
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, TraceGraph_ELBO, config_enumerate

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('0.3.0')
pyro.enable_validation(True)

def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

In [18]:
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None  # set to either True or False below

def fun(observe):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)

    # Note that the shapes of these sites depend on whether Pyro is enumerating.
    with x_axis:
        x_active = pyro.sample("x_active", dist.Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", dist.Bernoulli(p_y))
        
    if enumerated:
        assert x_active.shape  == (2, 1, 1)
        assert y_active.shape  == (2, 1, 1, 1)
    else:
        assert x_active.shape  == (width, 1)
        assert y_active.shape  == (height,)

    # The first trick is to broadcast. This works with or without enumeration.
    p = 0.1 + 0.5 * x_active * y_active 
       
    print("------")
    print("x_active.shape",x_active.shape)   
    print("y_active.shape",y_active.shape)
    print("p.shape",p.shape)
    
    if enumerated:
        assert p.shape == (2, 2, 1, 1)
    else:
        assert p.shape == (width, height)
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
    print("dense_pixels.shape",dense_pixels.shape)

    
    # The second trick is to index using ellipsis slicing.
    # This allows Pyro to add arbitrary dimensions on the left.
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1

    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)

    with x_axis, y_axis:
        if observe:
            pyro.sample("pixels", dist.Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

def guide4():
    fun(observe=False)

# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())

# Test with enumeration.
enumerated = True
test_model(model4, config_enumerate(guide4, "parallel"),
           TraceEnum_ELBO(max_plate_nesting=2))

------
x_active.shape torch.Size([8, 1])
y_active.shape torch.Size([10])
p.shape torch.Size([8, 10])
dense_pixels.shape torch.Size([8, 10])
------
x_active.shape torch.Size([8, 1])
y_active.shape torch.Size([10])
p.shape torch.Size([8, 10])
dense_pixels.shape torch.Size([8, 10])
------
x_active.shape torch.Size([2, 1, 1])
y_active.shape torch.Size([2, 1, 1, 1])
p.shape torch.Size([2, 2, 1, 1])
dense_pixels.shape torch.Size([2, 2, 8, 10])
------
x_active.shape torch.Size([2, 1, 1])
y_active.shape torch.Size([2, 1, 1, 1])
p.shape torch.Size([2, 2, 1, 1])
dense_pixels.shape torch.Size([2, 2, 8, 10])
