In [1]:
import argparse
import logging
import sys

import torch
import torch.nn as nn
from torch.distributions import constraints

import pyro
import pyro.contrib.examples.polyphonic_data_loader as poly
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO, TraceTMC_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings

In [2]:
log = logging.getLogger()
debug_handler = logging.StreamHandler(sys.stdout)
debug_handler.setLevel(logging.DEBUG)
debug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)
log.addHandler(debug_handler)

#### Next let's make our simple model faster in two ways: first we'll support vectorized minibatches of data, and second we'll support the PyTorch jit compiler.  To add batch support, we'll introduce a second plate "sequences" and randomly subsample data to size batch_size.  To add jit support we silence some warnings and try to avoid dynamic program structure.

#### Note that this is the "HMM" model in reference [1] (with the difference that in [1] the probabilities probs_x and probs_y are not MAP-regularized with Dirichlet and Beta distributions for any of the models)

In [3]:
globals()

{'__name__': '__main__',
 '__doc__': 'Automatically created module for IPython interactive environment',
 '__package__': None,
 '__loader__': None,
 '__spec__': None,
 '__builtin__': <module 'builtins' (built-in)>,
 '__builtins__': <module 'builtins' (built-in)>,
 '_ih': ['',
  'log = logging.getLogger()\ndebug_handler = logging.StreamHandler(sys.stdout)\ndebug_handler.setLevel(logging.DEBUG)\ndebug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)\nlog.addHandler(debug_handler)',
  'globals()'],
 '_oh': {},
 '_dh': [PosixPath('/home/dulunche/GP_VAE/drclab/HMM')],
 'In': ['',
  'log = logging.getLogger()\ndebug_handler = logging.StreamHandler(sys.stdout)\ndebug_handler.setLevel(logging.DEBUG)\ndebug_handler.addFilter(filter=lambda record: record.levelno <= logging.DEBUG)\nlog.addHandler(debug_handler)',
  'globals()'],
 'Out': {},
 'get_ipython': <bound method InteractiveShell.get_ipython of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7fa7ec8c8eb0>>,
 'ex

In [4]:
models = {
    name[len("model_") :]: model
    for name, model in globals().items()
    if name.startswith("model_")
}

In [88]:
    # logging.info("-" * 40)
    # model = models[args.model]
    # logging.info(
    #     "Training {} on {} sequences".format(
    #         model.__name__, len(data["train"]["sequences"])
    #     )
    # )

In [5]:
data = poly.load_data(poly.JSB_CHORALES)

In [6]:
seqs = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [7]:
seqs.shape

torch.Size([229, 129, 88])

In [8]:
present_notes = (seqs == 1).sum(0).sum(0) > 0

In [9]:
present_notes

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False])

In [10]:
seqs = seqs[..., present_notes]

In [11]:
lengths.shape

torch.Size([229])

In [12]:
def model_1(seqs, lengths, hd=16, batch_size = None, include_prior=True):
    with ignore_jit_warnings():
        num_seqs, max_len, data_dim = map(int, seqs.shape)
        assert lengths.shape == (num_seqs,)
        assert lengths.max() == max_len
    with poutine.mask(mask=include_prior):
        probs_x = pyro.sample(
            'probs_x',
            dist.Dirichlet(0.9 * torch.eye(hd) + 0.1).to_event(1),
        )

        probs_y = pyro.sample(
            'probs_y',
            dist.Beta(0.1, 0.9).expand([hd, data_dim]).to_event(2)
        )

    tones_plate = pyro.plate('tones', data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate('seqs', num_seqs, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.    
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    'x_{}'.format(t),
                    dist.Categorical(probs_x[x]),
                    infer={'enumerate': 'parallel'}
                )
                with tones_plate:
                    pyro.sample(
                        'y_{}'.format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs = seqs[batch, t]
                    )

        

In [13]:
guide = AutoDelta(
    poutine.block(model_1, expose_fn=lambda msg: msg['name'].startswith('probs_'))
)

In [14]:
first_available_dim = -3

In [18]:
guide_trace = poutine.trace(guide).get_trace(
    seqs, lengths, hd=16, batch_size=10
)

In [19]:
model_trace = poutine.trace(
    poutine.replay(poutine.enum(model_1, first_available_dim), guide_trace)
    ).get_trace(
        seqs, lengths, hd = 16, batch_size = 10
    )

In [20]:
print(model_trace.format_shapes())

Trace Shapes:                    
 Param Sites:                    
Sample Sites:                    
 probs_x dist             | 16 16
        value             | 16 16
 probs_y dist             | 16 51
        value             | 16 51
   tones dist             |      
        value          51 |      
    seqs dist             |      
        value          10 |      
     x_0 dist       10  1 |      
        value    16  1  1 |      
     y_0 dist    16 10 51 |      
        value       10 51 |      
     x_1 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_1 dist 16  1 10 51 |      
        value       10 51 |      
     x_2 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_2 dist    16 10 51 |      
        value       10 51 |      
     x_3 dist    16 10  1 |      
        value 16  1  1  1 |      
     y_3 dist 16  1 10 51 |      
        value       10 51 |      
     x_4 dist 16  1 10  1 |      
        value    16  1  1 |      
     y_4 dist 

In [155]:
# Notice that we're now using dim=-2 as a batch dimension (of size 10),
# and that the enumeration dimensions are now dims -3 and -4.

In [15]:
optim = Adam({'lr':0.01})
Elbo = TraceEnum_ELBO
elbo = Elbo(
    max_plate_nesting = 2,
    strict_enumeration_warning=True,
    jit_options={'time_compilation': 'store_true'}
)

In [16]:
svi = SVI(model_1, guide, optim, elbo)

In [17]:
num_steps = 100
pyro.set_rng_seed(111)
pyro.clear_param_store()
num_observations = float(lengths.sum())
for step in range(num_steps):
    loss = svi.step(seqs, lengths, hd=16, batch_size=20)
    print('{:5d}\t{}'.format(step, loss / num_observations))

    0	23.721404903309914
    1	21.806744314478163
    2	21.520815981024118
    3	22.301357554139205
    4	23.2621835844137
    5	24.690358604331138
    6	22.676959603824148
    7	21.166122709495184
    8	22.2997551061056
    9	22.612531234156588
   10	20.290681357282537
   11	22.70064776924748
   12	23.689865195190844
   13	20.331367241254437
   14	21.966402911566597
   15	19.65037435721011
   16	20.775209133048453
   17	23.21005649308322
   18	21.7058444086333
   19	22.641825704352865
   20	22.27959549503875
   21	21.486175490693128
   22	20.686854946766132
   23	22.14786430795973
   24	22.97311599188817
   25	20.880153635836894
   26	21.303213496776998
   27	20.999744242051133
   28	20.014709477076845
   29	22.11995726805244
   30	21.833583055696387
   31	20.167274751937423
   32	20.370391830231043
   33	21.05716982327805
   34	19.201077804736727
   35	23.498125950604766
   36	21.822762004780184
   37	20.17323413848048
   38	20.117970051423193
   39	20.47915459549504
   40	21.4792021

In [21]:
test_sequences = data['test']['sequences']
test_lengths = data['test']['sequence_lengths']

In [23]:
test_loss = elbo.loss(
    model_1,
    guide,
    test_sequences[..., present_notes],
    test_lengths,
    hd = 16
)

In [24]:
num_observations = float(test_lengths.sum())

In [25]:
test_loss / num_observations

18.32296130952381