## [Inference with Discrete Latent Variables](http://pyro.ai/examples/enumeration.html#Inference-with-Discrete-Latent-Variables)

In [1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoNormal

In [2]:
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.2')
pyro.set_rng_seed(0)

##### Pyro’s enumeration strategy (Obermeyer et al. 2019) encompasses popular algorithms including variable elimination, exact message passing, forward-filter-backward-sample, inside-out, Baum-Welch, and many other special-case algorithms. Aside from enumeration, Pyro implements a number of inference strategies including variational inference (SVI) and monte carlo (HMC and NUTS). 

##### Enumeration can be used either as a stand-alone strategy via infer_discrete, or as a component of other strategies. Thus enumeration allows Pyro to marginalize out discrete latent variables in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides.

The core idea of enumeration is to interpret discrete pyro.sample statements as full enumeration rather than random sampling. Other inference algorithms can then sum out the enumerated values. 

In [3]:
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print(f"model z ={z}")

In [4]:
def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print(f"guide z ={z}")

In [5]:
elbo = Trace_ELBO()

In [6]:
elbo.loss(model, guide)

guide z =4
model z =4


-0.0

In [8]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"))

guide z =tensor([0, 1, 2, 3, 4])
model z =tensor([0, 1, 2, 3, 4])


-0.0

In [9]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"))

guide z =4
model z =4
guide z =3
model z =3
guide z =2
model z =2
guide z =1
model z =1
guide z =0
model z =0


-0.0

##### Parallel enumeration is cheaper but more complex than sequential enumeration, so we’ll focus the rest of this tutorial on the parallel variant. Note that both forms can be interleaved.

#####  A model with a single discrete latent variable is a mixture model. Models with multiple discrete latent variables can be more complex, including HMMs, CRFs, DBNs, and other structured models. In models with multiple discrete latent variables, Pyro enumerates each variable in a different tensor dimension (counting from the right; see Tensor Shapes Tutorial). This allows Pyro to determine the dependency graph among variables and then perform cheap exact inference using variable elimination algorithms.

In [14]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)

    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))

    print(f"  model x.shape = {x.shape}")
    print(f"  model y.shape = {y.shape}")
    print(f"  model z.shape = {z.shape}")

    return x,y,z

def guide():
    pass


In [16]:
pyro.clear_param_store()
print("Sampling")
model()
print("Enumerated Inference:")
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide)

Sampling
  model x.shape = torch.Size([])
  model y.shape = torch.Size([])
  model z.shape = torch.Size([])
Enumerated Inference:
  model x.shape = torch.Size([3])
  model y.shape = torch.Size([3, 1])
  model z.shape = torch.Size([3, 1, 1])


-0.0

In [18]:
serving_model = infer_discrete(model, first_available_dim=-1)
x,y,z = serving_model() # takes the same args as model()
print(f"x = {x}")
print(f"y = {y}")
print(f"z = {z}")

  model x.shape = torch.Size([3])
  model y.shape = torch.Size([3, 1])
  model z.shape = torch.Size([3, 1, 1])
  model x.shape = torch.Size([])
  model y.shape = torch.Size([])
  model z.shape = torch.Size([])
x = 0
y = 0
z = 1


##### Notice that under the hood infer_discrete runs the model twice: first in forward-filter mode where sites are enumerated, then in replay-backward-sample model where sites are sampled. infer_discrete can also perform MAP inference by passing temperature=0. Note that while infer_discrete produces correct posterior samples, it does not currently produce correct logprobs, and should not be used in other gradient-based inference algorthms.

### [Indexing with enumerated variables](http://pyro.ai/examples/enumeration.html#Indexing-with-enumerated-variables)