# Inference with Discrete Latent Variables

In [1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation()
pyro.set_rng_seed(0)

## Mechanics of enumeration

传统的来进行discrete Variables采样的方法：

In [4]:
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print("model z = {}".format(z))
    
def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print("guide z = {}".format(z))
    
elbo = Trace_ELBO()
elbo.loss(model, guide)

guide z = 3
model z = 3


-0.0

当我们使用enumeration的时候，pyro会把应用于discrete variables的`sample`解释为enumerate，就会有以下的结果：

In [5]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"))

guide z = tensor([0, 1, 2, 3, 4])
model z = tensor([0, 1, 2, 3, 4])


-0.0

可以看到我们的样本的维度发生了变化，这来自于distribution的`enumerate_support()`方法。

可以看到，我们使用的enumerate方式是`parallel`，这在效率上是高效的，但因为改变了sample的维度，所以在后续code中必须加上一些特定的步骤来防止出现错误的结果。

另外我们还可以选择的enumerate方式是`sequential`，其每次采样的结果和非enumerate时是一样的，但会采样多次，直到所有的discrete variables pairs都遍历完。这使得我们能够更加灵活的去实现模型，但也增加了运行成本。

In [6]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"))

guide z = 4
model z = 4
guide z = 3
model z = 3
guide z = 2
model z = 2
guide z = 1
model z = 1
guide z = 0
model z = 0


-0.0

## Multiple Latent variables

如果有多个discrete variables？

In [8]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print("model x.shape = {}".format(x.shape))
    print("model y.shape = {}".format(y.shape))
    print("model z.shape = {}".format(z.shape))
    return x, y, z
    
def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide)

model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])


-0.0

我们可以看到，这在“TensorShape”中我们介绍过了：enumeration mode会分配每个discrete variables一个dimension来储存其enumeration，但为了节约空间，这些dimensions除了最左边的那一个，其他都是1。

In [9]:
model()

model x.shape = torch.Size([])
model y.shape = torch.Size([])
model z.shape = torch.Size([])


(tensor(2), tensor(0), tensor(0))

## Examing discrete latent states

我们可以看到，如果使用了enumeration，我们得到的内容就是包括了enumeration dimensions的tensor了。为了能够得到这些discrete latent variables的采样值，我们可以使用`infer_discrete`来wrap一下model，其中需要参数`first_avarible_dim=-1-max_plate_nesting`。

In [10]:
serving_model = infer_discrete(model, first_available_dim=-1)
x, y, z = serving_model()   # The args is same args as model
print("x = {}".format(x))
print("y = {}".format(y))
print("z = {}".format(z))

model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])
model x.shape = torch.Size([])
model y.shape = torch.Size([])
model z.shape = torch.Size([])
x = 2
y = 0
z = 0


可以看到，`infer_discrete`实际上运行了两次model（forward-filter mode and replay-backward-sample model）。

## Indexing with enumerated variables

In [26]:
pyro.clear_param_store()
p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(), constraint=constraints.simplex)
x = pyro.sample("x", dist.Categorical(torch.ones(4)))
y = pyro.sample("y", dist.Categorical(torch.ones(3)))
with pyro.plate("z_plate", 5):
    p_xy = p[..., x, y, :]
    z = pyro.sample("z", dist.Categorical(p_xy))
    print("p.shape = {}".format(p.shape))
    print("x.shape = {}".format(x.shape))
    print("y.shape = {}".format(y.shape))
    print("p_xy.shape = {}".format(p_xy.shape))
    print("z.shape = {}".format(z.shape))

p.shape = torch.Size([5, 4, 3, 2])
x.shape = torch.Size([])
y.shape = torch.Size([])
p_xy.shape = torch.Size([5, 2])
z.shape = torch.Size([5])


以上方式运行成功，因为我们还没有使用enumeration mode。如果使用了enumeration mode，则x和y就不是一个scalar了，则其也不能用于index了。

一种纯依靠pytorch来解决的办法是这样的：

In [28]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(torch.ones(4)))
    y = pyro.sample("y", dist.Categorical(torch.ones(3)))
    with pyro.plate("z_plate", 5):
        p_xy = p[
            torch.arange(5, device=p.device).reshape(5, 1),     # 这里使用的是tensor的advanced indexing，这里所有维度的index会被broadcasting成相同的shape，
            x.unsqueeze(-1),                                    # 然后结果的shape和此shape相同。
            y.unsqueeze(-1),                                    # 比如：a[seq1, seq2, ...]，seq1和seq2等需要有相同的shape（或者可以被boardcasting成相同的shape），
            torch.arange(2, device=p.device)                    # 然后，结果的shape和其相同，res[i1, i2, ...] = a[seq1[i1, i2, ...], seq2[i1, i2, ...], ...]
        ]                                                       # 这里是将这4个tensor先broadcasting了：(5, 1), (4, 1, 1), (3, 1, 1, 1), (2)
        z = pyro.sample("z", dist.Categorical(p_xy))            #       5, 1
    print("p.shape = {}".format(p.shape))                       #    4, 1, 1
    print("x.shape = {}".format(x.shape))                       # 3 ,1, 1, 1
    print("y.shape = {}".format(y.shape))                       #          2
    print("p_xy.shape = {}".format(p_xy.shape))                 #          =
    print("z.shape = {}".format(z.shape))                       # 3, 4, 5, 2
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide)

