In [1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoNormal
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.3')
pyro.set_rng_seed(0)

In [10]:
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print(f"model z = {z}")

def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print(f"guide z = {z}")

elbo = Trace_ELBO()
elbo.loss(model, guide)

guide z = 3
model z = 3


-0.0

In [8]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"));

guide z = tensor([0, 1, 2, 3, 4])
model z = tensor([0, 1, 2, 3, 4])


In [9]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"));

guide z = 4
model z = 4
guide z = 3
model z = 3
guide z = 2
model z = 2
guide z = 1
model z = 1
guide z = 0
model z = 0


In [14]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print(f"  model x.shape = {x.shape}")
    print(f"  model y.shape = {y.shape}")
    print(f"  model z.shape = {z.shape}")
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
print("Sampling:")
model()
print("Enumerated Inference:")
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);

Sampling:
  model x.shape = torch.Size([])
  model y.shape = torch.Size([])
  model z.shape = torch.Size([])
Enumerated Inference:
  model x.shape = torch.Size([3])
  model y.shape = torch.Size([3, 1])
  model z.shape = torch.Size([3, 1, 1])
