In [1]:
from discopy import PRO, Ty
from functools import reduce
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [3]:
from combinators import lens, sampler, utils
from combinators.inference import conditioning, resample
from combinators.model import collections
import examples.hmm.hmm as hmm

In [4]:
state_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2) * 2,
    'rate': torch.ones(5, 2) * 2,
}

In [5]:
prior = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, state_params), None, (1,), lens.LensPRO(0), PRO(4) & Ty())
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), None, (1,), PRO(4) & Ty(), PRO(4) & Ty())

In [6]:
chain = collections.sequential(hmm_step, 50)

In [7]:
hmm_model = prior >> chain
hmm_sampler = hmm_model.compile()

In [8]:
_, trace = hmm_model.trace(hmm_sampler)

In [9]:
_, p = trace.fold()

In [10]:
zs = [p[k].value for k in p if 'z' in k]

In [11]:
zs

[tensor([3]),
 tensor([0]),
 tensor([4]),
 tensor([2]),
 tensor([1]),
 tensor([4]),
 tensor([3]),
 tensor([4]),
 tensor([2]),
 tensor([2]),
 tensor([4]),
 tensor([2]),
 tensor([3]),
 tensor([0]),
 tensor([1]),
 tensor([1]),
 tensor([1]),
 tensor([2]),
 tensor([1]),
 tensor([2]),
 tensor([1]),
 tensor([1]),
 tensor([4]),
 tensor([3]),
 tensor([3]),
 tensor([3]),
 tensor([4]),
 tensor([3]),
 tensor([3]),
 tensor([0]),
 tensor([0]),
 tensor([1]),
 tensor([2]),
 tensor([1]),
 tensor([1]),
 tensor([1]),
 tensor([4]),
 tensor([3]),
 tensor([2]),
 tensor([0]),
 tensor([0]),
 tensor([0]),
 tensor([1]),
 tensor([1]),
 tensor([1]),
 tensor([4]),
 tensor([3]),
 tensor([4]),
 tensor([3]),
 tensor([2]),
 tensor([3])]

In [12]:
xs = [p[k].value for k in p if 'x' in k]

In [13]:
xs

[tensor([[6.1955, 5.8491]]),
 tensor([[-0.0691, -1.6537]]),
 tensor([[ 9.1041, 12.9941]]),
 tensor([[1.8954, 4.5132]]),
 tensor([[4.0078, 2.3241]]),
 tensor([[ 9.4159, 16.2984]]),
 tensor([[4.3161, 4.5761]]),
 tensor([[10.2149, 19.3497]]),
 tensor([[5.6183, 5.4863]]),
 tensor([[5.1698, 3.7430]]),
 tensor([[10.9026, 13.1368]]),
 tensor([[5.1688, 5.1994]]),
 tensor([[6.8682, 4.9132]]),
 tensor([[-0.4011, -0.0201]]),
 tensor([[3.7056, 1.7032]]),
 tensor([[2.7756, 0.9904]]),
 tensor([[2.7033, 3.8080]]),
 tensor([[7.9982, 6.3588]]),
 tensor([[3.0295, 1.9842]]),
 tensor([[4.7519, 4.3505]]),
 tensor([[2.0645, 0.4220]]),
 tensor([[3.0319, 4.1446]]),
 tensor([[ 9.2503, 23.3514]]),
 tensor([[2.5907, 6.4205]]),
 tensor([[3.6691, 4.9724]]),
 tensor([[5.7848, 4.9801]]),
 tensor([[ 8.6790, 14.0938]]),
 tensor([[4.9799, 6.2408]]),
 tensor([[2.6180, 6.0833]]),
 tensor([[-0.0958, -0.3813]]),
 tensor([[ 0.9988, -0.2101]]),
 tensor([[2.9947, 1.5811]]),
 tensor([[7.4542, 5.5030]]),
 tensor([[2.6193, 1.528

In [14]:
num_particles = 250

In [15]:
inference_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2),
    'rate': torch.ones(5, 2),
}

In [16]:
params = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, inference_params), None, (num_particles,), lens.LensPRO(0), PRO(4) & Ty())
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), None, (num_particles,), PRO(4) & Ty(), PRO(4) & Ty())

In [17]:
chain = collections.sequential(hmm_step, 50)

In [18]:
conditioned_chain = conditioning.SequentialConditioner(hmm_step=xs)(chain)

In [19]:
hmm_inference = params >> conditioned_chain
hmm_resampler = resample.resampler(hmm_inference.compile())
# hmm_resampler = hmm_inference.compile()

In [20]:
result, trace = hmm_inference.trace(hmm_resampler)

In [21]:
log_weight, p = trace.fold()

In [22]:
log_weight

tensor([-1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961, -1303.3961,
        -1303.3961, -1303.3961, -1303.39

In [23]:
inferred_zs = [p[k].value for k in p if 'z' in k]

In [24]:
accuracy = [(pr == tr).to(dtype=torch.float) for pr, tr in zip(inferred_zs, zs)]

In [25]:
for t in range(50):
    print('SMC percent accuracy at time %d: %f' % (t, accuracy[t].mean(dim=0) * 100))

SMC percent accuracy at time 0: 24.000000
SMC percent accuracy at time 1: 28.799999
SMC percent accuracy at time 2: 17.200001
SMC percent accuracy at time 3: 20.800001
SMC percent accuracy at time 4: 14.800000
SMC percent accuracy at time 5: 17.200001
SMC percent accuracy at time 6: 12.400000
SMC percent accuracy at time 7: 29.600000
SMC percent accuracy at time 8: 29.600000
SMC percent accuracy at time 9: 23.199999
SMC percent accuracy at time 10: 22.000000
SMC percent accuracy at time 11: 17.200001
SMC percent accuracy at time 12: 3.600000
SMC percent accuracy at time 13: 16.799999
SMC percent accuracy at time 14: 18.400000
SMC percent accuracy at time 15: 18.799999
SMC percent accuracy at time 16: 20.400000
SMC percent accuracy at time 17: 31.600002
SMC percent accuracy at time 18: 20.400000
SMC percent accuracy at time 19: 24.799999
SMC percent accuracy at time 20: 24.799999
SMC percent accuracy at time 21: 6.400000
SMC percent accuracy at time 22: 17.200001
SMC percent accuracy at