## [HMM](http://pyro.ai/examples/hmm.html#example-hidden-markov-models)

#### This example shows how to marginalize out discrete model variables in Pyro.

#### This combines Stochastic Variational Inference (SVI) with a
variable elimination algorithm, where we use enumeration to exactly
marginalize out some variables from the ELBO computation. We might
call the resulting algorithm collapsed SVI or collapsed SGVB (i.e
collapsed Stochastic Gradient Variational Bayes). In the case where
we exactly sum out all the latent variables (as is the case here),
this algorithm reduces to a form of gradient-based Maximum
Likelihood Estimation.

To marginalize out discrete variables ``x`` in Pyro's SVI:

1. Verify that the variable dependency structure in your model
    admits tractable inference, i.e. the dependency graph among
    enumerated variables should have narrow treewidth.
2. Annotate each target each such sample site in the model
    with ``infer={"enumerate": "parallel"}``
3. Ensure your model can handle broadcasting of the sample values
    of those variables
4. Use the ``TraceEnum_ELBO`` loss inside Pyro's ``SVI``.

Note that empirical results for the models defined here can be found in
reference [1]. This paper also includes a description of the "tensor
variable elimination" algorithm that Pyro uses under the hood to
marginalize out discrete latent variables.


In [1]:
import argparse
import logging
import sys

In [2]:
import torch
import torch.nn as nn
from torch.distributions import constraints

In [3]:
import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [4]:
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.DEBUG)

In [5]:
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

![](hmm.png)

In [6]:
data = poly.load_data(poly.JSB_CHORALES)
sequences = data['train']['sequences']
lengths = data['train']['sequence_lengths']

sequences $\{y_1, \dots, y_T\}$ where each $y_t \in \{0,1\}^{88}$ denotes the presence or absence of 88 distinct notes.

In [7]:
sequences.shape

torch.Size([229, 129, 88])

In [8]:
sequences[1, :3]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1

In [9]:
lengths.shape

torch.Size([229])

In [10]:
alpha = torch.eye(4) * 0.9 + 0.1

In [11]:
pyro.sample('dir', dist.Dirichlet(alpha).to_event(1))

tensor([[9.3248e-01, 1.0739e-10, 1.7987e-11, 6.7516e-02],
        [2.4896e-01, 6.1552e-01, 1.3552e-01, 3.1968e-11],
        [8.0716e-02, 1.7845e-03, 7.8750e-02, 8.3875e-01],
        [1.1319e-02, 2.4762e-02, 6.1067e-09, 9.6392e-01]])

In [12]:
pyro.sample('piano',  dist.Beta(0.1,0.9).expand([3,4]).to_event(2))

tensor([[9.4654e-01, 2.4930e-07, 1.2369e-01, 1.2845e-05],
        [2.6934e-01, 1.9805e-07, 2.5934e-07, 9.2198e-03],
        [4.4610e-11, 7.0637e-04, 2.7294e-03, 3.1136e-01]])

In [13]:
def model_0(sequences, lengths, hidden_dim, batch_size=None, include_prior=True):
    #assert not torch._C_.get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample( # 16 x 16
            'probs_x',
            dist.Dirichlet(0.9 * torch.eye(hidden_dim)+ 0.1).to_event(1)
        )

        probs_y = pyro.sample( # 16 x 88
            'probs_y',
            dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2)
        )

    tones_plate = pyro.plate('tones', data_dim, dim=-1)

    for i in pyro.plate('sequences', len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample(
                'x_{}_{}'.format(i, t),
                dist.Categorical(probs_x[x]),
                infer={'enumerate': 'parallel'},
            )

            with tones_plate:
                pyro.sample(
                    'y_{}_{}'.format(i,t),
                    dist.Bernoulli(probs_y[x.squeeze(-1)]),
                    obs = sequence[t],

                )
        


In [14]:
guide = AutoDelta(
    poutine.block(model_0, expose_fn=lambda msg: msg['name'].startswith('probs_'))
)

In [15]:
first_available_dim = -2

In [16]:
guide_trace = poutine.trace(guide).get_trace(
    sequences, lengths, hidden_dim=16
)

In [17]:
model_trace = poutine.trace(
    poutine.replay(poutine.enum(model_0, first_available_dim), guide_trace)
    ).get_trace(
        sequences, lengths, hidden_dim = 16
    )

In [18]:
print(model_trace.format_shapes())

 Trace Shapes:                  
  Param Sites:                  
 Sample Sites:                  
  probs_x dist           | 16 16
         value           | 16 16
  probs_y dist           | 16 88
         value           | 16 88
    tones dist           |      
         value        88 |      
sequences dist           |      
         value       229 |      
    x_0_0 dist           |      
         value    16   1 |      
    y_0_0 dist    16  88 |      
         value        88 |      
    x_0_1 dist    16   1 |      
         value 16  1   1 |      
    y_0_1 dist 16  1  88 |      
         value        88 |      
    x_0_2 dist 16  1   1 |      
         value    16   1 |      
    y_0_2 dist    16  88 |      
         value        88 |      
    x_0_3 dist    16   1 |      
         value 16  1   1 |      
    y_0_3 dist 16  1  88 |      
         value        88 |      
    x_0_4 dist 16  1   1 |      
         value    16   1 |      
    y_0_4 dist    16  88 |      
         v

In [21]:
optim = Adam({'lr':0.01})
Elbo = TraceEnum_ELBO
elbo = Elbo(
    max_plate_nesting = 1,
    strict_enumeration_warning=True,
    jit_options={'time_compilation': 'store_true'}
)

In [22]:
svi = SVI(model_0, guide, optim, elbo)

In [25]:
num_steps = 50
pyro.set_rng_seed(111)
pyro.clear_param_store()
num_observations = float(lengths.sum())
for step in range(num_steps):
    loss = svi.step(sequences, lengths, hidden_dim=16)
    print('{:5d}\t{}'.format(step, loss / num_observations))

    0	286705.125
    1	286114.15625
    2	285524.5625
    3	284937.0
    4	284350.90625
    5	283766.46875
    6	283183.90625
    7	282602.875
    8	282023.375
    9	281445.6875
   10	280869.5625
   11	280294.96875
   12	279722.09375
   13	279150.65625
   14	278580.8125
   15	278012.40625
   16	277445.65625
   17	276880.40625
   18	276316.5625
   19	275754.21875
   20	275193.375
   21	274633.96875
   22	274076.03125
   23	273519.46875
   24	272964.34375
   25	272410.625
   26	271858.3125
   27	271307.21875
   28	270757.625
   29	270209.34375
   30	269662.46875
   31	269116.8125
   32	268572.46875
   33	268029.40625
   34	267487.65625
   35	266947.15625
   36	266407.96875
   37	265870.03125
   38	265333.21875
   39	264797.71875
   40	264263.28125
   41	263730.1875
   42	263198.21875
   43	262667.5625
   44	262137.921875
   45	261609.546875
   46	261082.125
   47	260556.109375
   48	260030.953125
   49	259507.21875


[Repo](https://github.com/pyro-ppl/pyro/blob/dev/examples/hmm.py)