## [HMM](http://pyro.ai/examples/hmm.html#example-hidden-markov-models)

#### This example shows how to marginalize out discrete model variables in Pyro.

#### This combines Stochastic Variational Inference (SVI) with a
variable elimination algorithm, where we use enumeration to exactly
marginalize out some variables from the ELBO computation. We might
call the resulting algorithm collapsed SVI or collapsed SGVB (i.e
collapsed Stochastic Gradient Variational Bayes). In the case where
we exactly sum out all the latent variables (as is the case here),
this algorithm reduces to a form of gradient-based Maximum
Likelihood Estimation.

To marginalize out discrete variables ``x`` in Pyro's SVI:

1. Verify that the variable dependency structure in your model
    admits tractable inference, i.e. the dependency graph among
    enumerated variables should have narrow treewidth.
2. Annotate each target each such sample site in the model
    with ``infer={"enumerate": "parallel"}``
3. Ensure your model can handle broadcasting of the sample values
    of those variables
4. Use the ``TraceEnum_ELBO`` loss inside Pyro's ``SVI``.

Note that empirical results for the models defined here can be found in
reference [1]. This paper also includes a description of the "tensor
variable elimination" algorithm that Pyro uses under the hood to
marginalize out discrete latent variables.


In [2]:
import argparse
import logging
import sys

In [3]:
import torch
import torch.nn as nn
from torch.distributions import constraints

In [4]:
import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [5]:
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.DEBUG)

In [6]:
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

![](hmm.png)

In [7]:
data = poly.load_data(poly.JSB_CHORALES)
sequences = data['train']['sequences']
lengths = data['train']['sequence_lengths']

sequences $\{y_1, \dots, y_T\}$ where each $y_t \in \{0,1\}^{88}$ denotes the presence or absence of 88 distinct notes.

In [8]:
sequences.shape

torch.Size([229, 129, 88])

In [9]:
sequences[1, :3]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1

In [10]:
lengths.shape

torch.Size([229])

In [17]:
alpha = torch.eye(4) * 0.9 + 0.1

In [21]:
pyro.sample('dir', dist.Dirichlet(alpha).to_event(1))

tensor([[6.3419e-01, 1.7992e-02, 3.4779e-01, 2.7824e-05],
        [9.2109e-01, 5.5405e-02, 2.3501e-02, 1.3257e-11],
        [6.7614e-05, 1.3328e-02, 9.8660e-01, 7.3104e-09],
        [2.5200e-02, 5.5516e-02, 1.3254e-01, 7.8674e-01]])

In [25]:
pyro.sample('piano',  dist.Beta(0.1,0.9).expand([3,4]).to_event(2))

tensor([[3.5649e-11, 3.1512e-05, 3.7805e-02, 7.9353e-01],
        [1.6339e-05, 2.0791e-02, 1.5849e-04, 4.9668e-11],
        [1.3496e-17, 1.1442e-10, 3.3311e-06, 1.1068e-03]])

In [None]:
def model_0(sequences, lengths, hidden_dim, batch_size=None, include_prior=True):
    assert not torch._C_.get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample( # 16 x 16
            'probs_x',
            dist.Dirichlet(0.9 * torch.eye([hidden_dim])+ 0.1).to_event(1)
        )

        probs_y = pyro.sample( # 16 x 88
            'probs_y',
            dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2)
        )

    tones_plate = pyro.plate('tones', data_dim, dim=-1)

    for i in pyro.plate('sequences', len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample(
                'x_{}_{}'.format(i, t),
                dist.Categorical(probs_x[x]),
                infer={'enumerate': 'parallel'},
            )

        
        
