In [1]:
from functools import reduce
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [3]:
from combinators import lens, sampler, utils
from combinators.inference import conditioning, resample
from combinators.model import collections
import examples.hmm.hmm as hmm

In [4]:
state_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2) * 2,
    'rate': torch.ones(5, 2) * 2,
}

In [5]:
prior = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, state_params), (1,), hmm.ParametersProposal(), lens.PRO(0), lens.PRO(4))
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), (1,), hmm.TransitionProposal(), lens.PRO(4), lens.PRO(4))

In [6]:
chain = collections.sequential(hmm_step, 50)

In [7]:
hmm_model = prior >> chain
hmm_sampler = sampler.compile(hmm_model)

In [8]:
sampler.filter(hmm_sampler)
p, log_weight = sampler.trace(hmm_sampler)

In [9]:
zs = [p[k].value for k in p if 'z' in k]

In [10]:
zs

[tensor([4]),
 tensor([0]),
 tensor([0]),
 tensor([0]),
 tensor([1]),
 tensor([4]),
 tensor([4]),
 tensor([1]),
 tensor([4]),
 tensor([0]),
 tensor([0]),
 tensor([0]),
 tensor([4]),
 tensor([0]),
 tensor([1]),
 tensor([1]),
 tensor([4]),
 tensor([0]),
 tensor([3]),
 tensor([2]),
 tensor([1]),
 tensor([4]),
 tensor([0]),
 tensor([0]),
 tensor([0]),
 tensor([0]),
 tensor([1]),
 tensor([3]),
 tensor([1]),
 tensor([3]),
 tensor([4]),
 tensor([3]),
 tensor([2]),
 tensor([4]),
 tensor([0]),
 tensor([0]),
 tensor([1]),
 tensor([3]),
 tensor([3]),
 tensor([1]),
 tensor([3]),
 tensor([1]),
 tensor([3]),
 tensor([3]),
 tensor([1]),
 tensor([3]),
 tensor([1]),
 tensor([3]),
 tensor([2]),
 tensor([1]),
 tensor([4])]

In [11]:
xs = [p[k].value for k in p if 'x' in k]

In [12]:
xs

[tensor([[ 7.5328, 12.6918]]),
 tensor([[3.5403, 1.3866]]),
 tensor([[ 2.8297, -0.2054]]),
 tensor([[ 1.9631, -0.1380]]),
 tensor([[ 0.8396, -0.5853]]),
 tensor([[ 7.4074, 10.8746]]),
 tensor([[8.7990, 9.9349]]),
 tensor([[1.7775, 0.2576]]),
 tensor([[ 5.4967, 12.2110]]),
 tensor([[ 3.5204, -0.1720]]),
 tensor([[1.2444, 1.3263]]),
 tensor([[ 2.7883, -2.0661]]),
 tensor([[ 8.7756, 16.4658]]),
 tensor([[2.6608, 0.4371]]),
 tensor([[ 1.8057, -2.3360]]),
 tensor([[-1.4781,  0.4416]]),
 tensor([[ 4.9871, 12.6439]]),
 tensor([[-0.1186,  1.7229]]),
 tensor([[6.5152, 8.5877]]),
 tensor([[2.6754, 4.3759]]),
 tensor([[ 0.8380, -1.0770]]),
 tensor([[9.2146, 5.6433]]),
 tensor([[2.9043, 0.7807]]),
 tensor([[2.2589, 0.5164]]),
 tensor([[-0.3089,  0.5315]]),
 tensor([[2.9178, 0.0752]]),
 tensor([[1.0121, 1.6633]]),
 tensor([[8.0800, 8.2057]]),
 tensor([[2.8460, 0.2995]]),
 tensor([[7.2104, 7.6419]]),
 tensor([[ 6.9215, 15.2308]]),
 tensor([[ 8.7435, 10.2405]]),
 tensor([[0.8680, 2.3731]]),
 tensor([

In [13]:
num_particles = 250

In [14]:
inference_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2),
    'rate': torch.ones(5, 2),
}

In [15]:
params = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, inference_params), (num_particles,), hmm.ParametersProposal(), lens.PRO(0), lens.PRO(4))
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), (num_particles,), hmm.TransitionProposal(), lens.PRO(4), lens.PRO(4))

In [16]:
chain = collections.sequential(hmm_step, 50)

In [17]:
conditioned_chain = conditioning.SequentialConditioner(hmm_step=xs)(chain)

In [18]:
hmm_inference = params >> conditioned_chain
hmm_resampler = sampler.compile(hmm_inference)
resample.hook_resampling(hmm_resampler, method='get', resampler_cls=resample.SystematicResampler)

In [19]:
sampler.filter(hmm_resampler)
p, log_weight = sampler.trace(hmm_resampler)

In [20]:
log_weight

tensor([-740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.4091, -740.4091, -740.4091, -740.4091,
        -740.4091, -740.4091, -740.409

In [21]:
inferred_zs = [p[k].value for k in p if 'z' in k]

In [22]:
accuracy = [(pr == tr).to(dtype=torch.float) for pr, tr in zip(inferred_zs, zs)]

In [23]:
for t in range(50):
    print('SMC percent accuracy at time %d: %f' % (t, accuracy[t].mean(dim=0) * 100))

SMC percent accuracy at time 0: 100.000000
SMC percent accuracy at time 1: 100.000000
SMC percent accuracy at time 2: 0.000000
SMC percent accuracy at time 3: 0.000000
SMC percent accuracy at time 4: 100.000000
SMC percent accuracy at time 5: 100.000000
SMC percent accuracy at time 6: 100.000000
SMC percent accuracy at time 7: 100.000000
SMC percent accuracy at time 8: 0.000000
SMC percent accuracy at time 9: 0.000000
SMC percent accuracy at time 10: 0.000000
SMC percent accuracy at time 11: 100.000000
SMC percent accuracy at time 12: 100.000000
SMC percent accuracy at time 13: 100.000000
SMC percent accuracy at time 14: 100.000000
SMC percent accuracy at time 15: 0.000000
SMC percent accuracy at time 16: 100.000000
SMC percent accuracy at time 17: 100.000000
SMC percent accuracy at time 18: 83.600006
SMC percent accuracy at time 19: 83.600006
SMC percent accuracy at time 20: 100.000000
SMC percent accuracy at time 21: 0.000000
SMC percent accuracy at time 22: 0.000000
SMC percent accu