In [2]:
import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [3]:
data = poly.load_data(poly.JSB_CHORALES)

In [4]:
seqs = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [5]:
hd = 4

In [6]:
batch_size = 10

In [7]:
globals().keys()



In [19]:
# Next let's add a dependency of y[t] on y[t-1].
#
#     x[t-1] --> x[t] --> x[t+1]
#        |        |         |
#        V        V         V
#     y[t-1] --> y[t] --> y[t+1]
#
# Note that this is the "arHMM" model in reference [1].

In [28]:
def model_2(sequences, lengths, hd=4, batch_size=4, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hd) + 0.1).to_event(1),
        )
        print(probs_x)
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hd, 2, data_dim]).to_event(3),
        )
        print(probs_y.shape)
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x, y = 0, 0
        for t in pyro.markov(range(max_length)):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )
                print('x {} {}'.format(t, x))
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with tones_plate as tones:
                    print("tones is {}".format(tones))
                    y = pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x, y, tones]),
                        obs=sequences[batch, t],
                    ).long()
                    print('y_{} is {}'.format(t, y))

In [29]:
model_2(seqs, lengths)

tensor([[7.6349e-01, 6.6955e-07, 2.0250e-01, 3.4009e-02],
        [3.5925e-03, 9.5878e-02, 9.0052e-01, 1.2116e-05],
        [7.0697e-02, 7.0483e-08, 9.2815e-01, 1.1479e-03],
        [8.8665e-08, 1.1867e-12, 5.0175e-03, 9.9498e-01]])
torch.Size([4, 2, 88])
x 0 tensor([[0],
        [0],
        [0],
        [0]])
tones is tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87])
y_0 is tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0,