In [1]:
from functools import reduce
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [3]:
from combinators import lens, sampler, utils
from combinators.inference import conditioning, resample
from combinators.model import collections
import examples.hmm.hmm as hmm

In [4]:
state_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2) * 2,
    'rate': torch.ones(5, 2) * 2,
}

In [5]:
prior = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, state_params), (1,), hmm.ParametersProposal(), lens.PRO(0), lens.PRO(4))
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), (1,), hmm.TransitionProposal(), lens.PRO(4), lens.PRO(4))

In [6]:
chain = collections.sequential(hmm_step, 50)

In [7]:
hmm_model = prior >> chain
hmm_sampler = sampler.compile(hmm_model)

In [8]:
sampler.filter(hmm_sampler)
p, log_weight = sampler.trace(hmm_sampler)

In [9]:
zs = [p[k].value for k in p if 'z' in k]

In [10]:
zs

[tensor([2]),
 tensor([1]),
 tensor([2]),
 tensor([1]),
 tensor([2]),
 tensor([4]),
 tensor([0]),
 tensor([4]),
 tensor([3]),
 tensor([3]),
 tensor([0]),
 tensor([1]),
 tensor([0]),
 tensor([1]),
 tensor([4]),
 tensor([3]),
 tensor([2]),
 tensor([4]),
 tensor([0]),
 tensor([3]),
 tensor([2]),
 tensor([1]),
 tensor([4]),
 tensor([0]),
 tensor([2]),
 tensor([1]),
 tensor([3]),
 tensor([0]),
 tensor([4]),
 tensor([3]),
 tensor([1]),
 tensor([4]),
 tensor([0]),
 tensor([1]),
 tensor([3]),
 tensor([3]),
 tensor([0]),
 tensor([2]),
 tensor([4]),
 tensor([2]),
 tensor([2]),
 tensor([2]),
 tensor([2]),
 tensor([4]),
 tensor([3]),
 tensor([3]),
 tensor([3]),
 tensor([3]),
 tensor([0]),
 tensor([1]),
 tensor([3])]

In [11]:
xs = [p[k].value for k in p if 'x' in k]

In [12]:
xs

[tensor([[3.5967, 1.4736]]),
 tensor([[3.5154, 4.2596]]),
 tensor([[1.9807, 1.9720]]),
 tensor([[1.1656, 1.4333]]),
 tensor([[2.1127, 4.1003]]),
 tensor([[11.1091,  5.5806]]),
 tensor([[1.1744, 0.8495]]),
 tensor([[8.5569, 7.9489]]),
 tensor([[4.5335, 7.5802]]),
 tensor([[1.8427, 7.0639]]),
 tensor([[0.0169, 0.9083]]),
 tensor([[0.9960, 2.1963]]),
 tensor([[1.5277, 0.6737]]),
 tensor([[3.1667, 1.5478]]),
 tensor([[7.2655, 7.6514]]),
 tensor([[-2.1437,  9.0428]]),
 tensor([[1.9585, 4.2542]]),
 tensor([[9.6790, 7.8078]]),
 tensor([[0.5741, 0.3282]]),
 tensor([[3.8368, 4.0328]]),
 tensor([[2.2000, 0.4654]]),
 tensor([[1.7091, 2.8359]]),
 tensor([[9.9724, 5.7619]]),
 tensor([[0.5057, 0.1197]]),
 tensor([[1.4625, 0.5163]]),
 tensor([[1.5817, 1.9312]]),
 tensor([[2.7219, 4.9598]]),
 tensor([[-1.1733,  1.5189]]),
 tensor([[8.7343, 5.9980]]),
 tensor([[1.8930, 7.4649]]),
 tensor([[2.6239, 2.3302]]),
 tensor([[6.9376, 7.2948]]),
 tensor([[0.9021, 0.8497]]),
 tensor([[-0.0568,  2.6382]]),
 tenso

In [13]:
num_particles = 250

In [14]:
inference_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2),
    'rate': torch.ones(5, 2),
}

In [15]:
params = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, inference_params), (num_particles,), hmm.ParametersProposal(), lens.PRO(0), lens.PRO(4))
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), (num_particles,), hmm.TransitionProposal(), lens.PRO(4), lens.PRO(4))

In [16]:
chain = collections.sequential(hmm_step, 50)

In [17]:
conditioned_chain = conditioning.SequentialConditioner(hmm_step=xs)(chain)

In [18]:
hmm_inference = params >> conditioned_chain
hmm_resampler = sampler.compile(hmm_inference)
resample.hook_resampling(hmm_resampler, method='get', resampler_cls=resample.MultinomialResampler)

In [19]:
sampler.filter(hmm_resampler)
p, log_weight = sampler.trace(hmm_resampler)

In [20]:
log_weight

tensor([-1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653, -1177.9653,
        -1177.9653, -1177.9653, -1177.96

In [21]:
inferred_zs = [p[k].value for k in p if 'z' in k]

In [22]:
accuracy = [(pr == tr).to(dtype=torch.float) for pr, tr in zip(inferred_zs, zs)]

In [23]:
for t in range(50):
    print('SMC percent accuracy at time %d: %f' % (t, accuracy[t].mean(dim=0) * 100))

SMC percent accuracy at time 0: 0.000000
SMC percent accuracy at time 1: 100.000000
SMC percent accuracy at time 2: 0.000000
SMC percent accuracy at time 3: 0.000000
SMC percent accuracy at time 4: 0.000000
SMC percent accuracy at time 5: 0.000000
SMC percent accuracy at time 6: 100.000000
SMC percent accuracy at time 7: 0.000000
SMC percent accuracy at time 8: 100.000000
SMC percent accuracy at time 9: 100.000000
SMC percent accuracy at time 10: 100.000000
SMC percent accuracy at time 11: 100.000000
SMC percent accuracy at time 12: 100.000000
SMC percent accuracy at time 13: 100.000000
SMC percent accuracy at time 14: 0.000000
SMC percent accuracy at time 15: 100.000000
SMC percent accuracy at time 16: 100.000000
SMC percent accuracy at time 17: 0.000000
SMC percent accuracy at time 18: 100.000000
SMC percent accuracy at time 19: 0.000000
SMC percent accuracy at time 20: 0.000000
SMC percent accuracy at time 21: 100.000000
SMC percent accuracy at time 22: 0.000000
SMC percent accuracy