# Kalman Filter
> Implementation of Kalman filters using pytorch and parameter optimizations with gradient descend

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
#| default_exp kalman.filter

In [None]:
#| export
from fastcore.test import *
from fastcore.basics import *
from meteo_imp.utils import *
from meteo_imp.gaussian import *
from meteo_imp.data_preparation import MeteoDataTest
from typing import *
from functools import partial

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

## Introduction

The models uses a latent state variable $x$ that is modelled over time, to impute gaps in $y$

The assumption of the model is that the state variable at time $x_t$ depends only on the last state $x_{t-1}$ and not on the previous states.

### Equations

The equations of the model are:

$$\begin{align} p(x_t | x_{t-1}) & = \mathcal{N}(Ax_{t-1} + b, Q) \\
p(y_t | x_t) & = \mathcal{N}(Hx_t + d, R) \end{align}$$


where:

- $A$ is the `A`
- $b$ is the `bset`
- $Q$ is the `Q`
- $H$ is the `obs_trans` 
- $d$ is the `d`
- $R$ is the `R`

in addition the model has also the parameters of the initial state that are used to initialize the filter:

- `m0`
- `P0`

The Kalman filter has 3 steps:

- filter (updating the state at time t with observations till time t-1)
- update (update the state at time t using the observation at time t)
- smooth (update the state using the observations at time t+1)

In case of missing data the update step is skipped.

After smoothing the missing data at time t ($y_t$) can be imputed from the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

The Kalman Filter is an algorithm designed to estimate $P(x_t | y_{0:t})$.  As all state transitions and obss are linear with Gaussian distributed noise, these distributions can be represented exactly as Gaussian distributions with mean `ms[t]` and covs `Ps[t]`.
Similarly, the Kalman Smoother is an algorithm designed to estimate $P(x_t | y_{0:t-1})$



## Kalman Filter Base

TODO: fill nans with 0 for all data

In [None]:
#| export
class KalmanFilterBase(torch.nn.Module):
    """Base class for handling Kalman Filter implementation in PyTorch"""
    
    params_constr = {
        #name constraint
        'A':  None        ,
        'b':  None        ,
        'Q':  DiagPosDef(),
        'B':  None        ,
        'H':  None        ,
        'd':  None        ,
        'R':  DiagPosDef(),
        'm0': None       ,
        'P0': PosDef()   ,
        }
    
    def __init__(self,
            A: Tensor,                               # [n_dim_state,n_dim_state] $A$, state transition matrix 
            H: Tensor,                               # [n_dim_obs, n_dim_state] $H$, observation matrix
            B: Tensor,                               # [n_dim_state, n_dim_contr] $B$ control matrix
            Q: Tensor,                               # [n_dim_state, n_dim_state] $Q$, state trans covariance matrix
            R: Tensor,                               # [n_dim_obs, n_dim_obs] $R$, observations covariance matrix
            b: Tensor,                               # [n_dim_state] $b$, state transition offset
            d: Tensor,                               # [n_dim_obs] $d$, observations offset
            m0: Tensor,                              # [n_dim_state] $m_0$
            P0: Tensor,                              # [n_dim_state, n_dim_state] $P_0$
            
            n_dim_state: int = None,                 # Number of dimensions for state - default infered from parameters
            n_dim_obs: int = None,                   # Number of dimensions for observations - default  infered from parameters
            n_dim_contr: int = None,                 # Number of dimensions for control - default infered from parameters
                 
            var_names: Iterable[str]|None = None,    # Names of variables for printing 
            contr_names: Iterable[str]|None = None,  # Names of control variables for printing
            
            cov_checker: CheckPosDef = CheckPosDef(),# Check covariance at every step
            use_conditional: bool = True,            # Use conditional distribution for gaps that don't have all variables missing
            use_control: bool = True,                # Use the control in the filter
            use_smooth: bool = True,                 # Use smoother for predictions (otherwise is filter only)
                ):
        
        super().__init__()
        store_attr("var_names, contr_names, use_conditional, use_control, use_smooth, cov_checker")
        self._check_params(A, H, B, Q, R, b, d, m0, P0, n_dim_state, n_dim_obs, n_dim_contr)
                       
        self._init_params(A=A, H=H, B=B, Q=Q, R=R, b=b, d=d, m0=m0, P0=P0)
        
        self.cov_checker = cov_checker
    
    def _check_params(self, A, H, B, Q, R, b, d, m0, P0, n_dim_state, n_dim_obs, n_dim_contr):
        """Checks that the parameters are consistent and set dimensions"""
        self.n_dim_state = determine_dimensionality(
            [(A, array2d, -2),
             (b, array1d, -1),
             (Q, array2d, -2),
             (m0, array1d, -1),
             (P0, array2d, -2),
             (H, array2d, -1)],
            n_dim_state
        )
        self.n_dim_obs = determine_dimensionality(
            [(H, array2d, -2),
             (d, array1d, -1),
             (R, array2d, -2)],
            n_dim_obs
        )
        
        self.n_dim_contr = determine_dimensionality([(B, array2d, -1)], n_dim_contr)
        
        
    def _init_params(self, **params):
        for name, value in params.items():
            if (constraint := self.params_constr[name]) is not None:
                name, value = self._init_constraint(name, value, constraint)
            self._init_param(name, value, train=True)    
    
    def _init_param(self, param_name, value, train):
        self.register_parameter(param_name, torch.nn.Parameter(value, requires_grad=train))
    
    ### === Constraints utils
    def _init_constraint(self, param_name, value, constraint):
        name = f"C_{param_name}" # The constraint is always a Cholesky decomposition
        value = constraint.inverse_transform(value)
        setattr(self, param_name + "_constraint", constraint)
        self._init_constraint_property(param_name)
        return name, value
    
    def _init_constraint_property(self, param_name):
        "Setup property to get/set original value"
        getter = partial(_get_constraint, param_name=param_name)
        setter = partial(_set_constraint, param_name=param_name)
        setattr(type(self), param_name, property(getter, setter)) # need to set properties on the class see https://stackoverflow.com/questions/1325673/how-to-add-property-to-a-class-dynamically
               
   
    ### === Utility Func    
    def _parse_obs(self, obs, mask=None):
        """maybe get mask from `nan`"""
        # if mask is None: mask = ~torch.isnan(obs)
        # TODO incorrect support for 2d input!!!!!!
        assert obs.dim() == 3
        # obs, mask = torch.atleast_3d(obs), torch.atleast_3d(mask)
        return obs, mask
    
    def __repr__(self):
        return f"""Kalman Filter
        N dim obs: {self.n_dim_obs},
        N dim state: {self.n_dim_state},
        N dim contr: {self.n_dim_contr}"""

# Those methods needs to be global

def _get_constraint(self, param_name):
    """get the original value"""
    constraint = getattr(self, param_name + "_constraint")
    raw_value = getattr(self, f"C_{param_name}")
    return constraint.transform(raw_value)