p.shape = torch.Size([5, 4, 3, 2])
x.shape = torch.Size([4, 1])
y.shape = torch.Size([3, 1, 1])
p_xy.shape = torch.Size([3, 4, 5, 2])
z.shape = torch.Size([2, 1, 1, 1])


-0.0

但是，pyro提供了一个工具`pyro.ops.indexing.Vindex`，使得我们可以使用non-enumeration的语法来对enumeration mode下的tensor进行index。其先对index进行类似的broadcasting。

In [29]:
pyro.clear_param_store()
p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(), constraint=constraints.simplex)
x = pyro.sample("x", dist.Categorical(torch.ones(4)))
y = pyro.sample("y", dist.Categorical(torch.ones(3)))
with pyro.plate("z_plate", 5):
    p_xy = Vindex(p)[..., x, y, :]
    z = pyro.sample("z", dist.Categorical(p_xy))
    print("p.shape = {}".format(p.shape))
    print("x.shape = {}".format(x.shape))
    print("y.shape = {}".format(y.shape))
    print("p_xy.shape = {}".format(p_xy.shape))
    print("z.shape = {}".format(z.shape))

p.shape = torch.Size([5, 4, 3, 2])
x.shape = torch.Size([])
y.shape = torch.Size([])
p_xy.shape = torch.Size([5, 2])
z.shape = torch.Size([5])


In [30]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(torch.ones(4)))
    y = pyro.sample("y", dist.Categorical(torch.ones(3)))
    with pyro.plate("z_plate", 5):
        p_xy = Vindex(p)[..., x, y, :]
        z = pyro.sample("z", dist.Categorical(p_xy))            
    print("p.shape = {}".format(p.shape))                       
    print("x.shape = {}".format(x.shape))                       
    print("y.shape = {}".format(y.shape))                       
    print("p_xy.shape = {}".format(p_xy.shape))                
    print("z.shape = {}".format(z.shape))                       
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide)

p.shape = torch.Size([5, 4, 3, 2])
x.shape = torch.Size([4, 1])
y.shape = torch.Size([3, 1, 1])
p_xy.shape = torch.Size([3, 4, 5, 2])
z.shape = torch.Size([2, 1, 1, 1])


-0.0

## Plates and enumeration

pyro的enumeration strategy可以利用plates来reduce其cost。

下面是一个例子，其构建了一个有着不同mean、相同variance的gaussian mixture model（具体的gmm模型的介绍，要到后面进行）：

In [31]:
@config_enumerate
def model(data, num_components=3):
    print("Runing model with {} data points".format(len(data)))
    p = pyro.sample("p", dist.Dirichlet(0.5 * torch.ones(n_components)))   # beta分布的多维形式，用于每个数据点的components分布（categorical dist）的参数的先验
    scale = pyro.sample("scale", dist.LogNormal(0, num_components))        # 每个子gaussian dist的scale的分布
    with pyro.plate("components", num_components):                         # 每个子gaussian dist的mean是不同的，但其先验是相同的，都是N(0, 10)
        loc = pyro.sample("loc", dist.Normal(0, 10))
    with pyro.plate("data", len(data)):
        x = pyro.sample("x", dist.Categorical(p))                          # 每个样本属于哪个components，进行采样
        print("x.shape = {}".format(x.shape))
        d = dist.Normal(loc[x], scale)                                     # 根据属于哪个components，进行gaussian采样
        pyro.sample("obs", d, obs=data)
        print("dist.Normal(loc[x], scale).batch_shape = {}".format(d.batch_shape))
        
guide = AutoDiagonalNormal(poutine.block(model, hide=["x", "data"]))

data = torch.randn(10)

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide, data)

Runing model with 10 data points
x.shape = torch.Size([10])
dist.Normal(loc[x], scale).batch_shape = torch.Size([10])
Runing model with 10 data points
x.shape = torch.Size([3, 1])
dist.Normal(loc[x], scale).batch_shape = torch.Size([3, 1])


41.60341262817383

我们看到，实际上每次这个model被运行了2次：

1. 第一次是`AutoDiagonalNormal`运行的；
2. 第二次是`elbo`去运行来计算loss，enumeration先为每一个data point都生成一个相同的3x1 tensor，然后使用broadcasting来实现快速计算。

## Dependencies among plates

enumeration的计算节省是来自于dependency的restrictions，这些restrictions一般来说被`TraceEnum_ELBO`检查并添加到传统的independences中。这些可以被自动识别的restrictions一共有3种，如果违反了则会报错：

第一种，在一个plate中的每个variable不能依赖于其他的任何一项。

第二种，矢量化的plate外的变量不能依赖于其内的enumerated变量，比如：

In [41]:
@config_enumerate
def invalid_model(data):
    with pyro.plate("plate", 10):
        x = pyro.sample("x", dist.Bernoulli(0.5))
    assert x.shape == (10,)
    pyro.sample("obs", dist.Normal(x.sum(), 1.), data)

如果希望其成功，需要将矢量化的plate改成sequence的plate

In [42]:
@config_enumerate
def valid_model(data):
    x = []
    for i in pyro.plate("plate", 10):
        x.append(pyro.sample("x_{}".format(i), dist.Bernoulli(0.5)))
    assert len(x) == 10
    pyro.sample("obs", dist.Normal(sum(x), 1.), data)