In [1]:
from discopy import PRO, Ty
from functools import reduce
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [3]:
from combinators import lens, sampler, utils
from combinators.inference import conditioning, resample
from combinators.model import collections
import examples.hmm.hmm as hmm

In [4]:
state_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2) * 2,
    'rate': torch.ones(5, 2) * 2,
}

In [5]:
prior = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, state_params), None, (1,), lens.LensPRO(0), PRO(4) & Ty())
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), None, (1,), PRO(4) & Ty(), PRO(4) & Ty())

In [6]:
chain = collections.sequential(hmm_step, 50)

In [7]:
hmm_model = prior >> chain
hmm_sampler = hmm_model.compile()

In [8]:
_, trace = hmm_model.trace(hmm_sampler)

In [9]:
_, p = trace.fold()

In [10]:
zs = [p[k].value for k in p if 'z' in k]

In [11]:
zs

[tensor([1]),
 tensor([2]),
 tensor([2]),
 tensor([1]),
 tensor([2]),
 tensor([2]),
 tensor([3]),
 tensor([4]),
 tensor([4]),
 tensor([2]),
 tensor([4]),
 tensor([2]),
 tensor([0]),
 tensor([2]),
 tensor([4]),
 tensor([2]),
 tensor([3]),
 tensor([4]),
 tensor([3]),
 tensor([4]),
 tensor([3]),
 tensor([1]),
 tensor([1]),
 tensor([2]),
 tensor([3]),
 tensor([1]),
 tensor([1]),
 tensor([2]),
 tensor([2]),
 tensor([4]),
 tensor([4]),
 tensor([2]),
 tensor([4]),
 tensor([2]),
 tensor([3]),
 tensor([1]),
 tensor([2]),
 tensor([0]),
 tensor([4]),
 tensor([3]),
 tensor([4]),
 tensor([2]),
 tensor([3]),
 tensor([0]),
 tensor([2]),
 tensor([2]),
 tensor([3]),
 tensor([0]),
 tensor([2]),
 tensor([3]),
 tensor([4])]

In [12]:
xs = [p[k].value for k in p if 'x' in k]

In [13]:
xs

[tensor([[1.4746, 3.3599]]),
 tensor([[3.9921, 4.0480]]),
 tensor([[4.9024, 3.5653]]),
 tensor([[ 4.6347, -2.8364]]),
 tensor([[4.4370, 3.4809]]),
 tensor([[4.3954, 5.1040]]),
 tensor([[ 5.7672, 10.4145]]),
 tensor([[9.0420, 5.3427]]),
 tensor([[8.6188, 3.6343]]),
 tensor([[4.8799, 5.2575]]),
 tensor([[8.5914, 5.2310]]),
 tensor([[5.1362, 3.1733]]),
 tensor([[-4.0906,  3.3254]]),
 tensor([[5.0562, 5.8982]]),
 tensor([[8.5278, 2.1567]]),
 tensor([[5.0269, 5.1158]]),
 tensor([[ 6.2899, 12.1606]]),
 tensor([[8.8641, 5.2949]]),
 tensor([[5.8415, 8.2887]]),
 tensor([[8.9697, 4.4221]]),
 tensor([[ 6.1824, 10.7855]]),
 tensor([[ 4.1662, -4.9828]]),
 tensor([[ 1.0008, -3.0696]]),
 tensor([[6.1971, 1.5096]]),
 tensor([[6.7948, 5.6344]]),
 tensor([[2.0574, 0.6066]]),
 tensor([[1.7141, 2.6838]]),
 tensor([[4.1885, 5.2205]]),
 tensor([[5.1534, 3.6736]]),
 tensor([[9.5779, 4.9779]]),
 tensor([[9.4242, 5.7606]]),
 tensor([[4.4452, 6.4319]]),
 tensor([[8.1938, 6.9407]]),
 tensor([[4.5168, 3.9082]]),


In [14]:
num_particles = 250

In [15]:
inference_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2),
    'rate': torch.ones(5, 2),
}

In [16]:
params = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, inference_params), None, (num_particles,), lens.LensPRO(0), PRO(4) & Ty())
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), None, (num_particles,), PRO(4) & Ty(), PRO(4) & Ty())

In [17]:
chain = collections.sequential(hmm_step, 50)

In [18]:
conditioned_chain = conditioning.SequentialConditioner(hmm_step=xs)(chain)

In [19]:
hmm_inference = params >> conditioned_chain
hmm_resampler = resample.resampler(hmm_inference.compile())
# hmm_resampler = hmm_inference.compile()

In [20]:
result, trace = hmm_inference.trace(hmm_resampler)

In [21]:
log_weight, p = trace.fold()

In [22]:
log_weight

tensor([-821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.2757, -821.2757, -821.2757, -821.2757,
        -821.2757, -821.2757, -821.275

In [23]:
inferred_zs = [p[k].value for k in p if 'z' in k]

In [24]:
accuracy = [(pr == tr).to(dtype=torch.float) for pr, tr in zip(inferred_zs, zs)]

In [25]:
for t in range(50):
    print('SMC percent accuracy at time %d: %f' % (t, accuracy[t].mean(dim=0) * 100))

SMC percent accuracy at time 0: 15.600000
SMC percent accuracy at time 1: 17.600000
SMC percent accuracy at time 2: 20.000000
SMC percent accuracy at time 3: 14.800000
SMC percent accuracy at time 4: 25.200001
SMC percent accuracy at time 5: 51.599998
SMC percent accuracy at time 6: 9.600000
SMC percent accuracy at time 7: 13.600000
SMC percent accuracy at time 8: 32.000000
SMC percent accuracy at time 9: 8.800000
SMC percent accuracy at time 10: 21.600000
SMC percent accuracy at time 11: 8.800000
SMC percent accuracy at time 12: 5.200000
SMC percent accuracy at time 13: 20.800001
SMC percent accuracy at time 14: 22.799999
SMC percent accuracy at time 15: 18.000000
SMC percent accuracy at time 16: 29.600000
SMC percent accuracy at time 17: 21.600000
SMC percent accuracy at time 18: 34.799999
SMC percent accuracy at time 19: 22.400002
SMC percent accuracy at time 20: 36.000000
SMC percent accuracy at time 21: 0.000000
SMC percent accuracy at time 22: 1.200000
SMC percent accuracy at tim