In [5]:
import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [6]:
data = poly.load_data(poly.JSB_CHORALES)

In [7]:
seqs = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [8]:
hd = 4

In [9]:
batch_size = 10

In [10]:
globals().keys()



In [11]:
# Next consider a Factorial HMM with two hidden states.
#
#    w[t-1] ----> w[t] ---> w[t+1]
#        \ x[t-1] --\-> x[t] --\-> x[t+1]
#         \  /       \  /       \  /
#          \/         \/         \/
#        y[t-1]      y[t]      y[t+1]
#
# Note that since the joint distribution of each y[t] depends on two variables,
# those two variables become dependent. Therefore during enumeration, the
# entire joint space of these variables w[t],x[t] needs to be enumerated.
# For that reason, we set the dimension of each to the square root of the
# target hidden dimension.
#
# Note that this is the "FHMM" model in reference [1].

In [20]:
def Factorial_HMM(sequences, lengths, hd=4, batch_size=4, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    hidden_dim = int(hd**0.5)  # split between w and x
    with poutine.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_x = pyro.sample(
            "probs_x", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro.markov(range(max_length)):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample(
                    "w_{}".format(t),
                    dist.Categorical(probs_w[w]),
                    infer={"enumerate": "parallel"},
                )
                print(f'w t{t} is {w}')
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )
                print(f'x t{t} is {x}')
                with tones_plate as tones:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[w, x, tones]),
                        obs=sequences[batch, t],
                    )

In [21]:
Factorial_HMM(seqs, lengths)

w t0 is tensor([[0],
        [0],
        [0],
        [0]])
x t0 is tensor([[0],
        [0],
        [1],
        [0]])
w t1 is tensor([[0],
        [0],
        [0],
        [0]])
x t1 is tensor([[1],
        [0],
        [1],
        [0]])
w t2 is tensor([[0],
        [0],
        [0],
        [0]])
x t2 is tensor([[1],
        [0],
        [1],
        [0]])
w t3 is tensor([[0],
        [1],
        [0],
        [0]])
x t3 is tensor([[1],
        [0],
        [1],
        [0]])
w t4 is tensor([[0],
        [1],
        [0],
        [0]])
x t4 is tensor([[1],
        [0],
        [1],
        [0]])
w t5 is tensor([[0],
        [1],
        [0],
        [0]])
x t5 is tensor([[1],
        [0],
        [1],
        [0]])
w t6 is tensor([[0],
        [1],
        [0],
        [0]])
x t6 is tensor([[1],
        [0],
        [1],
        [0]])
w t7 is tensor([[0],
        [1],
        [0],
        [0]])
x t7 is tensor([[1],
        [0],
        [1],
        [0]])
w t8 is tensor([[0],
   