In [1]:
import argparse
import logging
import sys

import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [2]:
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

#### Next let's make our simple model faster in two ways: first we'll support vectorized minibatches of data, and second we'll support the PyTorch jit compiler.  To add batch support, we'll introduce a second plate "sequences" and randomly subsample data to size batch_size.  To add jit support we silence some warnings and try to avoid dynamic program structure.

#### Note that this is the "HMM" model in reference [1] (with the difference that in [1] the probabilities probs_x and probs_y are not MAP-regularized with Dirichlet and Beta distributions for any of the models)

In [3]:
globals()

{'__name__': '__main__',
 '__doc__': 'Automatically created module for IPython interactive environment',
 '__package__': None,
 '__loader__': None,
 '__spec__': None,
 '__builtin__': <module 'builtins' (built-in)>,
 '__builtins__': <module 'builtins' (built-in)>,
 '_ih': ['',
  'log = logging.getLogger()\ndebug_handler = logging.StreamHandler(sys.stdout)\ndebug_handler.setLevel(logging.DEBUG)\ndebug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)\nlog.addHandler(debug_handler)',
  'globals()'],
 '_oh': {},
 '_dh': [PosixPath('/home/dulunche/GP_VAE/drclab/HMM')],
 'In': ['',
  'log = logging.getLogger()\ndebug_handler = logging.StreamHandler(sys.stdout)\ndebug_handler.setLevel(logging.DEBUG)\ndebug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)\nlog.addHandler(debug_handler)',
  'globals()'],
 'Out': {},
 'get_ipython': <bound method InteractiveShell.get_ipython of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7f8263010e20>>,
 'ex

In [4]:
models = {
    name[len("model_") :]: model
    for name, model in globals().items()
    if name.startswith("model_")
}

In [5]:
    logging.info("-" * 40)
    model = models[args.model]
    logging.info(
        "Training {} on {} sequences".format(
            model.__name__, len(data["train"]["sequences"])
        )
    )

NameError: name 'args' is not defined

In [19]:
data = poly.load_data(poly.JSB_CHORALES)

In [20]:
seqs = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [21]:
seqs.shape

torch.Size([229, 129, 88])

In [9]:
present_notes = (seqs == 1).sum(0).sum(0) > 0

In [10]:
seqs = seqs[..., present_notes].shape

In [26]:
def model_1(seqs, lengths, hd=16, batch_size = None, include_prior=True):
    with ignore_jit_warnings():
        num_seqs, max_len, data_dim = map(int, seqs.shape)
        assert lengths.shape == (num_seqs,)
        assert lengths.max() == max_len
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            'probs_x',
            dist.Dirichlet(0.9 * torch.eye(hd) + 0.1).to_event(1),
        )

        probs_y = pyro.sample(
            'probs_y',
            dist.Beta(0.1, 0.9).expand([hd, data_dim]).to_event(2)
        )

    tones_plate = pyro.plate('tones', data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate('seqs', num_seqs, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.    
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    'x_{}'.format(t),
                    dist.Categorical(probs_x[x]),
                    infer={'enumerate': 'parallel'}
                )
                with tones_plate:
                    pyro.sample(
                        'y_{}'.format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs = seqs[batch, t]
                    )

        

In [27]:
guide = AutoDelta(
    poutine.block(model_1, expose_fn=lambda msg: msg['name'].startswith('probs_'))
)

In [28]:
first_available_dim = -3

In [36]:
guide_trace = poutine.trace(guide).get_trace(
    seqs, lengths, hd=16, batch_size=10
)

In [37]:
model_trace = poutine.trace(
    poutine.replay(poutine.enum(model_1, first_available_dim), guide_trace)
    ).get_trace(
        seqs, lengths, hd = 16, batch_size = 10
    )

In [38]:
print(model_trace.format_shapes())

Trace Shapes:                    
 Param Sites:                    
Sample Sites:                    
 probs_x dist             | 16 16
        value             | 16 16
 probs_y dist             | 16 88
        value             | 16 88
   tones dist             |      
        value          88 |      
    seqs dist             |      
        value          10 |      
     x_0 dist       10  1 |      
        value    16  1  1 |      
     y_0 dist    16 10 88 |      
        value       10 88 |      
     x_1 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_1 dist 16  1 10 88 |      
        value       10 88 |      
     x_2 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_2 dist    16 10 88 |      
        value       10 88 |      
     x_3 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_3 dist 16  1 10 88 |      
        value       10 88 |      
     x_4 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_4 dist 

In [39]:
# Notice that we're now using dim=-2 as a batch dimension (of size 10),
# and that the enumeration dimensions are now dims -3 and -4.

In [46]:
optim = Adam({'lr':0.01})
Elbo = TraceEnum_ELBO
elbo = Elbo(
    max_plate_nesting = 2,
    strict_enumeration_warning=True,
    jit_options={'time_compilation': 'store_true'}
)

In [47]:
svi = SVI(model_1, guide, optim, elbo)

In [48]:
num_steps = 100
pyro.set_rng_seed(111)
pyro.clear_param_store()
num_observations = float(lengths.sum())
for step in range(num_steps):
    loss = svi.step(seqs, lengths, hd=16, batch_size=10)
    print('{:5d}\t{}'.format(step, loss / num_observations))

    0	18.4637400865503
    1	17.390832096038242
    2	25.213609944231187
    3	19.927412272760193
    4	21.255819059172882
    5	21.20640119142464
    6	22.527997573694503
    7	23.41423734337655
    8	19.075568552183675
    9	21.856055805026436
   10	21.62752589266314
   11	22.25317547258637
   12	23.159353045556603
   13	22.561205366842906
   14	19.430997410733685
   15	21.856992829724053
   16	21.886047385384224
   17	23.827109889910915
   18	23.37222513942203
   19	20.69009379300355
   20	19.842464420221628
   21	21.500803487361484
   22	20.57880060838705
   23	22.649360378793364
   24	19.153334812051856
   25	20.96150503367857
   26	23.913721300789454
   27	19.218283750995873
   28	21.88987017454914
   29	19.513831299340914
   30	18.63533669515463
   31	19.696211613674222
   32	21.65799503874846
   33	19.707084721518072
   34	20.851984047946694
   35	22.923166238140073
   36	21.598462283624247
   37	21.916543691605707
   38	22.998823060766277
   39	20.331457775041645
   40	19.7984