In [None]:
import argparse
import logging
import sys

import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [None]:
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

#### Next let's make our simple model faster in two ways: first we'll support vectorized minibatches of data, and second we'll support the PyTorch jit compiler.  To add batch support, we'll introduce a second plate "sequences" and randomly subsample data to size batch_size.  To add jit support we silence some warnings and try to avoid dynamic program structure.

#### Note that this is the "HMM" model in reference [1] (with the difference that in [1] the probabilities probs_x and probs_y are not MAP-regularized with Dirichlet and Beta distributions for any of the models)

In [None]:
globals()

In [None]:
models = {
    name[len("model_") :]: model
    for name, model in globals().items()
    if name.startswith("model_")
}

In [None]:
    logging.info("-" * 40)
    model = models[args.model]
    logging.info(
        "Training {} on {} sequences".format(
            model.__name__, len(data["train"]["sequences"])
        )
    )

In [7]:
data = poly.load_data(poly.JSB_CHORALES)

In [16]:
seqs = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [25]:
seqs.shape

torch.Size([229, 129, 88])

In [28]:
present_notes = (seqs == 1).sum(0).sum(0) > 0

In [33]:
seqs = seqs[..., present_notes].shape

In [None]:
def model_1(seqs, lengths, hd=16, batch_size = None, include_prior=True):
    with ignore_jit_warnings():
        num_seqs, max_len, data_dim = map(int, seqs.shape)
        assert lengths.shape == (num_seqs,)
        assert lengths.max() == max_len
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            'probs_x',
            dist.Dirichlet(0.9 * torch.eye(args.hd) + 0.1).to_event(1),
        )

        probs_y = pyro.sample(
            'probs_y',
            dist.Beta(0.1, 0.9).expand([args.hd, data_dim]).to_event(2)
        )

    tones_plate = pyro.plate('tones', data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate('seqs', num_seqs, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.    


        