def _set_constraint(self, value, param_name, train=True):
    """set the transformed value"""
    constraint = getattr(self, param_name + "_constraint")
    raw_value = constraint.inverse_transform(value)
    self._init_param(f"C_{param_name}", raw_value, train)

### Constructors

Giving all the parameters manually to the `KalmanFilterBase` init method is not convenient, hence we are having some methods that help initize the class

#### Random parameters

In [None]:
#| export
@patch(cls_method=True)
def init_random(cls: KalmanFilterBase, n_dim_obs, n_dim_state, n_dim_contr, dtype=torch.float32, **kwargs):
    """kalman filter with random parameters"""
    return cls(
        A  = torch.rand(n_dim_state, n_dim_state, dtype=dtype),
        b  = torch.rand(n_dim_state, dtype=dtype),        
        Q  = to_diagposdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),        
        B  = torch.rand(n_dim_state, n_dim_contr, dtype=dtype),
        H  = torch.rand(n_dim_obs, n_dim_state, dtype=dtype),
        d  = torch.rand(n_dim_obs, dtype=dtype),          
        R  = to_diagposdef(torch.rand(n_dim_obs, n_dim_obs, dtype=dtype)),            
        m0 = torch.rand(n_dim_state, dtype=dtype),        
        P0 = to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),
        **kwargs) 
        

In [None]:
kB = KalmanFilterBase.init_random(3,4, 3, dtype=torch.float64)
kB

Kalman Filter
        N dim obs: 3,
        N dim state: 4,
        N dim contr: 3

In [None]:
kB.P0 = to_posdef(torch.rand(3,3))

check that assigment works :)

In [None]:
kB.P0 = to_posdef(torch.rand(4, 4, dtype=torch.float64))

In [None]:
kB.C_P0

Parameter containing:
tensor([[ 1.1682,  0.0000,  0.0000,  0.0000],
        [ 0.8023,  0.8394,  0.0000,  0.0000],
        [ 0.6775,  0.1935,  0.5666,  0.0000],
        [ 0.8928, -0.0317,  0.5224,  0.0062]], dtype=torch.float64,
       requires_grad=True)

In [None]:
list(kB.named_parameters())

