In [1]:
from discopy import PRO, Ty
from functools import reduce
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch

In [2]:
%cd ../..

/home/eli/AnacondaProjects/combinators


In [3]:
from combinators import lens, sampler, utils
from combinators.inference import conditioning, resample
from combinators.model import collections
import examples.hmm.hmm as hmm

In [4]:
state_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2) * 2,
    'rate': torch.ones(5, 2) * 2,
}

In [5]:
prior = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, state_params), None, (1,), lens.LensPRO(0), PRO(4) & Ty())
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), None, (1,), PRO(4) & Ty(), PRO(4) & Ty())

In [6]:
chain = collections.sequential(hmm_step, 50)

In [7]:
hmm_model = prior >> chain
hmm_sampler = hmm_model.compile()

In [8]:
_, trace = hmm_model.trace(hmm_sampler)

In [9]:
_, p = trace.fold()

In [10]:
zs = [p[k].value for k in p if 'z' in k]

In [11]:
zs

[tensor([3]),
 tensor([3]),
 tensor([1]),
 tensor([2]),
 tensor([0]),
 tensor([2]),
 tensor([0]),
 tensor([2]),
 tensor([0]),
 tensor([2]),
 tensor([0]),
 tensor([2]),
 tensor([3]),
 tensor([2]),
 tensor([1]),
 tensor([3]),
 tensor([3]),
 tensor([1]),
 tensor([2]),
 tensor([1]),
 tensor([3]),
 tensor([0]),
 tensor([2]),
 tensor([2]),
 tensor([1]),
 tensor([3]),
 tensor([2]),
 tensor([1]),
 tensor([4]),
 tensor([4]),
 tensor([0]),
 tensor([3]),
 tensor([0]),
 tensor([1]),
 tensor([1]),
 tensor([3]),
 tensor([2]),
 tensor([4]),
 tensor([0]),
 tensor([2]),
 tensor([1]),
 tensor([3]),
 tensor([3]),
 tensor([2]),
 tensor([1]),
 tensor([4]),
 tensor([2]),
 tensor([1]),
 tensor([2]),
 tensor([3]),
 tensor([1])]

In [12]:
xs = [p[k].value for k in p if 'x' in k]

In [13]:
xs

[tensor([[4.7026, 5.3827]]),
 tensor([[6.3133, 4.3302]]),
 tensor([[3.5163, 1.0125]]),
 tensor([[5.3809, 6.1912]]),
 tensor([[-0.3168,  1.9988]]),
 tensor([[4.8447, 2.3403]]),
 tensor([[-1.0862,  1.4617]]),
 tensor([[6.2852, 2.1984]]),
 tensor([[-1.8893, -0.5716]]),
 tensor([[3.8280, 3.7149]]),
 tensor([[-0.2647, -1.6588]]),
 tensor([[5.4067, 5.2456]]),
 tensor([[5.6621, 6.6706]]),
 tensor([[6.2415, 0.9347]]),
 tensor([[1.2064, 2.6123]]),
 tensor([[5.6483, 6.9035]]),
 tensor([[4.8398, 6.3641]]),
 tensor([[1.7894, 2.6727]]),
 tensor([[5.4235, 4.5007]]),
 tensor([[1.2857, 1.5523]]),
 tensor([[5.5103, 5.4754]]),
 tensor([[-1.1651, -0.3810]]),
 tensor([[6.2783, 1.1581]]),
 tensor([[6.5669, 4.5248]]),
 tensor([[2.8847, 1.2194]]),
 tensor([[4.7251, 4.6440]]),
 tensor([[6.7429, 1.4427]]),
 tensor([[2.2225, 0.8942]]),
 tensor([[8.3902, 8.5772]]),
 tensor([[8.3048, 8.9782]]),
 tensor([[-1.2131, -0.6445]]),
 tensor([[5.2527, 6.4494]]),
 tensor([[-0.9364,  0.8777]]),
 tensor([[1.9180, 1.8063]]),


In [14]:
num_particles = 250

In [15]:
inference_params = {
    'mu': (torch.arange(5, dtype=torch.float) * 2).expand(2, 5).t(),
    'concentration': torch.ones(5, 2),
    'rate': torch.ones(5, 2),
}

In [16]:
params = sampler.importance_box('hmm_parameters', hmm.Parameters(5, 2, inference_params), None, (num_particles,), lens.LensPRO(0), PRO(4) & Ty())
hmm_step = sampler.importance_box('hmm_step', hmm.TransitionAndEmission(), None, (num_particles,), PRO(4) & Ty(), PRO(4) & Ty())

In [17]:
chain = collections.sequential(hmm_step, 50)

In [18]:
conditioned_chain = conditioning.SequentialConditioner(hmm_step=xs)(chain)

In [19]:
hmm_inference = resample.resampler(params >> conditioned_chain)
hmm_resampler = hmm_inference.compile()

In [20]:
result, trace = hmm_inference.trace(hmm_resampler)

In [21]:
log_weight, p = trace.fold()

In [22]:
log_weight

tensor([-257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.0720, -257.0720, -257.0720, -257.0720,
        -257.0720, -257.0720, -257.072

In [23]:
inferred_zs = [p[k].value for k in p if 'z' in k]

In [24]:
accuracy = [(pr == tr).to(dtype=torch.float) for pr, tr in zip(inferred_zs, zs)]

In [25]:
for t in range(50):
    print('SMC percent accuracy at time %d: %f' % (t, accuracy[t].mean(dim=0) * 100))

SMC percent accuracy at time 0: 24.799999
SMC percent accuracy at time 1: 20.400000
SMC percent accuracy at time 2: 19.600000
SMC percent accuracy at time 3: 18.799999
SMC percent accuracy at time 4: 2.800000
SMC percent accuracy at time 5: 5.600000
SMC percent accuracy at time 6: 12.000000
SMC percent accuracy at time 7: 11.200001
SMC percent accuracy at time 8: 17.200001
SMC percent accuracy at time 9: 7.200000
SMC percent accuracy at time 10: 39.200001
SMC percent accuracy at time 11: 12.000000
SMC percent accuracy at time 12: 12.000000
SMC percent accuracy at time 13: 52.000000
SMC percent accuracy at time 14: 14.800000
SMC percent accuracy at time 15: 14.000000
SMC percent accuracy at time 16: 17.600000
SMC percent accuracy at time 17: 2.400000
SMC percent accuracy at time 18: 5.200000
SMC percent accuracy at time 19: 0.400000
SMC percent accuracy at time 20: 19.600000
SMC percent accuracy at time 21: 0.000000
SMC percent accuracy at time 22: 12.000000
SMC percent accuracy at time