In [1]:
import numpy as np
import torch
from torch.distributions.normal import Normal

import pyro
import pyro.distributions as dist

from generation_with_evaluation_period import generative_procedure_with_evaluation_period
import kalman_filter

In [2]:
L = 500
L_train = 350
S = 7

sigma_e = 0.1
sigma_u = 0.1
sigma_v = 0.0004
sigma_w = 0.01
sigma_r = 1.
sigma_o = 4.

mu_0 = 20.
delta_0 = 0.

gamma_0 = np.array([1,2,4,-1,-3,-2]) / 10
p_c = 4 / 350
p_a = 10 / 350

t_c_fixed = 330 - 1
c_fixed = 1.
t_r_fixed = 330 - 1
r_fixed = 2.

In [3]:
torch.manual_seed(4)
y, _, _, _, _, _, _, _, _ = generative_procedure_with_evaluation_period(mu_0, delta_0, gamma_0, 
                                                                        sigma_e, sigma_o, sigma_u, 
                                                                        sigma_r, sigma_v, sigma_w, 
                                                                        p_a, p_c, L, L_train, S, 
                                                                        t_c_fixed, c_fixed, t_r_fixed, r_fixed)

m = S + 1
r = 3
n = 100

a_1 = kalman_filter.gen_a_1(S, y)
P_1 = kalman_filter.gen_P_1(S)

Z_t = kalman_filter.gen_Z_t(S)
T_t = kalman_filter.gen_T_t(S)
R_t = kalman_filter.gen_R_t(S, r)

H_e = sigma_e ** 2
H_o = sigma_o ** 2

Q_eta = kalman_filter.gen_Q_eta(sigma_r, sigma_v, sigma_w)
Q_xi = kalman_filter.gen_Q_eta(sigma_u, sigma_v, sigma_w)

In [4]:
a_t, P_t = kalman_filter.kalman_filter(n, m, y, a_1, P_1, Z_t, p_a, p_c, T_t, H_e, H_o, Q_eta, Q_xi, R_t)
print(a_t)
print()
print(P_t)

tensor([40163547713553574673709329809408., 11611557203154669101822488084480.,
        -5659870926556405726557606248448.,  3292034059641879104389317132288.,
          766229953750721650449432182784.,  1187672903705803993814841425920.,
           92496047826671166871620812800.,   927346671161039768964107862016.])

tensor([[ 2.8501,  1.8405,  1.3916, -0.1933,  0.1200, -0.1254,  0.0418, -0.2882],
        [ 1.8405,  1.4855,  0.9846, -0.0492,  0.0965, -0.0411,  0.0428, -0.1071],
        [ 1.3916,  0.9846,  1.9327, -0.1803, -0.2005, -0.1160, -0.0595, -0.2541],
        [-0.1933, -0.0492, -0.1803,  0.8643, -0.0293, -0.3426, -0.0295, -0.2148],
        [ 0.1200,  0.0965, -0.2005, -0.0293,  0.8432, -0.0092, -0.3549, -0.0079],
        [-0.1254, -0.0411, -0.1160, -0.3426, -0.0092,  0.8243,  0.0023, -0.3757],
        [ 0.0418,  0.0428, -0.0595, -0.0295, -0.3549,  0.0023,  0.8174,  0.0150],
        [-0.2882, -0.1071, -0.2541, -0.2148, -0.0079, -0.3757,  0.0150,  0.7952]])
