In [84]:
import argparse
import logging
import sys

import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [85]:
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

#### Next let's make our simple model faster in two ways: first we'll support vectorized minibatches of data, and second we'll support the PyTorch jit compiler.  To add batch support, we'll introduce a second plate "sequences" and randomly subsample data to size batch_size.  To add jit support we silence some warnings and try to avoid dynamic program structure.

#### Note that this is the "HMM" model in reference [1] (with the difference that in [1] the probabilities probs_x and probs_y are not MAP-regularized with Dirichlet and Beta distributions for any of the models)

In [86]:
globals()

{'__name__': '__main__',
 '__doc__': 'Automatically created module for IPython interactive environment',
 '__package__': None,
 '__loader__': None,
 '__spec__': None,
 '__builtin__': <module 'builtins' (built-in)>,
 '__builtins__': <module 'builtins' (built-in)>,
 '_ih': ['',
  'log = logging.getLogger()\ndebug_handler = logging.StreamHandler(sys.stdout)\ndebug_handler.setLevel(logging.DEBUG)\ndebug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)\nlog.addHandler(debug_handler)',
  'globals()',
  'models = {\n    name[len("model_") :]: model\n    for name, model in globals().items()\n    if name.startswith("model_")\n}',
  'logging.info("-" * 40)\nmodel = models[args.model]\nlogging.info(\n    "Training {} on {} sequences".format(\n        model.__name__, len(data["train"]["sequences"])\n    )\n)',
  'data = poly.load_data(poly.JSB_CHORALES)',
  "seqs = data['train']['sequences']\nlengths = data['train']['sequence_lengths']",
  'seqs.shape',
  'present_notes = (

In [87]:
models = {
    name[len("model_") :]: model
    for name, model in globals().items()
    if name.startswith("model_")
}

In [88]:
    # logging.info("-" * 40)
    # model = models[args.model]
    # logging.info(
    #     "Training {} on {} sequences".format(
    #         model.__name__, len(data["train"]["sequences"])
    #     )
    # )

In [143]:
data = poly.load_data(poly.JSB_CHORALES)

In [144]:
seqs = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [145]:
seqs.shape

torch.Size([229, 129, 88])

In [146]:
present_notes = (seqs == 1).sum(0).sum(0) > 0

In [163]:
present_notes

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False])

In [148]:
#seqs = seqs[..., present_notes]

In [149]:
def model_1(seqs, lengths, hd=16, batch_size = None, include_prior=True):
    with ignore_jit_warnings():
        num_seqs, max_len, data_dim = map(int, seqs.shape)
        assert lengths.shape == (num_seqs,)
        assert lengths.max() == max_len
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            'probs_x',
            dist.Dirichlet(0.9 * torch.eye(hd) + 0.1).to_event(1),
        )

        probs_y = pyro.sample(
            'probs_y',
            dist.Beta(0.1, 0.9).expand([hd, data_dim]).to_event(2)
        )

    tones_plate = pyro.plate('tones', data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate('seqs', num_seqs, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.    
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    'x_{}'.format(t),
                    dist.Categorical(probs_x[x]),
                    infer={'enumerate': 'parallel'}
                )
                with tones_plate:
                    pyro.sample(
                        'y_{}'.format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs = seqs[batch, t]
                    )

        

In [150]:
guide = AutoDelta(
    poutine.block(model_1, expose_fn=lambda msg: msg['name'].startswith('probs_'))
)

In [151]:
first_available_dim = -3

In [152]:
guide_trace = poutine.trace(guide).get_trace(
    seqs, lengths, hd=16, batch_size=10
)

In [153]:
model_trace = poutine.trace(
    poutine.replay(poutine.enum(model_1, first_available_dim), guide_trace)
    ).get_trace(
        seqs, lengths, hd = 16, batch_size = 10
    )

In [154]:
print(model_trace.format_shapes())

Trace Shapes:                    
 Param Sites:                    
Sample Sites:                    
 probs_x dist             | 16 16
        value             | 16 16
 probs_y dist             | 16 88
        value             | 16 88
   tones dist             |      
        value          88 |      
    seqs dist             |      
        value          10 |      
     x_0 dist       10  1 |      
        value    16  1  1 |      
     y_0 dist    16 10 88 |      
        value       10 88 |      
     x_1 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_1 dist 16  1 10 88 |      
        value       10 88 |      
     x_2 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_2 dist    16 10 88 |      
        value       10 88 |      
     x_3 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_3 dist 16  1 10 88 |      
        value       10 88 |      
     x_4 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_4 dist 

In [155]:
# Notice that we're now using dim=-2 as a batch dimension (of size 10),
# and that the enumeration dimensions are now dims -3 and -4.

In [156]:
optim = Adam({'lr':0.01})
Elbo = TraceEnum_ELBO
elbo = Elbo(
    max_plate_nesting = 2,
    strict_enumeration_warning=True,
    jit_options={'time_compilation': 'store_true'}
)

In [157]:
svi = SVI(model_1, guide, optim, elbo)

In [158]:
num_steps = 100
pyro.set_rng_seed(111)
pyro.clear_param_store()
num_observations = float(lengths.sum())
for step in range(num_steps):
    loss = svi.step(seqs, lengths, hd=16, batch_size=20)
    print('{:5d}\t{}'.format(step, loss / num_observations))

    0	11.25366548670964
    1	11.154480064460056
    2	14.366844037444775
    3	10.84172996487289
    4	12.35001946476425
    5	11.265643106757443
    6	12.534415287535309
    7	12.395510429492287
    8	13.153589438328384
    9	13.04273194756283
   10	11.651775367567176
   11	12.644023412037372
   12	13.04648344137032
   13	12.399373958861448
   14	12.423139078003912
   15	13.84249950206417
   16	12.357624302889839
   17	12.813258220467878
   18	12.750212754399943
   19	12.819715542840589
   20	12.382948866516983
   21	12.744561182733396
   22	11.976045891576737
   23	13.548979910552617
   24	10.964897787354241
   25	11.199694901137104
   26	11.954057498008257
   27	10.272297340117332
   28	12.355637086260593
   29	11.913723564134134
   30	11.729987506337364
   31	11.510302744984429
   32	11.706931945752155
   33	10.793706317447672
   34	11.96227909755921
   35	12.137931619830521
   36	11.811151046570581
   37	12.171738067646846
   38	11.689217878612299
   39	11.821370047801839
   40	1

In [159]:
test_sequences = data['test']['sequences']
test_lengths = data['test']['sequence_lengths']

In [160]:
test_loss = elbo.loss(
    model_1,
    guide,
    test_sequences,
    test_lengths,
    hd = 16
)

In [161]:
num_observations = float(test_lengths.sum())

In [162]:
test_loss / num_observations

9.59161789021164