[('A',
  Parameter containing:
  tensor([[0.4743, 0.2929, 0.7767, 0.3891],
          [0.1858, 0.9826, 0.8748, 0.5956],
          [0.7667, 0.0196, 0.4954, 0.7218],
          [0.2282, 0.7413, 0.4577, 0.2354]], dtype=torch.float64,
         requires_grad=True)),
 ('H',
  Parameter containing:
  tensor([[0.2559, 0.4528, 0.9301, 0.9019],
          [0.0589, 0.5794, 0.8223, 0.2959],
          [0.4382, 0.5937, 0.3968, 0.9370]], dtype=torch.float64,
         requires_grad=True)),
 ('B',
  Parameter containing:
  tensor([[0.3702, 0.3493, 0.4764],
          [0.4421, 0.9268, 0.9782],
          [0.0114, 0.8504, 0.9450],
          [0.1107, 0.3799, 0.0615]], dtype=torch.float64, requires_grad=True)),
 ('C_Q',
  Parameter containing:
  tensor([0.2220, 0.6632, 0.6818, 0.6496], dtype=torch.float64,
         requires_grad=True)),
 ('C_R',
  Parameter containing:
  tensor([0.5654, 0.7635, 0.5842], dtype=torch.float64, requires_grad=True)),
 ('b',
  Parameter containing:
  tensor([0.5049, 0.6878, 0.7206, 0

### Test data

In [None]:
#| exporti
def get_test_data(n_obs = 10, n_dim_obs=3, n_dim_contr = 3, p_missing=.3, bs=2, dtype=torch.float32, device='cpu'):
    data = torch.rand(bs, n_obs, n_dim_obs, dtype=dtype, device=device)
    mask = torch.rand(bs, n_obs, n_dim_obs, device=device) > p_missing
    control = torch.rand(bs, n_obs, n_dim_contr, dtype=dtype, device=device)
    data[~mask] = torch.nan # ensure that the missing data cannot be used
    return data, mask, control

In [None]:
reset_seed()
data, mask, control = get_test_data(dtype=torch.float64)
show_as_row(data, mask, control)

## Filter

### Filter predict

Probability of state at time `t` given state a time `t-1` 

$p(x_t) = \mathcal{N}(x_t; m_t^-, P_t^-)$ where:

- predicted state mean: $m_t^- = Am_{t-1} + B c_t + b$  

- predicted state covariance: $P_t^- = AP_{t-1}A^T + Q$

In [None]:
A, Q, b, B, m_pr,P_pr = (k.A.unsqueeze(0), k.Q.unsqueeze(0), k.b.unsqueeze(-1),
                                                  k.B.unsqueeze(0),
                                                  torch.stack([k.m0]*2).unsqueeze(-1),
                                                  torch.stack([k.P0]*2))

#### Covariance

##### Standard

implement $P_t^- = AP_{t-1}A^T + Q$

In [None]:
#| export
def _filter_predict_cov_stand(A, Q, Pm):
    """Standard - Kalman Filter predict covariance"""
    return A @ Pm @ A.mT + Q

In [None]:
P_m = _filter_predict_cov_stand(A, Q, P_pr)
P_m

tensor([[[3.9186, 5.2420, 4.0320, 3.6330],
         [5.2420, 7.5784, 5.7753, 5.3769],
         [4.0320, 5.7753, 5.8339, 4.6659],
         [3.6330, 5.3769, 4.6659, 4.3242]],

        [[3.9186, 5.2420, 4.0320, 3.6330],
         [5.2420, 7.5784, 5.7753, 5.3769],
         [4.0320, 5.7753, 5.8339, 4.6659],
         [3.6330, 5.3769, 4.6659, 4.3242]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

##### Num Stable

Implement the numerical stable version of the covariance update

In [None]:
C_Pm = torch.linalg.cholesky(Pm)

In [None]:
C_Q = torch.linalg.cholesky(Q)

$$W = \begin{bmatrix}AU_{t-1}&G\end{bmatrix}$$

In [None]:
(A @ C_curr).shape

torch.Size([2, 4, 4])

In [None]:
C_Q.expand(C_curr.shape[0], -1, -1).shape

torch.Size([2, 4, 4])

In [None]:
torch.concat([A @ C_curr, C_Q.expand(C_curr.shape[0], -1, -1)], dim=-1).shape

torch.Size([2, 4, 8])

In [None]:
W = torch.concat([A @ C_curr, C_Q.expand_as(C_curr)], dim=-1)
W.shape

torch.Size([2, 4, 8])

In [None]:
Q, R = torch.linalg.qr(W.mT)

In [None]:
C_pred = R.mT

In [None]:
test_close(C_pred @ C_pred.mT, P_m)

In [None]:
#| export
def _fitler_predict_cov(A, # transition covariance
                        C_Q, # Cholesky Factor of transition covariance
                        C_Pm # Cholesky Factor of current state covariance $P^-$
                       ):
    """Numerical stable Kalman filter predict for covariance"""
    W = torch.concat([A @ C_Pm, C_Q.expand_as(C_Q)], dim=-1)
    return torch.linalg.qr(W.mT).mT 

##### Mean

In [None]:
#| export
def unsqueeze_iter(*args, dim): return list(map(partial(torch.unsqueeze, dim=dim), args))
unsqueeze_first = partial(unsqueeze_iter, dim=0)
unsqueeze_last = partial(unsqueeze_iter, dim=-1)

In [None]:
#| export
def _filter_predict(A,
                    Q,
                    b,
                    B, #[n_dim_state, n_dim_contr]
                    m_pr,
                    P_pr,
                    control, #[n_batches, n_dim_contr]
                    cov_checker=CheckPosDef()):
    r"""Calculate the state at time `t+1` given the state at time `t`"""
    
    m_m = A.unsqueeze(0) @ m_pr + B.unsqueeze(0) @ control.unsqueeze(-1) + b.unsqueeze(-1)
    P_m =  A.unsqueeze(0) @ P_pr @ A.unsqueeze(0).mT + Q.unsqueeze(0)

    cov_checker.check(P_m, caller='filter_predict')
    return (m_m, P_m)

In [None]:
m_m, P_m = _filter_predict(
    A, Q, b, B,
    m_pr,P_pr, control[:,0,:])

In [None]:
show_as_row(m_m, P_m)

In [None]:
show_as_row((m_m.shape, P_m.shape,))

### Filter Predict UD

#### UDU Decomposition

In [None]:
Lc = torch.linalg.cholesky(Q)

In [None]:
Lc

tensor([[ 1.0125,  0.0000,  0.0000,  0.0000],
        [ 1.0871,  0.8976,  0.0000,  0.0000],
        [ 0.7836, -0.1072,  0.5750,  0.0000],
        [ 1.4338,  0.5792,  0.2001,  0.1477]], dtype=torch.float64,
       grad_fn=<LinalgCholeskyExBackward0>)

In [None]:
D0 = torch.diag(Lc)

In [None]:
D0

tensor([1.0125, 0.8976, 0.5750, 0.1477], dtype=torch.float64,
       grad_fn=<DiagBackward0>)

In [None]:
D0.unsqueeze(0).shape

torch.Size([1, 4])

In [None]:
L = Lc / D0.unsqueeze(0)
L

tensor([[ 1.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0737,  1.0000,  0.0000,  0.0000],
        [ 0.7740, -0.1194,  1.0000,  0.0000],
        [ 1.4161,  0.6454,  0.3479,  1.0000]], dtype=torch.float64,
       grad_fn=<DivBackward0>)

In [None]:
D = D0.pow(2)
D

tensor([1.0251, 0.8056, 0.3306, 0.0218], dtype=torch.float64,
       grad_fn=<PowBackward0>)

In [None]:
L @ torch.diag(D) @ L.T - Q

tensor([[-2.2204e-16,  0.0000e+00, -2.2204e-16,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-2.2204e-16, -1.1102e-16, -3.3307e-16, -2.2204e-16],
        [ 0.0000e+00,  0.0000e+00, -2.2204e-16,  4.4409e-16]],
       dtype=torch.float64, grad_fn=<SubBackward0>)

can compute gradients

In [None]:
L.sum().backward(retain_graph=True) 

Cannot used PyTorch LDL function, because it doesn't compute gradients, and in any case we don't need it! We can use just use the cholesky decomposition. Actually we may just use cholesky directly and $I$ for the diagonal ...

In [None]:
A = to_posdef(torch.rand(2, 3,3))

In [None]:
torch.diagonal(A, dim1=-2, dim2=-1).shape

torch.Size([2, 3])

In [None]:
def udu_decomposition(A):
    L = torch.linalg.cholesky(A)
    D0 = torch.diagonal(L, dim1=-2, dim2=-1)
    U = L / D0.unsqueeze(-2)
    D = D0.pow(2)
    return U, D

In [None]:
U, D = udu_decomposition(A)

In [None]:
torch.diag_embed(D), D

(tensor([[[1.3610, 0.0000, 0.0000],
          [0.0000, 0.1577, 0.0000],
          [0.0000, 0.0000, 0.0253]],
 
         [[1.6047, 0.0000, 0.0000],
          [0.0000, 0.0966, 0.0000],
          [0.0000, 0.0000, 0.0751]]]),
 tensor([[1.3610, 0.1577, 0.0253],
         [1.6047, 0.0966, 0.0751]]))

In [None]:
for _ in range(10):
    A = to_posdef(torch.rand(5, 10,10))
    U,D = udu_decomposition(A)
    test_close(A, U @ torch.diag_embed(D) @ U.mT)

In [None]:
torch.diag(Q)

tensor([1.0251, 1.9873, 0.9562, 2.4531], dtype=torch.float64,
       grad_fn=<DiagBackward0>)

In [None]:
udu_decomposition(torch.diag(torch.diag(Q)))

(tensor([[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]], dtype=torch.float64, grad_fn=<DivBackward0>),
 tensor([1.0251, 1.9873, 0.9562, 2.4531], dtype=torch.float64,
        grad_fn=<PowBackward0>))

The first step is to decompose the Qariance matrix Q
$$    Q = GD_QG^T $$

In [None]:
G, D_Q = udu_decomposition(Q)

$$W = \begin{bmatrix}AU_{t-1}&G\end{bmatrix}$$

In [None]:
U_curr, D_curr = udu_decomposition(P_pr[0])

In [None]:
W = torch.hstack([A @ U_curr, G])
W

tensor([[ 0.7179,  0.1642, -0.1333,  0.7937,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 1.3292,  0.5025,  0.1748,  0.2773,  1.0737,  1.0000,  0.0000,  0.0000],
        [ 1.4469,  0.8537, -0.0175,  0.3410,  0.7740, -0.1194,  1.0000,  0.0000],
        [ 0.9866,  0.0631, -0.2660,  0.8306,  1.4161,  0.6454,  0.3479,  1.0000]],
       dtype=torch.float64, grad_fn=<CatBackward0>)

$$ D_w = \begin{bmatrix}D_{t-1} & 0 \\ 0& D_Q \end{bmatrix}$$

In [None]:
Dw = torch.diag(torch.hstack([D_curr, D_Q]))
Dw

tensor([[2.0770, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4582, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0169, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.0251, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8056, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3306, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0218]],
       dtype=torch.float64, grad_fn=<DiagBackward0>)

In [None]:
W.shape

torch.Size([4, 8])

In [None]:
C_dw = Dw.sqrt()

In [None]:
C_dw.shape

torch.Size([8, 8])

In [None]:
Q, R = torch.linalg.qr((W @ C_dw).mT)
Q, R

(tensor([[-0.7107, -0.3610,  0.4193, -0.2126],
         [-0.0763, -0.1629,  0.3247, -0.2840],
         [ 0.0190, -0.0712,  0.0029, -0.1181],
         [-0.0710,  0.1072, -0.0504,  0.0055],
         [-0.6955,  0.3739, -0.4589,  0.2447],
         [-0.0000, -0.8287, -0.4603,  0.2697],
         [-0.0000, -0.0000,  0.5419,  0.7722],
         [-0.0000, -0.0000,  0.0000,  0.3623]], dtype=torch.float64,
        grad_fn=<LinalgQrBackward0>),
 tensor([[-1.4558, -2.1453, -2.0742, -2.0196],
         [ 0.0000, -1.0830, -0.4600, -0.4487],
         [ 0.0000,  0.0000,  1.0611, -0.2117],
         [ 0.0000,  0.0000,  0.0000,  0.4077]], dtype=torch.float64,
        grad_fn=<LinalgQrBackward0>))

In [None]:
Q.T @ Q

tensor([[ 1.0000e+00, -4.4235e-17,  5.8981e-17,  6.9389e-18],
        [-4.4235e-17,  1.0000e+00, -1.1102e-16,  8.3267e-17],
        [ 5.8981e-17, -1.1102e-16,  1.0000e+00, -8.3267e-17],
        [ 6.9389e-18,  8.3267e-17, -8.3267e-17,  1.0000e+00]],
       dtype=torch.float64, grad_fn=<MmBackward0>)

In [None]:
U_pred = R.mT
V = Q.T @ torch.inverse(C_dw)
U_pred.shape, V.shape

(torch.Size([4, 4]), torch.Size([4, 8]))

In [None]:
D_pred = V @ Dw @ V.T

In [None]:
D_pred

tensor([[ 1.0000e+00,  4.7705e-17, -6.2450e-17,  7.6328e-17],
        [-1.1276e-17,  1.0000e+00, -1.6653e-16,  1.3878e-16],
        [-5.8981e-17, -1.1102e-16,  1.0000e+00, -8.3267e-17],
        [ 3.8164e-17,  1.3878e-16, -5.5511e-17,  1.0000e+00]],
       dtype=torch.float64, grad_fn=<MmBackward0>)

so the math works and D_pred is a diagonal matrix, actually and identity matrix

Actually this all derivations are probably not needed, as the original idea of the algorightm is to not scalar square roots in PyTorch ... but since we have to use cholesky decomposition anyway, we can just use cholesky factors and make the life easier

In [None]:
U_pred @ D_pred @ U_pred.T

tensor([[2.1193, 3.1231, 3.0197, 2.9402],
        [3.1231, 5.7751, 4.9480, 4.8186],
        [3.0197, 4.9480, 5.6400, 4.1710],
        [2.9402, 4.8186, 4.1710, 4.4912]], dtype=torch.float64,
       grad_fn=<MmBackward0>)

In [None]:
test_close(P_m[0], U_pred @ D_pred @ U_pred.T)

#| export
def filter_predict_ud(A,
                    Q,
                    b,
                    B, #[n_dim_state, n_dim_contr]
                    m_pr,
                    P_pr,
                    control, #[n_batches, n_dim_contr]
                    cov_checker=CheckPosDef()):
    
    

### Filter correct

Probability of state at time `t` given the observations at time `t`

$p(x_t|y_t) = \mathcal{N}(x_t; m_t, P_t)$ where:

- predicted obs mean: $z_t = Hm_t^- + d$  

- prediced obs covariance: $S_t = HP_t^-H^T + R$

- kalman gain$K_t = P_t^-H^TS_t^{-1}$ 

- corrected state mean: $m_t = m_t^- + K_t(y_t - z_t)$ 

- corrected state covariance: $P_t = (I-K_tH)P_t^-$ 

if the observation are missing this step is skipped and the corrected state is equal to the predicted state


Need to figure out the Nans for the gradients ...

#### Missing observations

If all the observations at time $t$ are missing the correct step is skipped and the filtered state at time $t$ () is the same of the filtered state.

If only some observations are missing a variation of equation can be used.

$y^{ng}_t$ is a vector containing the observations that are not missing at time $t$. 

It can be expressed as a linear transformation of $y_t$

$$ y^{ng}_t = My_t$$

where $M$ is a mask matrix that is used to select the subset of $y_t$ that is observed. $M \in \mathbb{R}^{n_{ng} \times n}$ and is made of columns which are made of all zeros but for an entry 1 at row corresponding to the non-missing observation.
hence:

$$ p(y^{ng}_t) = \mathcal{N}(M\mu_{y_t},  M\Sigma_{y_t}M^T)$$

from which you can derive

$$ p(y^{ng}_t|x_t) = p(MHx_t + Mb, MRM^T) $${#eq-filter-correct}

Then the posterior $p(x_t|y_t^{ng})$ can be computed similarly of equation @filter_correct as:

$$ p(x_t|y^{ng}_t) = \mathcal{N}(x_t; m_t, P_t) $${#eq-filter_correct_missing}
    
where:

*  predicted obs mean: $z_t = MHm_t^- + Md$
*  predicted obs covariance: $S_t = MHP_t^-(MH)^T + MRM^T$
*  Kalman gain $K_t = P_t^-(MH)^TS_t^{-1}$
*  corrected state mean: $m_t = m_t^- + K_t(My_t - z_t)$
*  corrected state covariance: $P_t = (I-K_tMH)P_t^-$


In [None]:
k.d.shape

torch.Size([3])

##### Details implementation 

For the implementation the matrix multiplication $MH$ can be replaced with `H[m]` where `m` is the mask for the rows for `H` and $MRM^T$ with `R[m][:,m]`

In [None]:
H, R, d,obs, mm = (k.H, k.R, k.d, data[:,0,:], mask[:,0,:])

In [None]:
m = torch.tensor([False,True,True]) # mask batch
M = torch.tensor([[0,1,0], # mask matrix
                  [0,0,1]], dtype=torch.float64)
show_as_row(m, M, H, R)

In [None]:
M @ H, H[m]

(tensor([[0.2237, 0.0553, 0.2241, 0.4903],
         [0.4756, 0.3574, 0.5449, 0.4099]], dtype=torch.float64,
        grad_fn=<MmBackward0>),
 tensor([[0.2237, 0.0553, 0.2241, 0.4903],
         [0.4756, 0.3574, 0.5449, 0.4099]], dtype=torch.float64,
        grad_fn=<IndexBackward0>))

In [None]:
M @ R @ M.T, R[m][:,m]

(tensor([[1.1859, 0.0000],
         [0.0000, 0.6930]], dtype=torch.float64, grad_fn=<MmBackward0>),
 tensor([[1.1859, 0.0000],
         [0.0000, 0.6930]], dtype=torch.float64, grad_fn=<IndexBackward0>))

By using partially missing observations `_filter_correct` cannot be easily batched as the shape of the intermediate variables depends on the number of observed variables. So the idea is to divide the batch in batches where there is the same number of variables.

In [None]:
mask_values, indices = torch.unique(mask[:,1,:], dim=0, return_inverse=True)
mask_values, indices

(tensor([[ True, False,  True],
         [ True,  True, False]]),
 tensor([0, 1]))

In [None]:
#| export
def _filter_correct_batch(
                    H,
                    R,
                    d,
                    m_m,
                    P_m,
                    obs, # [n_obs]
                    mask, # [n_obs_np, n_obs] mask to obtain non missing obs from obs
                    cov_checker=CheckPosDef()):
    """Update state at time `t` given observations at time `t` assuming that all observations have the same mask"""
    
    if (~mask).all(): return (m_m, P_m)

    m_H, m_d, m_obs, m_R = H[mask], d[mask], obs[:, mask], R[mask][:,mask]
    
    # extra dim needed to have batched matmul working between matrices and means
    (m_H,), (m_d, m_obs) = unsqueeze_first(m_H), unsqueeze_last(m_d, m_obs) 
    
    pred_obs_mean = m_H @ m_m + m_d
    pred_R = m_H @ P_m @ m_H.mT + m_R
    kalman_gain = P_m @ m_H.mT @ torch.inverse(pred_R) # torch.cholesky_inverse(torch.linalg.cholesky(pred_R))

    corr_state_mean = m_m + kalman_gain @ (m_obs - pred_obs_mean) #select with the mask instead of multipling so that support nan in the dataset
    corr_state_cov = P_m - kalman_gain @ m_H @ P_m

    cov_checker.check(P_m, caller='filter_correct')
    return (corr_state_mean, corr_state_cov)

In [None]:
H, R, d,obs, mm = (k.H, k.R, k.d, data[:,0,:], mask[:,0,:])

In [None]:
corr_s_mean,corr_s_cov = _filter_correct_batch(H, R, d, m_m[0:1], P_m[0:1], obs[0:1], mm[0])

In [None]:
corr_s_mean.shape, corr_s_cov.shape

(torch.Size([1, 4, 1]), torch.Size([1, 4, 4]))

In [None]:
#| export
def _filter_correct(H,
                    R,
                    d,
                    m_m,
                    P_m,
                    obs,
                    mask,
                    cov_checker=CheckPosDef()) -> ListMNormal:
    """Update state at time `t` given observations at time `t`"""

    corr_state_mean, corr_state_cov = torch.empty_like(m_m), torch.empty_like(P_m)
    
    # find the unique values of the mask and make a sub-batches with it
    mask_values, indices = torch.unique(mask, return_inverse=True, dim=0)  
    for i, mask_v in enumerate(mask_values):
        idx_select = indices == i 
        corr_state_mean[idx_select], corr_state_cov[idx_select] = _filter_correct_batch(
            H, R, d,
            m_m[idx_select], P_m[idx_select],
            obs[idx_select], mask_v,
            cov_checker
        
        )
        assert all(mask[idx_select][0] == mask_v)
    
    return ListMNormal(corr_state_mean, corr_state_cov)

In [None]:
H, R, d,obs, mm = (k.H, k.R, k.d, data[:,0,:], mask[:,0,:])

In [None]:
corr_s_mean, corr_s_cov = _filter_correct(H, R, d, m_m, P_m, obs, mm)

In [None]:
show_as_row(corr_s_mean, corr_s_cov)

In [None]:
corr_s_mean.shape, corr_s_cov.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
corr_s_mean.sum().backward(retain_graph=True) # check that pytorch can compute gradients with the whole batch

### Filter

The resursive version of the kalman filter is apperently breaking pytorch gradients calculations so a workaround is needed.
During the loop the states are saved in a python list and then at the end they are combined back into a tensor.
The last line of the function does:

- convert lists to tensors
- correct order dimensions

In [None]:
#| export
def _times2batch(x):
    """Permutes `x` so that the first dimension is the number of batches and not the times"""
    return x.permute(1,0,-2,-1)

In [None]:
#| export
def _filter(A, H, B,
            Q, R,
            b, d,
            m0, P0,
            obs, mask, control,
            cov_checker=CheckPosDef()
           ) ->Tuple[List, List, List, List]: # m_ms, P_ms, ms, Ps
    """Filter observations using kalman filter """
    n_timesteps = obs.shape[-2]
    bs = obs.shape[0]
    # lists are mutable so need to copy them
    m_ms, P_ms, ms, Ps = [[None for _ in range(n_timesteps)].copy() for _ in range(4)] 

    for t in range(n_timesteps):
        if t == 0:
            m_ms[t], P_ms[t] = torch.stack([m0]*bs).unsqueeze(-1), torch.stack([P0]*bs)
        else:
            m_ms[t], P_ms[t] = _filter_predict(A, Q, b, B,
                                                                      ms[t - 1], Ps[t - 1], control[:,t,:],
                                                                      cov_checker.add_args(t=t))

        ms[t], Ps[t] = _filter_correct(H, R, d,
                                                                     m_ms[t], P_ms[t],
                                                                     obs[:,t,:], mask[:,t,:],
                                                                     cov_checker.add_args(t=t))
    
    ret = list(maps(torch.stack, _times2batch, (m_ms, P_ms, ms, Ps,)))
    return ret

In [None]:
obs, m0, P0 = data, k.m0, k.P0

In [None]:
m_ms, P_ms, ms, Ps = _filter(
    A, H, B,
    Q, R,
    b, d,
    m0, P0,
    data, mask, control)

Predictions at time `0` for both batches

In [None]:
show_as_row(list(map(Self.shape(), (m_ms, P_ms, ms, Ps,))))

In [None]:
show_as_row(list(map(lambda x:x[0][0], (m_ms, P_ms, ms, Ps,))))

### KalmanFilter method

In [None]:
#| export
@patch
def _filter_all(self: KalmanFilter, obs, mask, control
               ) ->Tuple[List, List, List, List]: # m_ms, P_ms, ms, Ps
    """ wrapper around `_filter`"""
    obs, mask = self._parse_obs(obs, mask)
    return _filter(
            self.A, self.H,
            self.B if self.use_control else torch.zeros_like(self.B),
            self.Q, self.R,
            self.b, self.d,
            self.m0, self.P0,
            obs, mask, control,
            self.cov_checker
        )

In [None]:
pred_mean, _, _, _ = k._filter_all(obs, mask, control);

In [None]:
type(k._filter_all(obs, mask, control))

list

In [None]:
pred_mean.sum().backward(retain_graph=True) # it works!

The filter methods wraps `_filter_all` but in addition:

- returns only filtered state
- detach tensors

In [None]:
#| export
@patch
def filter(self: KalmanFilter,
          obs: Tensor, # [n_timesteps, n_dim_obs] obs for times [0...n_timesteps-1]
          mask: Tensor,  # [n_timesteps, n_dim_obs] obs for times [0...n_timesteps-1]
          control: Tensor, # [n_timesteps, n_dim_contr] control for times [1...n_timesteps-1]
          ) -> ListMNormal: # Filtered state
    """Filter observation"""
    _, _, ms, Ps = self._filter_all(obs, mask, control)
    return ListMNormal(ms.squeeze(-1), Ps)

In [None]:
filt = k.filter(obs, mask, control)
filt.mean.shape, filt.cov.shape

(torch.Size([2, 10, 4]), torch.Size([2, 10, 4, 4]))

## Smooth

### Smooth step

compute the probability of the state at time `t` given all the observations

$p(x_t|Y) = \mathcal{N}(x_t; m_t^s, P_t^s)$ where:

- Kalman smoothing gain: $G_t = P_tA^T(P_{t+1}^-)^{-1}$
- smoothed mean: $m_t^s = m_t + G_t(m_{t+1}^s - m_{t+1}^-)$
- smoothed covariance: $P_t^s = P_t + G_t(P_{t+1}^s - P_{t+1}^-)G_t^T$

In [None]:
#| export
def _smooth_update(A,                # [n_dim_state, n_dim_state]
                   filt_state: MNormal,         # [n_dim_state] filtered state at time `t`
                   pred_state: MNormal,         # [n_dim_state] state before filtering at time `t + 1` (= using the observation until time t)
                   next_smoothed_state: Normal, # [n_dim_state] smoothed state at time  `t+1`
                   cov_checker = CheckPosDef()
                   ) -> MNormal:                # mean and cov of smoothed state at time `t`
    """Correct a pred state with a Kalman Smoother update"""
    kalman_smoothing_gain = filt_state.cov @ A.unsqueeze(0).mT @ torch.inverse(pred_state.cov) # torch.cholesky_inverse(torch.linalg.cholesky(pred_state.cov))

    m_p = filt_state.mean + kalman_smoothing_gain @ (next_smoothed_state.mean - pred_state.mean)
    P_p = filt_state.cov + kalman_smoothing_gain @ (next_smoothed_state.cov - pred_state.cov) @ kalman_smoothing_gain.mT

    cov_checker.check(P_p, caller='smooth_update')
    
    return MNormal(m_p, P_p)

In [None]:
filt_state, pred_state, next_smoothed_state = [MNormal(m_m, P_m)] * 3 # just for testing

In [None]:
show_as_row(*_smooth_update(A, MNormal(m_m, P_m), MNormal(m_m, P_m), MNormal(m_m, P_m)))

In [None]:
show_as_row(*map(Self.shape(), _smooth_update(A, MNormal(m_m, P_m), MNormal(m_m, P_m), MNormal(m_m, P_m))))

### Smooth

In [None]:
#| export
def _smooth(A, # `[n_dim_state, n_dim_state]`
            filt_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `ms[t]` is the state estimate for time t given obs from times `[0...t]`
            pred_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `m_ms[t]` is the state estimate for time t given obs from times `[0...t-1]`
            cov_checker = CheckPosDef()
           ) -> ListMNormal: # `[n_timesteps, n_dim_state]` Smoothed state 
    """Apply the Kalman Smoother"""
    x = pred_state.mean # sample for getting tensor properties
    bs, n_timesteps, n_dim_state = x.shape[0], x.shape[1], x.shape[2]

    smoothed_state = ListMNormal(torch.zeros((bs, n_timesteps,n_dim_state,1),             dtype=x.dtype, device=x.device), 
                                 torch.zeros((bs, n_timesteps, n_dim_state,n_dim_state), dtype=x.dtype, device=x.device))
    # For the last timestep cannot use the smoother
    smoothed_state.mean[:,-1,] = filt_state.mean[:,-1]
    smoothed_state.cov[:,-1] = filt_state.cov[:,-1]

    for t in reversed(range(n_timesteps - 1)):
        (smoothed_state.mean[:,t], smoothed_state.cov[:,t]) = (
            _smooth_update(
                A,
                filt_state[:,t],
                pred_state[:,t + 1],
                smoothed_state[:,t+1],
            )
        )
    return smoothed_state

In [None]:
(m_ms, P_ms, ms, Ps ) = k._filter_all(data, mask, control)
filt_state, pred_state = ListMNormal(ms, Ps), ListMNormal(m_ms, P_ms)

In [None]:
smooth_state = _smooth(k.A,  filt_state, pred_state)

In [None]:
show_as_row(smooth_state.mean[0][0], smooth_state.cov[0][0])

In [None]:
show_as_row(smooth_state.mean.shape, smooth_state.cov.shape)

### KalmanFilter method

In [None]:
#| export
@patch
def smooth(self: KalmanFilter,
           obs: Tensor,
           mask: Tensor,
           control: Tensor
          ) -> ListMNormal: # `[n_timesteps, n_dim_state]` smoothed state
        
    """Kalman Filter Smoothing"""

    (m_ms, P_ms, ms, Ps) = self._filter_all(obs, mask, control)

    smoothed_state = _smooth(self.A,
                   ListMNormal(ms, Ps), ListMNormal(m_ms, P_ms),
                   self.cov_checker)
    smoothed_state.mean.squeeze_(-1)
    return smoothed_state

In [None]:
smoothed_state = k.smooth(data, mask, control)

In [None]:
show_as_row(smoothed_state.mean.shape, smoothed_state.cov.shape)

## Predict

The prediction at time t ($y_t$) are computed rom the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

this works both if the state was filtered or smoother

This add the supports for conditional predictions, which means that at the time (t) when we are making the predictions some of the variables have been actually observed. Since the model prediction is a normal distribution we can condition on the observed values and thus improve the predictions. See `conditional_gaussian`

In order to have conditional predictions that make sense it's not possible to return the full covariance matrix for the predictions but only the standard deviations

In [None]:
test_m = torch.tensor(
    [[True, True, True,],
    [False, True, True],
    [False, False, False]]
)

In [None]:
torch.logical_xor(test_m.all(-1), test_m.any(-1))

tensor([False,  True, False])

In [None]:
A = torch.rand(2,2,3,3)

In [None]:
(A @ A).shape

torch.Size([2, 2, 3, 3])

predict can be vectorized across both the batch and the timesteps, except for timesteps that require conditional predictions

In [None]:
#| export
@patch
def _obs_from_state(self: KalmanFilter, state: ListMNormal):

    mean = self.H @ state.mean.unsqueeze(-1) + self.d.unsqueeze(-1)
    cov = self.H @ state.cov @ self.H.mT + self.R
    
    for c in cov: # this is batched and for all timestamps
        self.cov_checker.check(c, caller='predict')
    
    return ListMNormal(mean.squeeze(-1), cov)

In [None]:
smoothed_state.mean.shape, smoothed_state.cov.shape

(torch.Size([2, 10, 4]), torch.Size([2, 10, 4, 4]))

In [None]:
(k.H @ smoothed_state.mean.unsqueeze(-1)).shape

torch.Size([2, 10, 3, 1])

In [None]:
pred_obs0 = k._obs_from_state(smoothed_state)
pred_obs0.mean.shape

torch.Size([2, 10, 3])

In [None]:
pred_obs0.cov.shape

torch.Size([2, 10, 3, 3])

In [None]:
#| export
@patch
def predict(self: KalmanFilter, obs, mask, control, smooth=True):
    """Predicted observations at all times """
    state = self.smooth(obs, mask, control) if smooth else self.filter(obs, mask, control)
    obs, mask = self._parse_obs(obs, mask)
    
    pred_obs = self._obs_from_state(state)
    pred_mean, pred_std = pred_obs.mean, cov2std(pred_obs.cov)
    
    if self.use_conditional:
        # conditional predictions are slow, do only if some obs are missing 
        cond_mask = torch.logical_xor(mask.all(-1), mask.any(-1))

        # this cannot be batched so returns a list
        cond_preds = cond_gaussian_batched(
            pred_obs[cond_mask], obs[cond_mask], mask[cond_mask])
    
        for i, c_pred in enumerate(cond_preds):
            m = ~mask[cond_mask][i]
            pred_mean[cond_mask][i][m] = c_pred.mean
            pred_std [cond_mask][i][m] = cov2std(c_pred.cov)
    
    return ListNormal(pred_mean, pred_std)

In [None]:
pred = k.predict(data, mask, control)

In [None]:
pred.mean.shape, pred.std.shape

(torch.Size([2, 10, 3]), torch.Size([2, 10, 3]))

Gradients ...

In [None]:
def get_grad_mask(x):
    "filter gradient after sub the masks value with x"
    d = data.clone()
    d[~mask] = x
    k.predict(data, mask, control).mean.sum().backward(retain_graph=True)
    grad = k.R_raw.grad.clone()
    k.zero_grad() 
    return grad

In [None]:
get_grad_mask(10)

tensor([-1.9622,  2.6291,  6.2310], dtype=torch.float64)

In [None]:
test_close(get_grad_mask(1), get_grad_mask(10))

In [None]:
@patch
def predict_times(self: KalmanFilter, times, obs, mask=None, smooth=True, check_args=None):
    """Predicted observations at specific times """
    state = self.smooth(obs, mask, check_args) if smooth else self.filter(obs, mask, check_args)
    obs, mask = self._parse_obs(obs, mask)
    times = array1d(times)
    
    n_timesteps = obs.shape[0]
    n_features = obs.shape[1] if len(obs.shape) > 1 else 1
    
    if times.max() > n_timesteps or times.min() < 0:
        raise ValueError(f"provided times range from {times.min()} to {times.max()}, which is outside allowed range : 0 to {n_timesteps}")

    means = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device)
    stds = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device) 
    for i, t in enumerate(times):
        mean, std = self._obs_from_state(
            state.mean[t],
            state.cov[t],
            {'t': t, **check_args} if check_args is not None else None
        )
        
        means[i], stds[i] = _get_cond_pred(ListNormal(mean, std), obs[t], mask[t])
    
    return ListNormal(means, stds)  

## Additional

### Get Info

In [None]:
k.H

Parameter containing:
tensor([[0.6627, 0.2116, 0.6291, 0.3670],
        [0.2237, 0.0553, 0.2241, 0.4903],
        [0.4756, 0.3574, 0.5449, 0.4099]], dtype=torch.float64,
       requires_grad=True)

In [None]:
#| export
@patch
def get_info(self: KalmanFilter):
    out = {}
    var_names = ifnone(self.var_names, [f"y_{i}" for i in range(self.H.shape[0])])
    latent_names = [f"x_{i}" for i in range(self.A.shape[0])]
    contr_names = ifnone(self.contr_names, [f"c_{i}" for i in range(self.B.shape[1])])
    out['trans matrix (A)'] = array2df(self.A,    latent_names, latent_names, 'state')
    out['trans cov (Q)']    = array2df(self.Q,       latent_names, latent_names, 'state')
    out['trans off']        = array2df(self.b,       latent_names, ['offset'],   'state')
    out['obs matrix (H)']   = array2df(self.H,      var_names,    latent_names, 'variable')
    out['obs cov (R)']      = array2df(self.R,         var_names,    var_names,    'variable')
    out['obs off']          = array2df(self.d,         var_names,    ['offset'],   'variable')
    out['contr matrix (B)'] = array2df(self.B,    latent_names, contr_names,  'state')
    out['init state mean']  = array2df(self.m0, latent_names, ['mean'],     'state')
    out['init state cov']   = array2df(self.P0,  latent_names, latent_names, 'state')

    return out

In [None]:
k.B

Parameter containing:
tensor([[0.3640, 0.6339, 0.1636],
        [0.7750, 0.8999, 0.4475],
        [0.2087, 0.9819, 0.7791],
        [0.4139, 0.9798, 0.9052]], dtype=torch.float64, requires_grad=True)

In [None]:
display_as_row(k.get_info())

state,x_0,x_1,x_2,x_3
x_0,0.3923,0.7703,0.7068,0.1368
x_1,0.6002,0.4108,0.3458,0.6117
x_2,0.7857,0.167,0.0884,0.5066
x_3,0.1431,0.509,0.9645,0.2983

state,x_0,x_1,x_2,x_3
x_0,1.1591,0.7234,1.4105,0.9115
x_1,0.7234,0.8544,0.7177,0.5024
x_2,1.4105,0.7177,1.9741,1.459
x_3,0.9115,0.5024,1.459,1.7778

state,offset
x_0,0.6728
x_1,0.9842
x_2,0.6946
x_3,0.0369

variable,x_0,x_1,x_2,x_3
y_0,0.6627,0.2116,0.6291,0.367
y_1,0.2237,0.0553,0.2241,0.4903
y_2,0.4756,0.3574,0.5449,0.4099

variable,y_0,y_1,y_2
y_0,1.5641,0.0,0.0
y_1,0.0,1.1859,0.0
y_2,0.0,0.0,0.693

variable,offset
y_0,0.4045
y_1,0.2622
y_2,0.4132

state,c_0,c_1,c_2
x_0,0.364,0.6339,0.1636
x_1,0.775,0.8999,0.4475
x_2,0.2087,0.9819,0.7791
x_3,0.4139,0.9798,0.9052

state,mean
x_0,0.383
x_1,0.4214
x_2,0.5458
x_3,0.5241

state,x_0,x_1,x_2,x_3
x_0,1.3892,0.7181,1.6393,1.3942
x_1,0.7181,0.8852,0.9592,0.8515
x_2,1.6393,0.9592,2.1083,1.658
x_3,1.3942,0.8515,1.658,1.4405


In [None]:
#| export
@patch
def _repr_html_(self: KalmanFilter):
    title = f"Kalman Filter ({self.n_dim_obs} obs, {self.n_dim_state} state, {self.n_dim_contr} contr)"
    return row_dfs(self.get_info(), title , hide_idx=True)

In [None]:
k

state,x_0,x_1,x_2,x_3
x_0,0.3923,0.7703,0.7068,0.1368
x_1,0.6002,0.4108,0.3458,0.6117
x_2,0.7857,0.167,0.0884,0.5066
x_3,0.1431,0.509,0.9645,0.2983

state,x_0,x_1,x_2,x_3
x_0,1.1591,0.7234,1.4105,0.9115
x_1,0.7234,0.8544,0.7177,0.5024
x_2,1.4105,0.7177,1.9741,1.459
x_3,0.9115,0.5024,1.459,1.7778

state,offset
x_0,0.6728
x_1,0.9842
x_2,0.6946
x_3,0.0369

variable,x_0,x_1,x_2,x_3
y_0,0.6627,0.2116,0.6291,0.367
y_1,0.2237,0.0553,0.2241,0.4903
y_2,0.4756,0.3574,0.5449,0.4099

variable,y_0,y_1,y_2
y_0,1.5641,0.0,0.0
y_1,0.0,1.1859,0.0
y_2,0.0,0.0,0.693

variable,offset
y_0,0.4045
y_1,0.2622
y_2,0.4132

state,c_0,c_1,c_2
x_0,0.364,0.6339,0.1636
x_1,0.775,0.8999,0.4475
x_2,0.2087,0.9819,0.7791
x_3,0.4139,0.9798,0.9052

state,mean
x_0,0.383
x_1,0.4214
x_2,0.5458
x_3,0.5241

state,x_0,x_1,x_2,x_3
x_0,1.3892,0.7181,1.6393,1.3942
x_1,0.7181,0.8852,0.9592,0.8515
x_2,1.6393,0.9592,2.1083,1.658
x_3,1.3942,0.8515,1.658,1.4405


### Constructors

#### Simple parameters

In [None]:
#| export
@patch(cls_method=True)
def init_simple(cls: KalmanFilter,
                n_dim, # n_dim_obs and n_dim_state
                dtype=torch.float64):
    """Simplest version of kalman filter parameters"""
    return cls(
        A =     torch.eye(n_dim, dtype=dtype),
        b =        torch.zeros(n_dim, dtype=dtype),        
        Q =        torch.eye(n_dim, dtype=dtype),        
        H =       torch.eye(n_dim, dtype=dtype),
        d =          torch.zeros(n_dim, dtype=dtype),          
        R =          torch.eye(n_dim, dtype=dtype),            
        B =     torch.eye(n_dim, dtype=dtype),
        m0 =  torch.zeros(n_dim, dtype=dtype),        
        P0 =   torch.eye(n_dim, dtype=dtype),
    )

In [None]:
KalmanFilter.init_simple(2).state_dict()

OrderedDict([('trans_matrix',
              tensor([[1., 0.],
                      [0., 1.]], dtype=torch.float64)),
             ('trans_off', tensor([0., 0.], dtype=torch.float64)),
             ('trans_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]], dtype=torch.float64)),
             ('contr_matrix',
              tensor([[1., 0.],
                      [0., 1.]], dtype=torch.float64)),
             ('obs_matrix',
              tensor([[1., 0.],
                      [0., 1.]], dtype=torch.float64)),
             ('obs_off', tensor([0., 0.], dtype=torch.float64)),
             ('obs_cov_raw', tensor([0., 0.], dtype=torch.float64)),
             ('init_state_mean', tensor([0., 0.], dtype=torch.float64)),
             ('init_state_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]], dtype=torch.float64))])

#### Local slope

Local slope models are an extentions of local level model that in the state variable keep track of also the slope

Given $n$ as the number of dimensions of the observations

The transition matrix (`A`) is:

$$A = \left[\begin{array}{cc}I & I \\ 0 & I\end{array}\right]$$

where:

- $I \in \mathbb{R}^{n \times n}$
- $A \in \mathbb{R}^{2n \times 2n}$

the state $x \in \mathbb{R}^{2N \times 1}$ where the upper half keep track of the level and the lower half of the slope. $A \in \mathbb{R}^2N \times 2N$

the observation matrix (`H`) is:

$$H = \left[\begin{array}{cc}I & 0 \end{array}\right]$$

For the multivariate case the 1 are replaced with an identiy matrix


assuming that the control has the same dimensions of the observations then if we are doing a local slope model we have $B \in \mathbb{R}^{state \times contr}$:
$$ B = \begin{bmatrix} -I & I \\ 0 & 0 \end{bmatrix}$$

In [None]:
#| export
from torch import hstack, eye, vstack, ones, zeros, tensor
from functools import partial
from sklearn.decomposition import PCA

In [None]:
#| exporti
def set_dtype(*args, dtype=torch.float64):
    return [partial(arg, dtype=dtype) for arg in args] 

eye, ones, zeros, tensor = set_dtype(eye, ones, zeros, tensor)

In [None]:
#| export
# @delegates(KalmanFilter)
@patch(cls_method=True)
def init_local_slope_pca(cls: KalmanFilter,
                n_dim_obs, # n_dim_obs and n_dim_contr
                n_dim_state: int, # n_dim_state
                df_pca: pd.DataFrame|None = None, # dataframe for PCA init, None no PCA init
                **kwargs
            ):
    """Local Slope + PCA init"""
    if df_pca is not None:
        comp = PCA(n_dim_state).fit(df_pca).components_
        H = tensor(comp.T) # transform state -> obs
        B = tensor(comp) # transform obs -> state
    else:
        H, B = eye(n_dim_obs), eye(n_dim_obs)
        
    return cls(
        A =     vstack([hstack([eye(n_dim_state),                eye(n_dim_state)]),
                                   hstack([zeros(n_dim_state, n_dim_state), eye(n_dim_state)])]),
        b =        zeros(n_dim_state * 2),        
        Q =        eye(n_dim_state * 2)*.1,        
        H =       hstack([H, zeros(n_dim_obs, n_dim_state)]),
        d =          zeros(n_dim_obs),          
        R =          eye(n_dim_obs)*.01,            
        B =     vstack([hstack([-B,                  B]),
                                   hstack([ zeros(n_dim_state,n_dim_obs), zeros(n_dim_state, n_dim_obs)])]),
        m0 =  zeros(n_dim_state * 2),        
        P0 =   eye(n_dim_state * 2) * 3,
        **kwargs
    ) 

In [None]:
KalmanFilter.init_local_slope_pca(2,2,pd.DataFrame([[1,2], [2,4]])).state_dict()

OrderedDict([('trans_matrix',
              tensor([[1., 0., 1., 0.],
                      [0., 1., 0., 1.],
                      [0., 0., 1., 0.],
                      [0., 0., 0., 1.]], dtype=torch.float64)),
             ('trans_off', tensor([0., 0., 0., 0.], dtype=torch.float64)),
             ('trans_cov_raw',
              tensor([[0.3162, 0.0000, 0.0000, 0.0000],
                      [0.0000, 0.3162, 0.0000, 0.0000],
                      [0.0000, 0.0000, 0.3162, 0.0000],
                      [0.0000, 0.0000, 0.0000, 0.3162]], dtype=torch.float64)),
             ('contr_matrix',
              tensor([[-0.4472, -0.8944,  0.4472,  0.8944],
                      [-0.8944,  0.4472,  0.8944, -0.4472],
                      [ 0.0000,  0.0000,  0.0000,  0.0000],
                      [ 0.0000,  0.0000,  0.0000,  0.0000]], dtype=torch.float64)),
             ('obs_matrix',
              tensor([[ 0.4472,  0.8944,  0.0000,  0.0000],
                      [ 0.8944, -0.4472,  0.0000, 

## Export

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()