In [None]:
import argparse
import logging
import sys

import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [None]:
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

#### Next let's make our simple model faster in two ways: first we'll support vectorized minibatches of data, and second we'll support the PyTorch jit compiler.  To add batch support, we'll introduce a second plate "sequences" and randomly subsample data to size batch_size.  To add jit support we silence some warnings and try to avoid dynamic program structure.

#### Note that this is the "HMM" model in reference [1] (with the difference that in [1] the probabilities probs_x and probs_y are not MAP-regularized with Dirichlet and Beta distributions for any of the models)

In [None]:
globals()

In [None]:
models = {
    name[len("model_") :]: model
    for name, model in globals().items()
    if name.startswith("model_")
}

In [None]:
    logging.info("-" * 40)
    model = models[args.model]
    logging.info(
        "Training {} on {} sequences".format(
            model.__name__, len(data["train"]["sequences"])
        )
    )

In [61]:
data = poly.load_data(poly.JSB_CHORALES)

In [62]:
seqs = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [63]:
seqs.shape

torch.Size([229, 129, 88])

In [64]:
present_notes = (seqs == 1).sum(0).sum(0) > 0

In [65]:
#seqs = seqs[..., present_notes]

In [66]:
def model_1(seqs, lengths, hd=16, batch_size = None, include_prior=True):
    with ignore_jit_warnings():
        num_seqs, max_len, data_dim = map(int, seqs.shape)
        assert lengths.shape == (num_seqs,)
        assert lengths.max() == max_len
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            'probs_x',
            dist.Dirichlet(0.9 * torch.eye(hd) + 0.1).to_event(1),
        )

        probs_y = pyro.sample(
            'probs_y',
            dist.Beta(0.1, 0.9).expand([hd, data_dim]).to_event(2)
        )

    tones_plate = pyro.plate('tones', data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate('seqs', num_seqs, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.    
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    'x_{}'.format(t),
                    dist.Categorical(probs_x[x]),
                    infer={'enumerate': 'parallel'}
                )
                with tones_plate:
                    pyro.sample(
                        'y_{}'.format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs = seqs[batch, t]
                    )

        

In [67]:
guide = AutoDelta(
    poutine.block(model_1, expose_fn=lambda msg: msg['name'].startswith('probs_'))
)

In [68]:
first_available_dim = -3

In [69]:
guide_trace = poutine.trace(guide).get_trace(
    seqs, lengths, hd=16, batch_size=10
)

In [70]:
model_trace = poutine.trace(
    poutine.replay(poutine.enum(model_1, first_available_dim), guide_trace)
    ).get_trace(
        seqs, lengths, hd = 16, batch_size = 10
    )

In [71]:
print(model_trace.format_shapes())

Trace Shapes:                    
 Param Sites:                    
Sample Sites:                    
 probs_x dist             | 16 16
        value             | 16 16
 probs_y dist             | 16 88
        value             | 16 88
   tones dist             |      
        value          88 |      
    seqs dist             |      
        value          10 |      
     x_0 dist       10  1 |      
        value    16  1  1 |      
     y_0 dist    16 10 88 |      
        value       10 88 |      
     x_1 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_1 dist 16  1 10 88 |      
        value       10 88 |      
     x_2 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_2 dist    16 10 88 |      
        value       10 88 |      
     x_3 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_3 dist 16  1 10 88 |      
        value       10 88 |      
     x_4 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_4 dist 

In [72]:
# Notice that we're now using dim=-2 as a batch dimension (of size 10),
# and that the enumeration dimensions are now dims -3 and -4.

In [73]:
optim = Adam({'lr':0.01})
Elbo = TraceEnum_ELBO
elbo = Elbo(
    max_plate_nesting = 2,
    strict_enumeration_warning=True,
    jit_options={'time_compilation': 'store_true'}
)

In [74]:
svi = SVI(model_1, guide, optim, elbo)

In [75]:
num_steps = 100
pyro.set_rng_seed(111)
pyro.clear_param_store()
num_observations = float(lengths.sum())
for step in range(num_steps):
    loss = svi.step(seqs, lengths, hd=16, batch_size=20)
    print('{:5d}\t{}'.format(step, loss / num_observations))

    0	16.54754947671471
    1	16.05238511262403
    2	20.599012276381547
    3	15.632935286448904
    4	17.890405455566018
    5	16.278584911639022
    6	17.86796326138915
    7	17.998771003838634
    8	19.041405627580215
    9	18.651307986890707
   10	16.73214334214529
   11	18.139346210255667
   12	18.518398728905627
   13	17.985785063735786
   14	17.915470866227277
   15	19.886749022235097
   16	17.819775521474615
   17	18.399899054827262
   18	18.29599795393641
   19	18.50268319511842
   20	17.690746767943796
   21	18.390542387919172
   22	17.211871695516766
   23	18.905843729629897
   24	15.76610935576157
   25	16.072071684652713
   26	17.236328266459044
   27	14.772938998334178
   28	17.96401281958427
   29	16.961864905482727
   30	16.59761352936916
   31	16.452603299051205
   32	16.65044225755052
   33	15.514194566162091
   34	16.9913483649598
   35	17.34855281741146
   36	16.74412322553777
   37	17.462540513869776
   38	16.5800431846165
   39	16.83498859274281
   40	15.89796050

In [76]:
test_sequences = data['test']['sequences']
test_lengths = data['test']['sequence_lengths']

In [79]:
test_loss = elbo.loss(
    model_1,
    guide,
    test_sequences,
    test_lengths,
    hd = 16
)

In [82]:
num_observations = float(test_lengths.sum())

In [83]:
test_loss / num_observations

13.834368386243387