# Kalman Filter
> Implementation of Kalman filters using pytorch and parameter optimizations with gradient descend

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
#| default_exp kalman.filter

In [None]:
#| export
from fastcore.test import *
from fastcore.basics import *
from meteo_imp.utils import *
from meteo_imp.gaussian import *
from meteo_imp.data_preparation import MeteoDataTest
from typing import *
from functools import partial

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

## Introduction

The models uses a latent state variable $x$ that is modelled over time, to impute gaps in $y$

The assumption of the model is that the state variable at time $x_t$ depends only on the last state $x_{t-1}$ and not on the previous states.

### Equations

The equations of the model are:

$$\begin{align} p(x_t | x_{t-1}) & = \mathcal{N}(Ax_{t-1} + b, Q) \\
p(y_t | x_t) & = \mathcal{N}(Hx_t + d, R) \end{align}$$


where:

- $A$ is the `A`
- $b$ is the `bset`
- $Q$ is the `Q`
- $H$ is the `obs_trans` 
- $d$ is the `d`
- $R$ is the `R`

in addition the model has also the parameters of the initial state that are used to initialize the filter:

- `m0`
- `P0`

The Kalman filter has 3 steps:

- filter (updating the state at time t with observations till time t-1)
- update (update the state at time t using the observation at time t)
- smooth (update the state using the observations at time t+1)

In case of missing data the update step is skipped.

After smoothing the missing data at time t ($y_t$) can be imputed from the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

The Kalman Filter is an algorithm designed to estimate $P(x_t | y_{0:t})$.  As all state transitions and obss are linear with Gaussian distributed noise, these distributions can be represented exactly as Gaussian distributions with mean `ms[t]` and covs `Ps[t]`.
Similarly, the Kalman Smoother is an algorithm designed to estimate $P(x_t | y_{0:t-1})$



# Kalman Filter Base

### Utils

In [None]:
#| export
def _add_batch_dim(x):
    """make x 3 dimensional by adding empty dims in the correct place"""
    if x.dim() == 1: return x.unsqueeze(0).unsqueeze(-1)
    elif x.dim() == 2: return x.unsqueeze(0)
    else: return x

In [None]:
#| export
def _add_batch_dims_iter(*xs):
    """vectorize `add_batch_dim`"""
    return [_add_batch_dim(x) for x in xs]

In [None]:
show_as_row(_add_batch_dim(torch.ones(2)), _add_batch_dim(torch.ones(2,2)), _add_batch_dim(torch.ones(2,2,2)))

In [None]:
show_as_row(_add_batch_dim(torch.ones(2)).shape, _add_batch_dim(torch.ones(2,2)).shape, _add_batch_dim(torch.ones(2,2,2)).shape)

In [None]:
#| export
def _check_same_size(
    os: Sequence[tuple[Tensor, int]], # sequences of tensors and the dimension to check
    size=None, # Optional size of the common dimension
)-> int: # size of common dimension
    """Check that all args have the same size at the given dimension, raise `ValueError` if not """
    size = ifnone(size, os[0][0].shape[os[0][1]])
    if not all([size == x.shape[dim] for x, dim in os]):
        raise ValueError("All parameters must have the same size at the given dimension")
    return size

## Kalman Filter Base

In [None]:
#| export
class KalmanFilterBase(torch.nn.Module):
    """Base class for handling Kalman Filter implementation in PyTorch"""
    
    params_constr = {
        #name constraint
        'A':  None        ,
        'b':  None        ,
        'Q':  PosDef(),
        'B':  None        ,
        'H':  None        ,
        'd':  None        ,
        'R':  PosDef(),
        'm0': None       ,
        'P0': PosDef()   ,
        }
    
    def __init__(self,
            A: Tensor,                             # [n_dim_state,n_dim_state] $A$, state transition matrix 
            H: Tensor,                             # [n_dim_obs, n_dim_state] $H$, observation matrix
            B: Tensor,                             # [n_dim_state, n_dim_contr] $B$ control matrix
            Q: Tensor,                             # [n_dim_state, n_dim_state] $Q$, state trans covariance matrix
            R: Tensor,                             # [n_dim_obs, n_dim_obs] $R$, observations covariance matrix
            b: Tensor,                             # [n_dim_state] $b$, state transition offset
            d: Tensor,                             # [n_dim_obs] $d$, observations offset
            m0: Tensor,                            # [n_dim_state] $m_0$
            P0: Tensor,                            # [n_dim_state, n_dim_state] $P_0$
    
            n_dim_state: int = None,               # Number of dimensions for state - default infered from parameters
            n_dim_obs: int = None,                 # Number of dimensions for observations - default  infered from parameters
            n_dim_contr: int = None,               # Number of dimensions for control - default infered from parameters
            
            var_names: Iterable[str]|None = None,  # Names of variables for printing 
            contr_names: Iterable[str]|None = None,# Names of control variables for printing
    
            cov_checker: CheckPosDef|None = None,  # Check covariance at every step
            use_conditional: bool = True,          # Use conditional distribution for gaps that don't have all variables missing
            use_control: bool = True,              # Use the control in the filter
            use_smooth: bool = True,               # Use smoother for predictions (otherwise is filter only)
                ):
        
        super().__init__()
        store_attr("var_names, contr_names, use_conditional, use_control, use_smooth, cov_checker")
        
        A, H, B, Q, R, b, d, m0, P0 = _add_batch_dims_iter(A, H, B, Q, R, b, d, m0, P0)
        
        self._check_params(A, H, B, Q, R, b, d, m0, P0, n_dim_state, n_dim_obs, n_dim_contr)
        self._init_params(A=A, H=H, B=B, Q=Q, R=R, b=b, d=d, m0=m0, P0=P0)

    
    def _check_params(self, A, H, B, Q, R, b, d, m0, P0, n_dim_state, n_dim_obs, n_dim_contr):
        """Checks that the parameters are dimensions are consistent and sets n_dim"""
        self.n_dim_state = _check_same_size(
            [(A,  -2),
             (b,  -2),
             (Q,  -2),
             (m0, -2),
             (P0, -2),
             (H,  -1)],
            n_dim_state
        )
        self.n_dim_obs = _check_same_size(
            [(H, -2),
             (d, -2),
             (R, -2)],
            n_dim_obs
        )
        
        self.n_dim_contr = _check_same_size([(B, -1)], n_dim_contr)
        
        
    def _init_params(self, **params):
        for name, value in params.items():
            if (constraint := self.params_constr[name]) is not None:
                name, value = self._init_constraint(name, value, constraint)
            self._init_param(name, value, train=True)    
    
    def _init_param(self, param_name, value, train):
        self.register_parameter(param_name, torch.nn.Parameter(value, requires_grad=train))
    
    ### === Constraints utils
    def _init_constraint(self, param_name, value, constraint):
        name = f"{param_name}_raw"
        value = constraint.inverse_transform(value)
        setattr(self, param_name + "_constraint", constraint)
        return name, value
    
    def _get_constraint(self, param_name):
        """get the original value"""
        constraint = getattr(self, param_name + "_constraint")
        raw_value = getattr(self, f"{param_name}_raw")
        return constraint.transform(raw_value)

    def _set_constraint(self, value, param_name, train=True):
            """set the transformed value"""
            constraint = getattr(self, param_name + "_constraint")
            raw_value = constraint.inverse_transform(value)
            self._init_param(f"{param_name}_raw", raw_value, train)
            
               
    @property
    def Q_C(self):
        "Cholesky factor of Q"
        return self.Q_raw
    @property
    def Q(self): return self._get_constraint('Q')
    @Q.setter
    def Q(self, value): self._set_constraint(value, 'Q')

    @property
    def R_C(self):
        "Cholesky factor of R"
        return self.R_raw
    @property
    def R(self): return self._get_constraint('R')
    @R.setter
    def R(self, value): self._set_constraint(value, 'R')
    
    @property
    def P0_C(self):
        "Cholesky factor of P0"
        return self.P0_raw
    @property
    def P0(self): return self._get_constraint('P0')
    @P0.setter
    def P0(self, value): self._set_constraint(value, 'P0')


    ### === Utility Func    
    def _parse_obs(self, obs, mask, control):
        """maybe get mask from `nan`"""
        # if mask is None: mask = ~torch.isnan(obs)
        return _add_batch_dim(obs).unsqueeze(-1), _add_batch_dim(mask), _add_batch_dim(control).unsqueeze(-1)
    def __repr__(self):
        return f"""Kalman Filter
        N dim obs: {self.n_dim_obs},
        N dim state: {self.n_dim_state},
        N dim contr: {self.n_dim_contr}"""

### Constructors

Giving all the parameters manually to the `KalmanFilterBase` init method is not convenient, hence we are having some methods that help initize the class

due to a bug in fastcore cannot subclass after creating class methods

In [None]:
#| export
class KalmanFilter(KalmanFilterBase):
    pass

In [None]:
#| export
class KalmanFilterSR(KalmanFilterBase):
    pass

In [None]:
#| export
filter_classes = [KalmanFilterBase, KalmanFilter, KalmanFilterSR]

#### Random parameters

In [None]:
#| export
#| include: false
@patch_to(filter_classes, cls_method=True)
def init_random(cls, n_dim_obs, n_dim_state, n_dim_contr, dtype=torch.float64, **kwargs):
    """kalman filter with random parameters"""
    return cls(
        A  = torch.rand(n_dim_state, n_dim_state, dtype=dtype),
        b  = torch.rand(n_dim_state, dtype=dtype),        
        Q  = to_diagposdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),        
        B  = torch.rand(n_dim_state, n_dim_contr, dtype=dtype),
        H  = torch.rand(n_dim_obs, n_dim_state, dtype=dtype),
        d  = torch.rand(n_dim_obs, dtype=dtype),          
        R  = to_diagposdef(torch.rand(n_dim_obs, n_dim_obs, dtype=dtype)),            
        m0 = torch.rand(n_dim_state, dtype=dtype),        
        P0 = to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),
        **kwargs) 
        

In [None]:
kB = KalmanFilterBase.init_random(3,4, 3, dtype=torch.float64)
kB

Kalman Filter
        N dim obs: 3,
        N dim state: 4,
        N dim contr: 3

In [None]:
test_close(kB.Q_C @ kB.Q_C.mT, kB.Q, eps=2e-5)

In [None]:
kB.P0 = to_posdef(torch.rand(1,3,3))

check that assigment works :)

In [None]:
kB.P0 = to_posdef(torch.rand(4, 4, dtype=torch.float64))

In [None]:
kB.P0_C

Parameter containing:
tensor([[0.9340, 0.0000, 0.0000, 0.0000],
        [1.0605, 0.2445, 0.0000, 0.0000],
        [1.0207, 0.4495, 0.8036, 0.0000],
        [0.4206, 0.1265, 0.0767, 0.0983]], dtype=torch.float64,
       requires_grad=True)

In [None]:
kB = KalmanFilterBase.init_random(3,4, 3, dtype=torch.float64)
kB

Kalman Filter
        N dim obs: 3,
        N dim state: 4,
        N dim contr: 3

In [None]:
list(kB.named_parameters())

[('A',
  Parameter containing:
  tensor([[[0.2082, 0.1327, 0.0447, 0.6860],
           [0.1764, 0.5756, 0.4995, 0.9907],
           [0.6185, 0.2050, 0.4548, 0.4365],
           [0.7427, 0.5919, 0.4975, 0.0609]]], dtype=torch.float64,
         requires_grad=True)),
 ('H',
  Parameter containing:
  tensor([[[0.2712, 0.3454, 0.9073, 0.9889],
           [0.3629, 0.4210, 0.5668, 0.9796],
           [0.8114, 0.0584, 0.9681, 0.8442]]], dtype=torch.float64,
         requires_grad=True)),
 ('B',
  Parameter containing:
  tensor([[[0.9434, 0.6200, 0.8832],
           [0.9176, 0.5619, 0.1946],
           [0.5095, 0.0843, 0.9053],
           [0.5472, 0.8411, 0.9707]]], dtype=torch.float64, requires_grad=True)),
 ('Q_raw',
  Parameter containing:
  tensor([[[0.1125, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.4965, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.9714, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.9973]]], dtype=torch.float64,
         requires_grad=True)),
 ('R_raw',
  Parameter

In [None]:
kB.state_dict()

OrderedDict([('A',
              tensor([[[0.2082, 0.1327, 0.0447, 0.6860],
                       [0.1764, 0.5756, 0.4995, 0.9907],
                       [0.6185, 0.2050, 0.4548, 0.4365],
                       [0.7427, 0.5919, 0.4975, 0.0609]]], dtype=torch.float64)),
             ('H',
              tensor([[[0.2712, 0.3454, 0.9073, 0.9889],
                       [0.3629, 0.4210, 0.5668, 0.9796],
                       [0.8114, 0.0584, 0.9681, 0.8442]]], dtype=torch.float64)),
             ('B',
              tensor([[[0.9434, 0.6200, 0.8832],
                       [0.9176, 0.5619, 0.1946],
                       [0.5095, 0.0843, 0.9053],
                       [0.5472, 0.8411, 0.9707]]], dtype=torch.float64)),
             ('Q_raw',
              tensor([[[0.1125, 0.0000, 0.0000, 0.0000],
                       [0.0000, 0.4965, 0.0000, 0.0000],
                       [0.0000, 0.0000, 0.9714, 0.0000],
                       [0.0000, 0.0000, 0.0000, 0.9973]]], dtype=torch.float64)

#### From filter

In [None]:
#| export
@patch(cls_method=True)
def init_from(cls: KalmanFilter|KalmanFilterBase|KalmanFilterSR, o: filter_classes # Other filter
             ):
    """Initialize Filter by copying all parameters from another one"""
    return cls(o.A, o.H, o.B, o.Q, o.R, o.b, o.d, o.m0, o.P0,
               o.n_dim_state, o.n_dim_obs, o.n_dim_contr,
               o.var_names, o.contr_names, o.cov_checker,
               o.use_conditional, o.use_control, o.use_smooth)

In [None]:
k1 = KalmanFilter.init_random(3,4,3)
k2 = KalmanFilterSR.init_from(k1)
for p1, p2 in zip(k1.parameters(), k2.parameters()):
    test_close(p1,p2, eps=1e-3) #noise added by contraints
                

### Get Info

In [None]:
#| export
@patch
def get_info(self: KalmanFilterBase):
    out = {}
    var_names = ifnone(self.var_names, [f"y_{i}" for i in range(self.n_dim_obs)])
    latent_names = [f"x_{i}" for i in range(self.n_dim_state)]
    contr_names = ifnone(self.contr_names, [f"c_{i}" for i in range(self.n_dim_contr)])
    out['$A$'] = array2df(self.A[0] , latent_names, latent_names, 'state')
    out['$Q$']    = array2df(self.Q[0] , latent_names, latent_names, 'state')
    out['$b$']        = array2df(self.b[0] , latent_names, ['offset'],   'state')
    out['$H$']   = array2df(self.H[0] , var_names,    latent_names, 'variable')
    out['$R$']      = array2df(self.R[0] , var_names,    var_names,    'variable')
    out['$d$']          = array2df(self.d[0] , var_names,    ['offset'],   'variable')
    out['$B$'] = array2df(self.B[0] , latent_names, contr_names,  'state')
    out['$m_0$']  = array2df(self.m0[0], latent_names, ['mean'],     'state')
    out['$P_0$']   = array2df(self.P0[0], latent_names, latent_names, 'state')

    return out

In [None]:
#| export
@patch
def _repr_html_(self: filter_classes):
    title = f"Kalman Filter ({self.n_dim_obs} obs, {self.n_dim_state} state, {self.n_dim_contr} contr)"
    return row_dfs(self.get_info(), title , hide_idx=True)

In [None]:
kB

state,x_0,x_1,x_2,x_3
x_0,0.2082,0.1327,0.0447,0.686
x_1,0.1764,0.5756,0.4995,0.9907
x_2,0.6185,0.205,0.4548,0.4365
x_3,0.7427,0.5919,0.4975,0.0609

state,x_0,x_1,x_2,x_3
x_0,0.0127,0.0,0.0,0.0
x_1,0.0,0.2465,0.0,0.0
x_2,0.0,0.0,0.9436,0.0
x_3,0.0,0.0,0.0,0.9946

state,offset
x_0,0.317
x_1,0.5824
x_2,0.8291
x_3,0.61

variable,x_0,x_1,x_2,x_3
y_0,0.2712,0.3454,0.9073,0.9889
y_1,0.3629,0.421,0.5668,0.9796
y_2,0.8114,0.0584,0.9681,0.8442

variable,y_0,y_1,y_2
y_0,0.568,0.0,0.0
y_1,0.0,0.4421,0.0
y_2,0.0,0.0,0.3442

variable,offset
y_0,0.8146
y_1,0.4824
y_2,0.2898

state,c_0,c_1,c_2
x_0,0.9434,0.62,0.8832
x_1,0.9176,0.5619,0.1946
x_2,0.5095,0.0843,0.9053
x_3,0.5472,0.8411,0.9707

state,mean
x_0,0.3079
x_1,0.7986
x_2,0.6972
x_3,0.5632

state,x_0,x_1,x_2,x_3
x_0,1.2375,0.4817,1.4902,0.8995
x_1,0.4817,0.6202,1.0675,0.4405
x_2,1.4902,1.0675,2.7101,1.0355
x_3,0.8995,0.4405,1.0355,0.8541


### Test data

In [None]:
#| exporti
def get_test_data(n_obs = 10, n_dim_obs=3, n_dim_contr = 3, gap=.3, fixed_gap=False, bs=2, dtype=torch.float64, device='cpu'):
    data = torch.rand(bs, n_obs, n_dim_obs, dtype=dtype, device=device)
    mask = torch.rand(bs, n_obs, n_dim_obs, device=device)
    if fixed_gap:
        mask[:, n_obs//2-gap//2,n_obs//2+gap//2, :] = False
    else:
        mask = mask > gap
    control = torch.rand(bs, n_obs, n_dim_contr, dtype=dtype, device=device)
    data[~mask] = torch.nan # ensure that the missing data cannot be used
    return data, mask, control

In [None]:
reset_seed()
data, mask, control = get_test_data()
show_as_row(data, mask, control)

# Standard Kalman Filter

In [None]:
k = KalmanFilter.init_random(3,4,3)

## Filter

### Filter predict

Probability of state at time `t` given state a time `t-1` 

$p(x_t) = \mathcal{N}(x_t; m_t^-, P_t^-)$ where:

- predicted state mean: $m_t^- = Am_{t-1} + B c_t + b$  

- predicted state covariance: $P_t^- = AP_{t-1}A^T + Q$

In [None]:
A, Q, b, B, m_pr,P_pr= (k.A, k.Q, k.b, k.B,torch.concat([k.m0]*2), torch.concat([k.P0]*2))

In [None]:
m_pr.shape, P_pr.shape, A.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]), torch.Size([1, 4, 4]))

#### Covariance

implement $P_t^- = AP_{t-1}A^T + Q$

In [None]:
#| export
def _filter_predict_cov_stand(A, Q, P_pr):
    """Standard - Kalman Filter predict covariance"""
    return A @ P_pr @ A.mT + Q

In [None]:
P_m = _filter_predict_cov_stand(A, Q, P_pr)
P_m

tensor([[[4.8187, 4.4353, 4.9715, 4.9817],
         [4.4353, 5.3111, 5.1082, 5.1311],
         [4.9715, 5.1082, 5.7240, 5.7099],
         [4.9817, 5.1311, 5.7099, 6.3261]],

        [[4.8187, 4.4353, 4.9715, 4.9817],
         [4.4353, 5.3111, 5.1082, 5.1311],
         [4.9715, 5.1082, 5.7240, 5.7099],
         [4.9817, 5.1311, 5.7099, 6.3261]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

#### Mean

In [None]:
#| export
def _filter_predict_mean(
    A,      # transition matrix
    B,      # control matrix
    b,      # transition offset
    m_pr,   # Mean previous time step $m_{t-1}$
    control, # control variable
):
    return A @ m_pr + B @ control + b

#### Predict

In [None]:
#| export
def _filter_predict(A,
                    Q,
                    b,
                    B, #[n_dim_state, n_dim_contr]
                    m_pr,
                    P_pr,
                    control, #[n_batches, n_dim_contr]
                    ):
    """Calculate the state at time `t` given the state at time `t-1`"""
    m_m = _filter_predict_mean(A, B, b, m_pr, control)
    P_m = _filter_predict_cov_stand(A, Q, P_pr)
    return (m_m, P_m)

In [None]:
#| export
def _filter_predict_mean(
    A,      # transition matrix
    B,      # control matrix
    b,      # transition offset
    m_pr,   # Mean previous time step $m_{t-1}$
    control, # control variable
):
    return A @ m_pr + B @ control + b

In [None]:
#| export
def unsqueeze_iter(*args, dim): return list(map(partial(torch.unsqueeze, dim=dim), args))
unsqueeze_first = partial(unsqueeze_iter, dim=0)
unsqueeze_last = partial(unsqueeze_iter, dim=-1)

In [None]:
#| export
def _filter_predict(A,
                    Q,
                    b,
                    B, #[n_dim_state, n_dim_contr]
                    m_pr,
                    P_pr,
                    control, #[n_batches, n_dim_contr]
                    ):
    """Calculate the state at time `t` given the state at time `t-1`"""
    m_m = _filter_predict_mean(A, B, b, m_pr, control)
    P_m = _filter_predict_cov_stand(A, Q, P_pr)
    return (m_m, P_m)

In [None]:
B.shape

torch.Size([1, 4, 3])

In [None]:
control.shape

torch.Size([2, 10, 3])

In [None]:
B[0].shape

torch.Size([4, 3])

In [None]:
B[0] @ control[0, 0].unsqueeze(-1)

tensor([[0.7217],
        [0.9572],
        [0.5742],
        [1.2138]], dtype=torch.float64, grad_fn=<MmBackward0>)

In [None]:
m_m, P_m = _filter_predict(
    A, Q, b, B,
    m_pr,P_pr, control[:, 0, :].unsqueeze(-1))

In [None]:
show_as_row(m_m, P_m)

In [None]:
(m_m.shape, P_m.shape,)

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

### Filter update

Probability of state at time `t` given the observations at time `t`

$p(x_t|y_t) = \mathcal{N}(x_t; m_t, P_t)$ where:

- predicted obs mean: $z_t = Hm_t^- + d$  

- prediced obs covariance: $S_t = HP_t^-H^T + R$

- kalman gain$K_t = P_t^-H^TS_t^{-1}$ 

- corrected state mean: $m_t = m_t^- + K_t(y_t - z_t)$ 

- corrected state covariance: $P_t = (I-K_tH)P_t^-$ 

if the observation are missing this step is skipped and the corrected state is equal to the predicted state


Need to figure out the Nans for the gradients ...

#### Kalman Gain

Don't compute the inverse of the matrix, but use `cholesky_solve` to invert the matrix

In [None]:
H, d, R, obs = k.H, k.d, k.R, data[:,0,:].unsqueeze(-1)

In [None]:
#| export
def _filter_update_k_gain(H, R,P_m):
    "kalman gain for filter update"
    S = H @ P_m @ H.mT + R
    S_C = torch.linalg.cholesky(S)
    return torch.cholesky_solve(H @ P_m.mT, S_C).mT

In [None]:
K = _filter_update_k_gain(H, R, P_m)
K

tensor([[[ 0.0307,  0.3311,  0.0291],
         [-0.2621,  0.4961,  0.0624],
         [ 0.0015,  0.4326, -0.0126],
         [ 0.2272,  0.3601, -0.0535]],

        [[ 0.0307,  0.3311,  0.0291],
         [-0.2621,  0.4961,  0.0624],
         [ 0.0015,  0.4326, -0.0126],
         [ 0.2272,  0.3601, -0.0535]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
test_close(_filter_update_k_gain(H, R, P_m), P_m @ H.mT @ torch.inverse(H @ P_m @ H.mT + R))

#### Covariance

In [None]:
#| export
def _filter_update_cov(H, K, P_m):
    return (eye_like(P_m) - K @ H) @ P_m

In [None]:
P = _filter_update_cov(H, K, P_m)
P

tensor([[[ 0.5970, -0.1075,  0.1579, -0.0676],
         [-0.1075,  0.3210, -0.0807, -0.2334],
         [ 0.1579, -0.0807,  0.2325, -0.0442],
         [-0.0676, -0.2334, -0.0442,  0.2359]],

        [[ 0.5970, -0.1075,  0.1579, -0.0676],
         [-0.1075,  0.3210, -0.0807, -0.2334],
         [ 0.1579, -0.0807,  0.2325, -0.0442],
         [-0.0676, -0.2334, -0.0442,  0.2359]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

#### Mean

In [None]:
z = H @ m_m + d; z
(obs - z)

tensor([[[-3.3343],
         [    nan],
         [    nan]],

        [[-2.4181],
         [-5.0754],
         [-3.9145]]], dtype=torch.float64, grad_fn=<SubBackward0>)

In [None]:
#| export
def _filter_update_mean(H, d, K, m_m, y):
    z = H @ m_m + d
    return m_m + K @ (y - z)

In [None]:
m = _filter_update_mean(H, d, K, m_m, obs)
m

tensor([[[    nan],
         [    nan],
         [    nan],
         [    nan]],

        [[ 0.2687],
         [-0.5319],
         [-0.6131],
         [ 0.6636]]], dtype=torch.float64, grad_fn=<AddBackward0>)

In [None]:
#| export
def _filter_update(
    H, # [1, n_dim_obs, n_dim_state]
    d, # [1, n_dim_obs, 1]
    R, # [1, n_dim_obs, n_dim_obs]
    m_m, # [n_batches, n_dim_state, 1]
    P_m, # [n_batches, n_dim_state, n_dim_state]
    obs # # [n_batches, n_dim_obs, 1]
) -> Tuple: # Filtered state (mean, covariance) [n_batches, n_dim_state]
    "Filter update state at `t` with obs at `t`"
    K = _filter_update_k_gain(H, R, P_m)
    m = _filter_update_mean(H, d, K, m_m, obs)
    P = _filter_update_cov(H, K, P_m)
    return m, P

In [None]:
m, P = _filter_update(H, d, R, m_m, P_m, obs)
show_as_row(m, P)
m.shape, P.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

there are `nan` in the output because there are `nan` in the observations

The next functions adds the support for missing obsevations by also using a mask

#### Missing observations

If all the observations at time $t$ are missing the correct step is skipped and the filtered state at time $t$ () is the same of the filtered state.

If only some observations are missing a variation of equation can be used.

$y^{ng}_t$ is a vector containing the observations that are not missing at time $t$. 

It can be expressed as a linear transformation of $y_t$

$$ y^{ng}_t = My_t$$

where $M$ is a mask matrix that is used to select the subset of $y_t$ that is observed. $M \in \mathbb{R}^{n_{ng} \times n}$ and is made of columns which are made of all zeros but for an entry 1 at row corresponding to the non-missing observation.
hence:

$$ p(y^{ng}_t) = \mathcal{N}(M\mu_{y_t},  M\Sigma_{y_t}M^T)$$

from which you can derive

$$ p(y^{ng}_t|x_t) = p(MHx_t + Mb, MRM^T) $${#eq-filter-correct}

Then the posterior $p(x_t|y_t^{ng})$ can be computed similarly of equation @filter_correct as:

$$ p(x_t|y^{ng}_t) = \mathcal{N}(x_t; m_t, P_t) $${#eq-filter_correct_missing}
    
where:

*  predicted obs mean: $z_t = MHm_t^- + Md$
*  predicted obs covariance: $S_t = MHP_t^-(MH)^T + MRM^T$
*  Kalman gain $K_t = P_t^-(MH)^TS_t^{-1}$
*  corrected state mean: $m_t = m_t^- + K_t(My_t - z_t)$
*  corrected state covariance: $P_t = (I-K_tMH)P_t^-$


##### Details implementation 

For the implementation the matrix multiplication $MH$ can be replaced with `H[m]` where `m` is the mask for the rows for `H` and $MRM^T$ with `R[m][:,m]`

In [None]:
H, R, d,obs, mm = k.H, k.R, k.d, data[:,0,:].unsqueeze(-1), mask[:,0,:].unsqueeze(-1)

In [None]:
m = torch.tensor([False,True,True]) # mask batch
M = torch.tensor([[[0,1,0], # mask matrix
                  [0,0,1]]], dtype=torch.float64)
show_as_row(m, M, H, R)

In [None]:
M @ H, H[:, m]

(tensor([[[0.1679, 0.8635, 0.3753, 0.9760],
          [0.2125, 0.8049, 0.2124, 0.6794]]], dtype=torch.float64,
        grad_fn=<UnsafeViewBackward0>),
 tensor([[[0.1679, 0.8635, 0.3753, 0.9760],
          [0.2125, 0.8049, 0.2124, 0.6794]]], dtype=torch.float64,
        grad_fn=<IndexBackward0>))

In [None]:
M @ R @ M.mT, R[:,m][:,:,m]

(tensor([[[0.0022, 0.0000],
          [0.0000, 0.9592]]], dtype=torch.float64,
        grad_fn=<UnsafeViewBackward0>),
 tensor([[[0.0022, 0.0000],
          [0.0000, 0.9592]]], dtype=torch.float64, grad_fn=<IndexBackward0>))

By using partially missing observations `_filter_update` cannot be easily batched as the shape of the intermediate variables depends on the number of observed variables. So the idea is to divide the batch in blocks that share the same number of variables missing.

In [None]:
mask_values, indices = torch.unique(mask[:,1,:], dim=0, return_inverse=True)
mask_values, indices

(tensor([[ True, False,  True],
         [ True,  True, False]]),
 tensor([0, 1]))

##### Update mask

In [None]:
mm = mask[0,0,:]

In [None]:
#| export
def _filter_update_mask(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_dim_obs] mask must be the same across batches
                       ):
    """Update state at time `t` given observations at time `t` assuming that all observations have the same mask"""
    if (~mask).all(): return (m_m, P_m) # all data is missing
    H_m, d_m, R_m, obs_m, = H[:, mask,:], d[:, mask,:], R[:, mask,:][:, :,mask], obs[:, mask] # _m for masked
    return _filter_update(H_m, d_m, R_m, m_m, P_m, obs_m)

In [None]:
H[:, mm].shape, d[:, mm].shape, R[:, mm][:, :,mm].shape, obs[:, mm].shape

(torch.Size([1, 1, 4]),
 torch.Size([1, 1, 1]),
 torch.Size([1, 1, 1]),
 torch.Size([2, 1, 1]))

In [None]:
show_as_row(*_filter_update_mask(H, d, R, m_m, P_m, obs, mask[0, 0, :] ))

In [None]:
m, P = _filter_update_mask(H, d, R, m_m, P_m, obs, mask[0, 0, :] )
m.shape, P.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

##### Update mask batch

In [None]:
#| export
def _filter_update_mask_batch(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_batches, n_dim_obs] mask must be the same across batches
                       ):
    """Support batches with different masks when update state at time `t` given observations at time `t`"""
    
    ms, Ps= torch.empty_like(m_m), torch.empty_like(P_m)
    
    # find the unique values of the mask and make a sub-batches with it
    mask_values, indices = torch.unique(mask, return_inverse=True, dim=0)  
    for i, mask_v in enumerate(mask_values):
        idx_select = indices == i 
        m, P = _filter_update_mask(
            H, d, R,
            m_m[idx_select], P_m[idx_select],
            obs[idx_select],
            mask_v,
        )
        ms[idx_select], Ps[idx_select] = m, P
    
    return ms, Ps

In [None]:
m, P = _filter_update_mask_batch(H, d, R, m_m, P_m, obs, mask[:,0,:] )
show_as_row(m, P)
m.shape, P.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
m.sum().backward(retain_graph=True) # check that pytorch can compute gradients with the whole batch and gradients aren't nan
H.grad

tensor([[[-1.8992e+00, -1.4530e+00,  3.0663e-01, -3.5863e+00],
         [-7.0442e-01,  9.0830e-01,  8.6987e-01, -1.0243e+00],
         [ 6.8845e-02,  4.2718e-04,  5.0290e-02, -3.1175e-02]]],
       dtype=torch.float64)

### Filter All

The resursive version of the kalman filter is apperently breaking pytorch gradients calculations so a workaround is needed.
During the loop the states are saved in a python list and then at the end they are combined back into a tensor.
The last line of the function does:

- convert lists to tensors
- correct order dimensions

In [None]:
#| export
def _times2batch(x):
    """Permutes `x` so that the first dimension is the number of batches and not the times"""
    return x.permute(1,0,-2,-1)

In [None]:
#| export
@patch
def _filter_all(self: KalmanFilter,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr]) 
            
           ) ->Tuple[ListMNormal, ListMNormal]: # (Filtered state, predicted state) with shape (n_batches, n_obs, self.n_dim_state)
    """Filter observations using kalman filter """
    obs, mask, control = self._parse_obs(obs, mask, control)
    n_obs = obs.shape[1]
    bs = obs.shape[0]
    # lists are mutable so need to copy them
    m_ms, P_ms, ms, Ps = [[None for _ in range(n_obs)].copy() for _ in range(4)] 

    for t in range(n_obs):
        # --- Predict
        if t == 0:
            m_ms[t], P_ms[t] = self.m0.expand(bs, -1, -1), self.P0.expand(bs, -1, -1)
        else:
            m_ms[t], P_ms[t] = _filter_predict(self.A, self.Q, self.b,
                                               self.B if self.use_control else torch.zeros_like(self.B), # maybe disable control
                                               ms[t - 1], Ps[t - 1], control[:,t,:])
        
        # --- Update
        ms[t], Ps[t] = _filter_update_mask_batch(self.H, self.d, self.R, m_ms[t], P_ms[t], obs[:,t,:], mask[:,t,:])
        
        if self.cov_checker is not None:
            self.cov_checker.check(P_ms[t], t=t, name="filter_predict")
            self.cov_checker.check(Ps[t], t=t, name="filter_update")
    
    m_ms, P_ms, ms, Ps = list(maps(torch.stack, _times2batch, (m_ms, P_ms, ms, Ps,))) # reorder dimensions and convert to tensor
    return ListMNormal(ms, Ps), ListMNormal(m_ms, P_ms) 

In [None]:
filt_state, pred_state  = k._filter_all(data, mask, control)

In [None]:
(ms, Ps), (m_ms, P_ms) = filt_state, pred_state

Predictions at time `0` for both batches

In [None]:
show_as_row(*map(Self.shape(), (m_ms, P_ms, ms, Ps,)))

In [None]:
show_as_row(*map(lambda x:x[0][0], (m_ms, P_ms, ms, Ps,)))

### Filter

The filter methods wraps `_filter_all` but in addition:

- returns only filtered state
- remove last dimensions from mean

In [None]:
#| export
@patch
def filter(self: KalmanFilter,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr])
          ) -> ListMNormal: # Filtered state (n_batches, n_obs, self.n_dim_state)
    """Filter observation"""
    filt_state, _ = self._filter_all(obs, mask, control)
    return filt_state

In [None]:
filt = k.filter(data, mask, control)
filt.mean.shape, filt.cov.shape

(torch.Size([2, 10, 4, 1]), torch.Size([2, 10, 4, 4]))

## Smooth

### Smooth update step

compute the probability of the state at time `t` given all the observations

$p(x_t|Y) = \mathcal{N}(x_t; m_t^s, P_t^s)$ where:

- Kalman smoothing gain: $G_t = P_tA^T(P_{t+1}^-)^{-1}$
- smoothed mean: $m_t^s = m_t + G_t(m_{t+1}^s - m_{t+1}^-)$
- smoothed covariance: $P_t^s = P_t + G_t(P_{t+1}^s - P_{t+1}^-)G_t^T$

In [None]:
def _smooth_gain(A, filt_state, pred_state):
    S_C = torch.linalg.cholesky(pred_state.cov)
    return torch.cholesky_solve(A @ filt_state.cov.mT, S_C).mT

In [None]:
test_close(_smooth_gain(A, filt_state, pred_state), filt_state.cov @ A.mT @ torch.inverse(pred_state.cov))

In [None]:
#| export
def _smooth_update(A,                # [n_dim_state, n_dim_state]
                   filt_state: MNormal,         # [n_dim_state] filtered state at time `t`
                   pred_state: MNormal,         # [n_dim_state] state before filtering at time `t + 1` (= using the observation until time t)
                   next_smoothed_state: Normal, # [n_dim_state] smoothed state at time  `t+1`
                   ) -> MNormal:                # mean and cov of smoothed state at time `t`
    """Correct a pred state with a Kalman Smoother update"""
    smooth_gain = _smooth_gain(A, filt_state, pred_state)

    m_p = filt_state.mean + smooth_gain @ (next_smoothed_state.mean - pred_state.mean)
    P_p = filt_state.cov + smooth_gain @ (next_smoothed_state.cov - pred_state.cov) @ smooth_gain.mT
    
    return MNormal(m_p, P_p)

In [None]:
show_as_row(*_smooth_update(A, filt_state[:, 0, :], pred_state[:, 0, :], filt_state[:, 0, :]))

In [None]:
show_as_row(*map(Self.shape(), _smooth_update(A, MNormal(m_m, P_m), MNormal(m_m, P_m), MNormal(m_m, P_m))))

### Smooth

In [None]:
#| export
def _smooth(A, # `[n_dim_state, n_dim_state]`
            filt_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `ms[t]` is the state estimate for time t given obs from times `[0...t]`
            pred_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `m_ms[t]` is the state estimate for time t given obs from times `[0...t-1]`
            cov_checker = None
           ) -> ListMNormal: # `[n_timesteps, n_dim_state]` Smoothed state 
    """Apply the Kalman Smoother"""
    n_obs = pred_state.mean.shape[1]

    smoothed_state = ListMNormal(torch.zeros_like(filt_state.mean), torch.zeros_like(filt_state.cov))
    # For the last timestep cannot use the smoother
    smoothed_state.mean[:,-1] = filt_state.mean[:,-1]
    smoothed_state.cov[:,-1] = filt_state.cov[:,-1]

    for t in reversed(range(n_obs - 1)):
        (smoothed_state.mean[:,t], smoothed_state.cov[:,t]) = (
            _smooth_update(
                A,
                filt_state[:,t],
                pred_state[:,t + 1],
                smoothed_state[:,t+1],
            )
        )
        if cov_checker is not None:
            cov_checker.check(smoothed_state.cov[:,t], name="smooth", t=t)
    return smoothed_state

In [None]:
smooth_state = _smooth(k.A,  filt_state, pred_state)

In [None]:
show_as_row(smooth_state.mean[0][0], smooth_state.cov[0][0])

In [None]:
show_as_row(smooth_state.mean.shape, smooth_state.cov.shape)

### KalmanFilter method

In [None]:
#| export
@patch
def smooth(self: KalmanFilter,
           obs: Tensor,
           mask: Tensor,
           control: Tensor
          ) -> ListMNormal: # `[n_timesteps, n_dim_state]` smoothed state
        
    """Kalman Filter Smoothing"""

    filt_state, pred_state = self._filter_all(obs, mask, control)

    smoothed_state = _smooth(self.A,
                   filt_state, pred_state,
                   self.cov_checker)
    return smoothed_state

In [None]:
smoothed_state = k.smooth(data, mask, control)

In [None]:
show_as_row(smoothed_state.mean.shape, smoothed_state.cov.shape)

In [None]:
smoothed_state.mean.sum().backward(retain_graph=True)
A.grad

tensor([[[-13.6413,  -6.2732, -14.2362,   8.2959],
         [ -0.3648,   6.0936,   3.7623,  -2.9837],
         [-13.0810,  -3.2975, -13.0166,   6.8282],
         [  2.2838,  13.9427,   7.3058,  -3.1199]]], dtype=torch.float64)

## Predict

The prediction at time t ($y_t$) are computed rom the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

this works both if the state was filtered or smoother

This add the supports for conditional predictions, which means that at the time (t) when we are making the predictions some of the variables have been actually observed. Since the model prediction is a normal distribution we can condition on the observed values and thus improve the predictions. See `conditional_gaussian`

In order to have conditional predictions that make sense it's not possible to return the full covariance matrix for the predictions but only the standard deviations

In [None]:
test_m = torch.tensor(
    [[True, True, True,],
    [False, True, True],
    [False, False, False]]
)

In [None]:
torch.logical_xor(test_m.all(-1), test_m.any(-1))

tensor([False,  True, False])

In [None]:
A = torch.rand(2,2,3,3)

In [None]:
(A @ A).shape

torch.Size([2, 2, 3, 3])

predict can be vectorized across both the batch and the timesteps, except for timesteps that require conditional predictions

In [None]:
#| export
@patch
def _obs_from_state(self: KalmanFilter, state: ListMNormal):

    mean = self.H @ state.mean + self.d
    cov = self.H @ state.cov @ self.H.mT + self.R
    
    if self.cov_checker is not None:
        for c in cov: # this is batched and for all timestamps
            self.cov_checker.check(c, caller='predict')
    
    return ListMNormal(mean.squeeze(-1), cov)

In [None]:
smoothed_state.mean.shape, smoothed_state.cov.shape

(torch.Size([2, 10, 4, 1]), torch.Size([2, 10, 4, 4]))

In [None]:
(k.H @ smoothed_state.mean).shape

torch.Size([2, 10, 3, 1])

In [None]:
pred_obs0 = k._obs_from_state(smoothed_state)
pred_obs0.mean.shape

torch.Size([2, 10, 3])

In [None]:
pred_obs0.cov.shape

torch.Size([2, 10, 3, 3])

In [None]:
#| export
@patch
def predict(self: KalmanFilter, obs, mask, control, smooth=True):
    """Predicted observations at all times """
    state = self.smooth(obs, mask, control) if smooth else self.filter(obs, mask, control)
    obs, mask, control = self._parse_obs(obs, mask, control)
    
    pred_obs = self._obs_from_state(state)
    pred_mean, pred_std = pred_obs.mean, cov2std(pred_obs.cov)
    
    if self.use_conditional:
        # conditional predictions are slow, do only if some obs are missing 
        mask = mask.squeeze(0)
        cond_mask = torch.logical_xor(mask.all(-1), mask.any(-1))
        # this cannot be batched so returns a list
        cond_preds = cond_gaussian_batched(
            pred_obs[cond_mask], obs[cond_mask].squeeze(-1), mask[cond_mask])
        
        for i, c_pred in enumerate(cond_preds):
            m = ~mask[cond_mask][i]
            pred_mean[cond_mask][i][m] = c_pred.mean
            pred_std [cond_mask][i][m] = cov2std(c_pred.cov)
    
    return ListNormal(pred_mean, pred_std)

In [None]:
pred = k.predict(data, mask, control)

In [None]:
pred.mean.shape, pred.std.shape

(torch.Size([2, 10, 3]), torch.Size([2, 10, 3]))

Gradients ...

In [None]:
def get_grad_mask(x):
    "filter gradient after sub the masks value with x"
    d = data.clone()
    d[~mask] = x
    k.predict(data, mask, control).mean.sum().backward(retain_graph=True)
    grad = k.R_raw.grad.clone()
    k.zero_grad() 
    return grad

In [None]:
get_grad_mask(10)

tensor([[[ -4.0769,  -0.6436,  -5.1756],
         [-13.0135, -13.6321, -14.3604],
         [ -5.0067,  -0.6871,  -9.5430]]], dtype=torch.float64)

In [None]:
test_close(get_grad_mask(1), get_grad_mask(10))

In [None]:
@patch
def predict_times(self: KalmanFilter, times, obs, mask=None, smooth=True, check_args=None):
    """Predicted observations at specific times """
    state = self.smooth(obs, mask, check_args) if smooth else self.filter(obs, mask, check_args)
    obs, mask = self._parse_obs(obs, mask)
    times = array1d(times)
    
    n_timesteps = obs.shape[0]
    n_features = obs.shape[1] if len(obs.shape) > 1 else 1
    
    if times.max() > n_timesteps or times.min() < 0:
        raise ValueError(f"provided times range from {times.min()} to {times.max()}, which is outside allowed range : 0 to {n_timesteps}")

    means = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device)
    stds = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device) 
    for i, t in enumerate(times):
        mean, std = self._obs_from_state(
            state.mean[t],
            state.cov[t],
            {'t': t, **check_args} if check_args is not None else None
        )
        
        means[i], stds[i] = _get_cond_pred(ListNormal(mean, std), obs[t], mask[t])
    
    return ListNormal(means, stds)  

# Numerical Stable Kalman Filter (Square Root Filter)

In [None]:
kSR = KalmanFilterSR.init_random(3,4,3)

## Filter 

### Filter predict

#### Covariance

Implement the numerical stable version of the covariance update

In [None]:
A, Q, b, B, m_pr,P_pr= (k.A, k.Q, k.b, k.B,torch.concat([k.m0]*2), torch.concat([k.P0]*2))

In [None]:
Q_C = kSR.Q_C

In [None]:
_filter_predict_cov_stand(kSR.A, kSR.Q_C @ kSR.Q_C.mT, P_pr)

tensor([[[1.3994, 1.7755, 1.6101, 2.3405],
         [1.7755, 5.1566, 4.5819, 6.5831],
         [1.6101, 4.5819, 5.0435, 5.9596],
         [2.3405, 6.5831, 5.9596, 8.7927]],

        [[1.3994, 1.7755, 1.6101, 2.3405],
         [1.7755, 5.1566, 4.5819, 6.5831],
         [1.6101, 4.5819, 5.0435, 5.9596],
         [2.3405, 6.5831, 5.9596, 8.7927]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [None]:
P_pr_C = torch.linalg.cholesky(P_pr)

$$W = \begin{bmatrix}AC_{t-1}&C_Q\end{bmatrix}$$

In [None]:
W = torch.concat([A @ P_pr_C, Q_C.expand_as(P_pr_C)], dim=-1)
W.shape

torch.Size([2, 4, 8])

In [None]:
P_m_C = torch.linalg.qr(W.mT).R.mT

In [None]:
P_m_C

tensor([[[-2.2660,  0.0000,  0.0000,  0.0000],
         [-1.9573,  0.8825,  0.0000,  0.0000],
         [-2.1939,  0.9222,  0.9708,  0.0000],
         [-2.1984,  0.9382,  0.0220, -0.3203]],

        [[-2.2660,  0.0000,  0.0000,  0.0000],
         [-1.9573,  0.8825,  0.0000,  0.0000],
         [-2.1939,  0.9222,  0.9708,  0.0000],
         [-2.1984,  0.9382,  0.0220, -0.3203]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
P_m_C @ P_m_C.mT

tensor([[[5.1348, 4.4353, 4.9715, 4.9817],
         [4.4353, 4.6100, 5.1082, 5.1311],
         [4.9715, 5.1082, 6.6063, 5.7099],
         [4.9817, 5.1311, 5.7099, 5.8165]],

        [[5.1348, 4.4353, 4.9715, 4.9817],
         [4.4353, 4.6100, 5.1082, 5.1311],
         [4.9715, 5.1082, 6.6063, 5.7099],
         [4.9817, 5.1311, 5.7099, 5.8165]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

In [None]:
P_m = _filter_predict_cov_stand(A, Q_C @ Q_C.mT, P_pr)

In [None]:
test_close(P_m, P_m_C @ P_m_C.mT)

In [None]:
(P_m - P_m_C @ P_m_C.mT).max()

tensor(2.6645e-15, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
test_P_m_C = torch.linalg.cholesky(P_m)

In [None]:
(test_P_m_C @ test_P_m_C.mT - P_m_C @ P_m_C.mT).max()

tensor(2.6645e-15, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
(test_P_m_C - P_m_C).max()

tensor(4.5320, dtype=torch.float64, grad_fn=<MaxBackward1>)

Cholesky decomposition is not unique! but the solution is correct

In [None]:
#| export
def _filter_predict_cov_SR(A, # transition covariance $A_t$
                        Q_C, # Cholesky Factor of transition covariance $Q_t$
                        P_pr_C # Cholesky Factor of previous state covariance $P_{t-1}$
                       ):
    """Numerical stable Kalman filter predict for covariance"""
    W = torch.concat([A @ P_pr_C, Q_C.expand_as(P_pr_C)], dim=-1)
    return torch.linalg.qr(W.mT).R.mT 

In [None]:
P_m_C = _filter_predict_cov_SR(A, Q_C, P_pr_C)
test_close(P_m_C @ P_m_C.mT, _filter_predict_cov_stand(A, Q_C @ Q_C.mT, P_pr))

In [None]:
def fuzz_filter_predict_cov_SR(n=10):
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        A, Q_C, b, B, m_pr,P_pr = (kSR.A.unsqueeze(0), kSR.Q_C.unsqueeze(0), kSR.b.unsqueeze(-1),
                                                  kSR.B.unsqueeze(0),
                                                  torch.stack([kSR.m0]*2).unsqueeze(-1),
                                                  torch.stack([kSR.P0]*2))
        P_pr_C = torch.linalg.cholesky(P_pr)
        P_m_C = _filter_predict_cov_SR(A, Q_C, P_pr_C)
        test_close(P_m_C @ P_m_C.mT, _filter_predict_cov_stand(A, Q_C @ Q_C.mT, P_pr), eps=5e-13) 

In [None]:
fuzz_filter_predict_cov_SR()

#### Predict

In [None]:
#| export
def _filter_predict_SR(A, Q_C, b, B, m_pr, P_pr_C,control) -> Tuple: # predicted state
    """Calculate the state at time `t` given the state at time `t-1`"""
    m_m = _filter_predict_mean(A, B, b, m_pr, control)
    P_m = _filter_predict_cov_SR(A, Q_C, P_pr_C)
    return (m_m, P_m)

### Filter Update

$$M = \begin{bmatrix}R^{T/2} & 0 \\ (C^-)^TH^T & (C^-)^T \end{bmatrix}$$

In [None]:
H, R, R_C, obs = kSR.H, kSR.R, kSR.R_C, data[:,0,:].unsqueeze(-1)

In [None]:
R_C = kSR.R_C

In [None]:
# P_m_C = torch.linalg.cholesky(P_m)

In [None]:
P_pr_C.shape

torch.Size([2, 4, 4])

In [None]:
P_m_C.shape

torch.Size([2, 4, 4])

In [None]:
[[R_C.shape,0],
 [(P_m_C.mT @ H.mT).shape, P_m_C.mT.shape]]

[[torch.Size([1, 3, 3]), 0], [torch.Size([2, 4, 3]), torch.Size([2, 4, 4])]]

In [None]:
M_21 = P_m_C.mT @ H.mT 

In [None]:
M_21.mT.shape

torch.Size([2, 3, 4])

In [None]:
M_21.shape[:-2]

torch.Size([2])

In [None]:
[[R_C.expand(*M_21.shape[:-2], -1, -1).shape, torch.zeros_like(M_21.mT).shape],
 [M_21.shape, P_m_C.mT.shape]]

[[torch.Size([2, 3, 3]), torch.Size([2, 3, 4])],
 [torch.Size([2, 4, 3]), torch.Size([2, 4, 4])]]

In [None]:
M_1 = torch.cat([R_C.expand(*M_21.shape[:-2], -1, -1).mT, torch.zeros_like(M_21.mT)], dim=-1) 

In [None]:
M_2 = torch.cat([M_21, P_m_C.mT], dim=-1)

In [None]:
M_1.shape, M_2.shape

(torch.Size([2, 3, 7]), torch.Size([2, 4, 7]))

In [None]:
M = torch.cat([M_1, M_2], dim=-2)

In [None]:
M[0]

tensor([[ 0.9094,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.9543,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.1604,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.7577, -3.0682, -4.6017, -2.2660, -1.9573, -2.1939, -2.1984],
        [ 0.6702,  0.9346,  1.1521,  0.0000,  0.8825,  0.9222,  0.9382],
        [ 0.3868,  0.6330,  0.3967,  0.0000,  0.0000,  0.9708,  0.0220],
        [-0.0931, -0.1038, -0.1399,  0.0000,  0.0000,  0.0000, -0.3203]],
       dtype=torch.float64, grad_fn=<SelectBackward0>)

In [None]:
U = torch.linalg.qr(M).R

so $U \in \mathbb{R}^{(n+k) \times (n+k)}$ and the need to get the bottom part of size $k \times k$

In [None]:
U[0]

tensor([[-2.1270, -2.9497, -4.2442, -1.8726, -1.8956, -2.2802, -2.1305],
        [ 0.0000, -1.7055, -1.7253, -0.8378, -0.7263, -0.8690, -0.8122],
        [ 0.0000,  0.0000, -1.3102, -0.7895, -0.5537, -0.2800, -0.6167],
        [ 0.0000,  0.0000,  0.0000,  0.5503, -0.2911, -0.4499, -0.3183],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.3127,  0.4201, -0.2192],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.4413,  0.1756],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.2399]],
       dtype=torch.float64, grad_fn=<SelectBackward0>)

In [None]:
n_dim_state = P_m_C.shape[-1]
n_dim_obs = R_C.shape[-1]

In [None]:
P_C = U[:, n_dim_obs:, n_dim_obs:].mT

In [None]:
P_C[0]

tensor([[ 0.5503,  0.0000,  0.0000,  0.0000],
        [-0.2911, -0.3127,  0.0000,  0.0000],
        [-0.4499,  0.4201, -0.4413,  0.0000],
        [-0.3183, -0.2192,  0.1756, -0.2399]], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [None]:
test_close(M.mT @ M, U.mT @ U) # this is just to say the the QR decomposition is correct

Check with standard implementation

In [None]:
P_C @ P_C.mT

tensor([[[ 3.0279e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8248e-01, -4.1026e-04,  1.6120e-01],
         [-2.4755e-01, -4.1026e-04,  5.7354e-01, -2.6354e-02],
         [-1.7516e-01,  1.6120e-01, -2.6354e-02,  2.3776e-01]],

        [[ 3.0279e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8248e-01, -4.1026e-04,  1.6120e-01],
         [-2.4755e-01, -4.1026e-04,  5.7354e-01, -2.6354e-02],
         [-1.7516e-01,  1.6120e-01, -2.6354e-02,  2.3776e-01]]],
       dtype=torch.float64, grad_fn=<UnsafeViewBackward0>)

In [None]:
K = _filter_update_k_gain(H, R_C @ R_C.mT, P_m_C @ P_m_C.mT)

In [None]:
_filter_update_cov(H, K, P_m_C @ P_m_C.mT) 

tensor([[[ 3.0279e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8248e-01, -4.1026e-04,  1.6120e-01],
         [-2.4755e-01, -4.1026e-04,  5.7354e-01, -2.6354e-02],
         [-1.7516e-01,  1.6120e-01, -2.6354e-02,  2.3776e-01]],

        [[ 3.0279e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8248e-01, -4.1026e-04,  1.6120e-01],
         [-2.4755e-01, -4.1026e-04,  5.7354e-01, -2.6354e-02],
         [-1.7516e-01,  1.6120e-01, -2.6354e-02,  2.3776e-01]]],
       dtype=torch.float64, grad_fn=<UnsafeViewBackward0>)

comparison with the $U$ matrix computed using the reference values

In [None]:
P_m.shape

torch.Size([2, 4, 4])

In [None]:
R.shape

torch.Size([1, 3, 3])

In [None]:
S = H @ P_m @ H.mT + R

In [None]:
S_C = torch.linalg.cholesky(S)

In [None]:
K = _filter_update_k_gain(H, R, P_m)

In [None]:
test_close(K, P_m @ H.mT @ torch.inverse(S))

In [None]:
K_til = K @ S_C

In [None]:
K_til.shape

torch.Size([2, 4, 3])

In [None]:
K_til

tensor([[[1.8726, 0.8378, 0.7895],
         [1.8956, 0.7264, 0.5537],
         [2.2802, 0.8690, 0.2800],
         [2.1305, 0.8122, 0.6167]],

        [[1.8726, 0.8378, 0.7895],
         [1.8956, 0.7264, 0.5537],
         [2.2802, 0.8690, 0.2800],
         [2.1305, 0.8122, 0.6167]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

In [None]:
P = _filter_update_cov(H, K, P_m) 

In [None]:
P

tensor([[[ 3.0280e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8249e-01, -4.0924e-04,  1.6121e-01],
         [-2.4755e-01, -4.0924e-04,  5.7354e-01, -2.6353e-02],
         [-1.7516e-01,  1.6121e-01, -2.6353e-02,  2.3776e-01]],

        [[ 3.0280e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8249e-01, -4.0924e-04,  1.6121e-01],
         [-2.4755e-01, -4.0924e-04,  5.7354e-01, -2.6353e-02],
         [-1.7516e-01,  1.6121e-01, -2.6353e-02,  2.3776e-01]]],
       dtype=torch.float64, grad_fn=<UnsafeViewBackward0>)

In [None]:
P_C = torch.linalg.cholesky(P)

In [None]:
P_C

tensor([[[ 0.5503,  0.0000,  0.0000,  0.0000],
         [-0.2910,  0.3127,  0.0000,  0.0000],
         [-0.4499, -0.4200,  0.4413,  0.0000],
         [-0.3183,  0.2193, -0.1755,  0.2399]],

        [[ 0.5503,  0.0000,  0.0000,  0.0000],
         [-0.2910,  0.3127,  0.0000,  0.0000],
         [-0.4499, -0.4200,  0.4413,  0.0000],
         [-0.3183,  0.2193, -0.1755,  0.2399]]], dtype=torch.float64,
       grad_fn=<LinalgCholeskyExBackward0>)

In [None]:
S_C.shape

torch.Size([2, 3, 3])

In [None]:
K_til.shape

torch.Size([2, 4, 3])

In [None]:
U_1 = torch.cat([S_C.mT, K_til.mT], dim=-1)

In [None]:
U_1.shape

torch.Size([2, 3, 7])

In [None]:
U_2 = torch.cat([torch.zeros_like(K_til), P_C.mT], dim=-1)

In [None]:
U_corr = torch.cat([U_1, U_2], dim=-2)

In [None]:
(torch.tril(U_corr[0].mT) == U_corr[0].mT).all()

tensor(True)

In [None]:
U_corr.mT @ U_corr

tensor([[[ 4.5240,  6.2739,  9.0271,  3.9830,  4.0319,  4.8499,  4.5314],
         [ 6.2739, 11.6096, 15.4616,  6.9526,  6.8304,  8.2079,  7.6694],
         [ 9.0271, 15.4616, 22.7061, 10.4276, 10.0239, 11.5435, 11.2512],
         [ 3.9830,  6.9526, 10.4276,  5.1348,  4.4353,  4.9715,  4.9817],
         [ 4.0319,  6.8304, 10.0239,  4.4353,  4.6100,  5.1082,  5.1311],
         [ 4.8499,  8.2079, 11.5435,  4.9715,  5.1082,  6.6063,  5.7099],
         [ 4.5314,  7.6694, 11.2512,  4.9817,  5.1311,  5.7099,  5.8165]],

        [[ 4.5240,  6.2739,  9.0271,  3.9830,  4.0319,  4.8499,  4.5314],
         [ 6.2739, 11.6096, 15.4616,  6.9526,  6.8304,  8.2079,  7.6694],
         [ 9.0271, 15.4616, 22.7061, 10.4276, 10.0239, 11.5435, 11.2512],
         [ 3.9830,  6.9526, 10.4276,  5.1348,  4.4353,  4.9715,  4.9817],
         [ 4.0319,  6.8304, 10.0239,  4.4353,  4.6100,  5.1082,  5.1311],
         [ 4.8499,  8.2079, 11.5435,  4.9715,  5.1082,  6.6063,  5.7099],
         [ 4.5314,  7.6694, 11.2512,

Now we are building $U^TU$ manually and check tht is the same

In [None]:
test_close(K_til @ K_til.mT, K @ H @ P_m)
K_til @ K_til.mT

tensor([[[4.8320, 4.5955, 5.2190, 5.1569],
         [4.5955, 4.4276, 5.1086, 4.9699],
         [5.2190, 5.1086, 6.0327, 5.7362],
         [5.1569, 4.9699, 5.7362, 5.5787]],

        [[4.8320, 4.5955, 5.2190, 5.1569],
         [4.5955, 4.4276, 5.1086, 4.9699],
         [5.2190, 5.1086, 6.0327, 5.7362],
         [5.1569, 4.9699, 5.7362, 5.5787]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

In [None]:
test_close(K_til @ K_til.mT + P, P_m)

In [None]:
UTU = torch.cat([torch.cat([S, H@P_m], dim=-1),
           torch.cat([P_m@H.mT, K_til@K_til.mT + P], dim=-1)
          ], dim=-2)
UTU

tensor([[[ 4.5240,  6.2739,  9.0271,  3.9830,  4.0319,  4.8499,  4.5314],
         [ 6.2739, 11.6096, 15.4616,  6.9526,  6.8304,  8.2079,  7.6694],
         [ 9.0271, 15.4616, 22.7061, 10.4276, 10.0239, 11.5435, 11.2512],
         [ 3.9830,  6.9526, 10.4276,  5.1348,  4.4353,  4.9715,  4.9817],
         [ 4.0319,  6.8304, 10.0239,  4.4353,  4.6100,  5.1082,  5.1311],
         [ 4.8499,  8.2079, 11.5435,  4.9715,  5.1082,  6.6063,  5.7099],
         [ 4.5314,  7.6694, 11.2512,  4.9817,  5.1311,  5.7099,  5.8165]],

        [[ 4.5240,  6.2739,  9.0271,  3.9830,  4.0319,  4.8499,  4.5314],
         [ 6.2739, 11.6096, 15.4616,  6.9526,  6.8304,  8.2079,  7.6694],
         [ 9.0271, 15.4616, 22.7061, 10.4276, 10.0239, 11.5435, 11.2512],
         [ 3.9830,  6.9526, 10.4276,  5.1348,  4.4353,  4.9715,  4.9817],
         [ 4.0319,  6.8304, 10.0239,  4.4353,  4.6100,  5.1082,  5.1311],
         [ 4.8499,  8.2079, 11.5435,  4.9715,  5.1082,  6.6063,  5.7099],
         [ 4.5314,  7.6694, 11.2512,

In [None]:
test_close(U_corr.mT @ U_corr, UTU)

In [None]:
U_corr

tensor([[[ 2.1270,  2.9497,  4.2441,  1.8726,  1.8956,  2.2802,  2.1305],
         [ 0.0000,  1.7055,  1.7253,  0.8378,  0.7264,  0.8690,  0.8122],
         [ 0.0000,  0.0000,  1.3102,  0.7895,  0.5537,  0.2800,  0.6167],
         [ 0.0000,  0.0000,  0.0000,  0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.2399]],

        [[ 2.1270,  2.9497,  4.2441,  1.8726,  1.8956,  2.2802,  2.1305],
         [ 0.0000,  1.7055,  1.7253,  0.8378,  0.7264,  0.8690,  0.8122],
         [ 0.0000,  0.0000,  1.3102,  0.7895,  0.5537,  0.2800,  0.6167],
         [ 0.0000,  0.0000,  0.0000,  0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,

In [None]:
torch.linalg.cholesky(P).mT

tensor([[[ 0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,  0.2399]],

        [[ 0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,  0.2399]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
U_corr[0, 3:-1, 3:-1]

tensor([[ 0.5503, -0.2910, -0.4499],
        [ 0.0000,  0.3127, -0.4200],
        [ 0.0000,  0.0000,  0.4413]], dtype=torch.float64,
       grad_fn=<SliceBackward0>)

In [None]:
U_corr[:, n_dim_obs:, n_dim_obs:]

tensor([[[ 0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,  0.2399]],

        [[ 0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,  0.2399]]], dtype=torch.float64,
       grad_fn=<SliceBackward0>)

In [None]:
torch.linalg.cholesky(P).mT

tensor([[[ 0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,  0.2399]],

        [[ 0.5503, -0.2910, -0.4499, -0.3183],
         [ 0.0000,  0.3127, -0.4200,  0.2193],
         [ 0.0000,  0.0000,  0.4413, -0.1755],
         [ 0.0000,  0.0000,  0.0000,  0.2399]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
test_close(U_corr[:, n_dim_obs:, n_dim_obs:], torch.linalg.cholesky(P).mT)

In [None]:
M = torch.cat([torch.cat([R_C.mT.expand(2,-1,-1), torch.zeros_like(H @ P_m_C)],   dim=-1),
               torch.cat([P_m_C.mT @ H.mT,        P_m_C.mT                   ],   dim=-1)
              ], dim=-2)

In [None]:
(torch.linalg.cholesky(R) @ torch.linalg.cholesky(R).mT - R).max()

tensor(0., dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
(torch.linalg.cholesky(P_m) @ torch.linalg.cholesky(P_m).mT - P_m).max()

tensor(8.8818e-16, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
R_C@R_C.mT + H @ P_m_C @ P_m_C.mT @ H.mT

tensor([[[ 4.5240,  6.2739,  9.0271],
         [ 6.2739, 11.6096, 15.4616],
         [ 9.0271, 15.4616, 22.7061]],

        [[ 4.5240,  6.2739,  9.0271],
         [ 6.2739, 11.6096, 15.4616],
         [ 9.0271, 15.4616, 22.7061]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [None]:
R_C@R_C.mT + H @ torch.linalg.cholesky(P_m) @ torch.linalg.cholesky(P_m).mT @ H.mT

tensor([[[ 4.5240,  6.2739,  9.0271],
         [ 6.2739, 11.6096, 15.4616],
         [ 9.0271, 15.4616, 22.7061]],

        [[ 4.5240,  6.2739,  9.0271],
         [ 6.2739, 11.6096, 15.4616],
         [ 9.0271, 15.4616, 22.7061]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [None]:
R + H @ P_m @ H.mT

tensor([[[ 4.5240,  6.2739,  9.0271],
         [ 6.2739, 11.6096, 15.4616],
         [ 9.0271, 15.4616, 22.7061]],

        [[ 4.5240,  6.2739,  9.0271],
         [ 6.2739, 11.6096, 15.4616],
         [ 9.0271, 15.4616, 22.7061]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [None]:
test_close(M.mT@M, UTU, eps=2e-5)

In [None]:
M

tensor([[[ 0.9094,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.9543,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.1604,  0.0000,  0.0000,  0.0000,  0.0000],
         [-1.7577, -3.0682, -4.6017, -2.2660, -1.9573, -2.1939, -2.1984],
         [ 0.6702,  0.9346,  1.1521,  0.0000,  0.8825,  0.9222,  0.9382],
         [ 0.3868,  0.6330,  0.3967,  0.0000,  0.0000,  0.9708,  0.0220],
         [-0.0931, -0.1038, -0.1399,  0.0000,  0.0000,  0.0000, -0.3203]],

        [[ 0.9094,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.9543,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.1604,  0.0000,  0.0000,  0.0000,  0.0000],
         [-1.7577, -3.0682, -4.6017, -2.2660, -1.9573, -2.1939, -2.1984],
         [ 0.6702,  0.9346,  1.1521,  0.0000,  0.8825,  0.9222,  0.9382],
         [ 0.3868,  0.6330,  0.3967,  0.0000,  0.0000,  0.9708,  0.0220],
         [-0.0931, -0.1038, -0.1399,

In [None]:
U = torch.linalg.qr(M).R

In [None]:
test_close(U.mT @ U, UTU)

In [None]:
test_close(U_corr.mT @ U_corr, U.mT @ U)

but the decomposition is different so:

In [None]:
(U_corr - U).max()

tensor(8.4883, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
P_C = U[:, 3:, 3:].mT

In [None]:
P_C

tensor([[[ 0.5503,  0.0000,  0.0000,  0.0000],
         [-0.2911, -0.3127,  0.0000,  0.0000],
         [-0.4499,  0.4201, -0.4413,  0.0000],
         [-0.3183, -0.2192,  0.1756, -0.2399]],

        [[ 0.5503,  0.0000,  0.0000,  0.0000],
         [-0.2911, -0.3127,  0.0000,  0.0000],
         [-0.4499,  0.4201, -0.4413,  0.0000],
         [-0.3183, -0.2192,  0.1756, -0.2399]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
P_C @ P_C.mT

tensor([[[ 3.0279e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8248e-01, -4.1026e-04,  1.6120e-01],
         [-2.4755e-01, -4.1026e-04,  5.7354e-01, -2.6354e-02],
         [-1.7516e-01,  1.6120e-01, -2.6354e-02,  2.3776e-01]],

        [[ 3.0279e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8248e-01, -4.1026e-04,  1.6120e-01],
         [-2.4755e-01, -4.1026e-04,  5.7354e-01, -2.6354e-02],
         [-1.7516e-01,  1.6120e-01, -2.6354e-02,  2.3776e-01]]],
       dtype=torch.float64, grad_fn=<UnsafeViewBackward0>)

In [None]:
P

tensor([[[ 3.0280e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8249e-01, -4.0924e-04,  1.6121e-01],
         [-2.4755e-01, -4.0924e-04,  5.7354e-01, -2.6353e-02],
         [-1.7516e-01,  1.6121e-01, -2.6353e-02,  2.3776e-01]],

        [[ 3.0280e-01, -1.6016e-01, -2.4755e-01, -1.7516e-01],
         [-1.6016e-01,  1.8249e-01, -4.0924e-04,  1.6121e-01],
         [-2.4755e-01, -4.0924e-04,  5.7354e-01, -2.6353e-02],
         [-1.7516e-01,  1.6121e-01, -2.6353e-02,  2.3776e-01]]],
       dtype=torch.float64, grad_fn=<UnsafeViewBackward0>)

I have that $U^TU = U_{corr}^TU_{corr}$ but that $U != U_{corr}$ as the cholesky decomposition of a matrix is not unique but they are both valid.

However in both cases the bottom left of $U^TU$ is $KHP^- + P$

In [None]:
test_close(UTU[:, 3:, 3:], K @ H @ P_m + P)

In [None]:
test_close((U.mT @ U)[:, 3:, 3:], K @ H @ P_m + P)

Thefore the bottom left of $U$ needs to be a valid cholesky factor of $P$

In [None]:
U[0]

tensor([[-2.1270, -2.9497, -4.2442, -1.8726, -1.8956, -2.2802, -2.1305],
        [ 0.0000, -1.7055, -1.7253, -0.8378, -0.7263, -0.8690, -0.8122],
        [ 0.0000,  0.0000, -1.3102, -0.7895, -0.5537, -0.2800, -0.6167],
        [ 0.0000,  0.0000,  0.0000,  0.5503, -0.2911, -0.4499, -0.3183],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.3127,  0.4201, -0.2192],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.4413,  0.1756],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.2399]],
       dtype=torch.float64, grad_fn=<SelectBackward0>)

In [None]:
K @ H @ P_m + P

tensor([[[5.1348, 4.4353, 4.9715, 4.9817],
         [4.4353, 4.6100, 5.1082, 5.1311],
         [4.9715, 5.1082, 6.6063, 5.7099],
         [4.9817, 5.1311, 5.7099, 5.8165]],

        [[5.1348, 4.4353, 4.9715, 4.9817],
         [4.4353, 4.6100, 5.1082, 5.1311],
         [4.9715, 5.1082, 6.6063, 5.7099],
         [4.9817, 5.1311, 5.7099, 5.8165]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

let's test that for each block $U$ is equivalent for $U_{corr})$

In [None]:
test_close(U[:, :3, :3].mT @ U[:, :3, :3], S)

In [None]:
test_close(U[:, :3, 3:].mT @ U[:, :3, 3:], K @ H @ P_m)

In [None]:
test_close(U[:, 3:, 3:].mT @ U[:, 3:, 3:], P)

#### Covariance

In [None]:
#| export
def _filter_update_cov_SR(
    H,
    R_C,
    P_m_C
) -> Tuple: # (P_C, S_C) Chol of filtered covariance and chol factor of S
    """ """
    M_21 = P_m_C.mT @ H.mT 
    M_1 = torch.cat([R_C.expand(M_21.shape[0], -1, -1), torch.zeros_like(M_21.mT)], dim=-1) 
    M_2 = torch.cat([M_21,                              P_m_C.mT                 ], dim=-1)
    M = torch.cat([M_1, M_2], dim=-2)

    U = torch.linalg.qr(M).R

    n_dim_obs = R_C.shape[-1]
    P_C = U[:, n_dim_obs:, n_dim_obs:].mT
    S_C = U[:, :n_dim_obs, :n_dim_obs].mT
    
    return P_C, S_C
 

In [None]:
P_C, S_C = _filter_update_cov_SR(H, R_C, P_m_C) 

In [None]:
def fuzz_filter_update_cov_SR(n=10):
    errs = []
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        H, R_C, P_m = (kSR.H, kSR.R_C, torch.cat([kSR.P0]*5))
        R = R_C @ R_C.mT
        P_m_C = torch.linalg.cholesky(P_m)
        P_C, _ = _filter_update_cov_SR(H, R_C, P_m_C)
        K = _filter_update_k_gain(H, R, P_m)
        std_P = _filter_update_cov(H, K, P_m)
        errs.append((P_C @ P_C.mT -  std_P).abs().max().item())
        test_close(P_C @ P_C.mT, std_P, eps=1e-13)
    return torch.tensor(errs)

In [None]:
fuzz_filter_update_cov_SR(1000).median()

tensor(8.7985e-15)

In [None]:
torch.sqrt(torch.tensor(1e-5))

tensor(0.0032)

#### Kalman Gain

Don't compute the inverse of the matrix, but use `cholesky_solve` to invert the matrix

In [None]:
#| export
def _filter_update_k_gain_SR(
    H,
    P_m_C, # Chol factor of $P^-$
    S_C # Cholesky factor of S = (HPH^T +R)
):
    """kalman gain for filter update for SR filter"""
    return torch.cholesky_solve(H @ P_m_C @ P_m_C.mT, S_C).mT

In [None]:
test_close(S_C @ S_C.mT, S)

In [None]:
_filter_update_k_gain_SR(H, P_m_C, S_C)

tensor([[[-1.5789e-01, -1.1835e-01,  6.0261e-01],
         [ 5.0220e-02, -1.6250e-03,  4.2261e-01],
         [ 2.3883e-01,  2.9332e-01,  2.1371e-01],
         [ 6.2363e-02,  3.9408e-05,  4.7069e-01]],

        [[-1.5789e-01, -1.1835e-01,  6.0261e-01],
         [ 5.0220e-02, -1.6250e-03,  4.2261e-01],
         [ 2.3883e-01,  2.9332e-01,  2.1371e-01],
         [ 6.2363e-02,  3.9408e-05,  4.7069e-01]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
test_close(_filter_update_k_gain(H, R, P_m), _filter_update_k_gain_SR(H, P_m_C, S_C))

In [None]:
#| export
def _filter_update_SR(
    H, # [1, n_dim_obs, n_dim_state]
    d, # [1, n_dim_obs, 1]
    R_C, # [1, n_dim_obs, n_dim_obs]
    m_m, # [n_batches, n_dim_state, 1]
    P_m_C, # [n_batches, n_dim_state, n_dim_state] Cholesky factor predicted covariance
    obs # # [n_batches, n_dim_obs, 1]
) -> Tuple: # Filtered state (mean, chol_covariance) [n_batches, n_dim_state]
    "Filter update state at `t` with obs at `t`"
    P_C, S_C = _filter_update_cov_SR(H, R_C, P_m_C)
    K = _filter_update_k_gain_SR(H, P_m_C, S_C)
    m = _filter_update_mean(H, d, K, m_m, obs)
    return m, P_C

In [None]:
obs.shape

torch.Size([2, 3, 1])

In [None]:
m, P_C = _filter_update_SR(H, d, R_C, m_m, P_m_C, obs)
show_as_row(m, P_C)
m.shape, P_C.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

#### Missing observations

##### Update mask

In [None]:
#| export
def _filter_update_mask_SR(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R_C, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m_C, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_dim_obs] mask must be the same across batches
                       ):
    """SR Filter Update state at time `t` given observations at time `t` assuming that all observations have the same mask"""
    if (~mask).all(): return (m_m, P_m_C) # all data is missing
    H_m, d_m, R_C_m, obs_m, = H[:, mask,:], d[:, mask,:], R_C[:, mask,:][:, :,mask], obs[:, mask] # _m for masked
    return _filter_update_SR(H_m, d_m, R_C_m, m_m, P_m_C, obs_m)

In [None]:
show_as_row(*_filter_update_mask_SR(H, d, R_C, m_m, P_m_C, obs, mask[0, 0, :] ))

In [None]:
m, P_C = _filter_update_mask_SR(H, d, R_C, m_m, P_m_C, obs, mask[0, 0, :] )
m.shape, P_C.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

##### Update mask batch

In [None]:
#| export
def _filter_update_mask_batch_SR(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R_C, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m_C, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_batches, n_dim_obs] mask must be the same across batches
                       ):
    """Support batches with different masks when update state at time `t` given observations at time `t`"""
    
    ms, P_Cs= torch.empty_like(m_m), torch.empty_like(P_m_C)
    
    # find the unique values of the mask and make a sub-batches with it
    mask_values, indices = torch.unique(mask, return_inverse=True, dim=0)  
    for i, mask_v in enumerate(mask_values):
        idx_select = indices == i 
        m, P_C = _filter_update_mask_SR(
            H, d, R_C,
            m_m[idx_select], P_m_C[idx_select],
            obs[idx_select],
            mask_v,
        )
        ms[idx_select], P_Cs[idx_select] = m, P_C
    
    return ms, P_Cs

In [None]:
m, P_C = _filter_update_mask_batch_SR(H, d, R_C, m_m, P_m_C, obs, mask[:,0,:] )
show_as_row(m, P_C)
m.shape, P_C.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
m.sum().backward(retain_graph=True) # check that pytorch can compute gradients with the whole batch and gradients aren't nan
H.grad

tensor([[[-6.9338e+00, -5.2713e+00, -2.8998e+00, -1.0339e+01],
         [-1.9440e-03,  4.2816e-02,  6.5510e-02, -1.3323e-01],
         [ 1.5014e-01,  3.1149e-01,  4.6507e-01, -1.4330e+00]]],
       dtype=torch.float64)

### Filter All

The resursive version of the kalman filter is apperently breaking pytorch gradients calculations so a workaround is needed.
During the loop the states are saved in a python list and then at the end they are combined back into a tensor.
The last line of the function does:

- convert lists to tensors
- correct order dimensions

In [None]:
#| export
@patch
def _filter_all(self: KalmanFilterSR,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr]) 
            
           ) ->Tuple[ListMNormal, ListMNormal]: # (Filtered state, predicted state) with shape (n_batches, n_obs, self.n_dim_state)
    """Filter observations using kalman filter """
    obs, mask, control = self._parse_obs(obs, mask, control)
    bs, n_obs = obs.shape[0], obs.shape[1]
    # lists are mutable so need to copy them
    m_ms, P_m_Cs, ms, P_Cs = [[None for _ in range(n_obs)].copy() for _ in range(4)] 

    for t in range(n_obs):
        # Predict
        if t == 0:
            m_ms[t], P_m_Cs[t] = self.m0.expand(bs, -1, -1), self.P0_C.expand(bs, -1, -1)
        else:
            m_ms[t], P_m_Cs[t] = _filter_predict_SR(self.A, self.Q_C, self.b, self.B, ms[t - 1], P_Cs[t - 1], control[:,t,:])
        
        # Update
        ms[t], P_Cs[t] = _filter_update_mask_batch_SR(self.H, self.d, self.R_C, m_ms[t], P_m_Cs[t], obs[:,t,:], mask[:,t,:])
        
        if self.cov_checker is not None:
            self.cov_checker.check(P_m_Cs[t] @ P_m_Cs[t].mT, t=t, name="filter_predict", type="SR")
            self.cov_checker.check(P_Cs[t] @ P_Cs[t].mT, t=t, name="filter_update", type="SR")
    
    m_ms, P_m_Cs, ms, P_Cs = list(maps(torch.stack, _times2batch, (m_ms, P_m_Cs, ms, P_Cs,))) # reorder dimensions and convert to tensor
    return ListMNormal(ms, P_Cs), ListMNormal(m_ms, P_m_Cs) 

In [None]:
filt_stateSR, pred_stateSR  = kSR._filter_all(data, mask, control)

In [None]:
(ms, P_Cs), (m_ms, P_m_Cs) = filt_stateSR, pred_stateSR

Predictions at time `0` for both batches

In [None]:
show_as_row(*map(Self.shape(), (m_ms, P_m_Cs, ms, P_Cs,)))

In [None]:
show_as_row(*map(lambda x:x[0][0], (m_ms, P_m_Cs, ms, P_Cs,)))

### Filter

The filter methods wraps `_filter_all` but in addition:

- returns only filtered state
- remove last dimensions from mean

In [None]:
#| export
@patch
def filter(self: KalmanFilterSR,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr])
          ) -> ListMNormal: # Filtered state (n_batches, n_obs, self.n_dim_state)
    """Filter observation"""
    filt_state, _ = self._filter_all(obs, mask, control)
    return filt_state

In [None]:
filtSR = kSR.filter(data, mask, control)
filtSR.mean.shape, filtSR.cov.shape

(torch.Size([2, 10, 4, 1]), torch.Size([2, 10, 4, 4]))

### Comparison

## Additional

### Constructors

#### Simple parameters

In [None]:
#| export
@patch(cls_method=True)
def init_simple(cls: KalmanFilter,
                n_dim, # n_dim_obs and n_dim_state
                dtype=torch.float64):
    """Simplest version of kalman filter parameters"""
    return cls(
        A =     torch.eye(n_dim, dtype=dtype),
        b =        torch.zeros(n_dim, dtype=dtype),        
        Q =        torch.eye(n_dim, dtype=dtype),        
        H =       torch.eye(n_dim, dtype=dtype),
        d =          torch.zeros(n_dim, dtype=dtype),          
        R =          torch.eye(n_dim, dtype=dtype),            
        B =     torch.eye(n_dim, dtype=dtype),
        m0 =  torch.zeros(n_dim, dtype=dtype),        
        P0 =   torch.eye(n_dim, dtype=dtype),
    )

In [None]:
KalmanFilter.init_simple(2).state_dict()

OrderedDict([('A',
              tensor([[[1., 0.],
                       [0., 1.]]], dtype=torch.float64)),
             ('H',
              tensor([[[1., 0.],
                       [0., 1.]]], dtype=torch.float64)),
             ('B',
              tensor([[[1., 0.],
                       [0., 1.]]], dtype=torch.float64)),
             ('Q_raw',
              tensor([[[1., 0.],
                       [0., 1.]]], dtype=torch.float64)),
             ('R_raw',
              tensor([[[1., 0.],
                       [0., 1.]]], dtype=torch.float64)),
             ('b',
              tensor([[[0.],
                       [0.]]], dtype=torch.float64)),
             ('d',
              tensor([[[0.],
                       [0.]]], dtype=torch.float64)),
             ('m0',
              tensor([[[0.],
                       [0.]]], dtype=torch.float64)),
             ('P0_raw',
              tensor([[[1., 0.],
                       [0., 1.]]], dtype=torch.float64))])

#### Local slope

Local slope models are an extentions of local level model that in the state variable keep track of also the slope

Given $n$ as the number of dimensions of the observations

The transition matrix (`A`) is:

$$A = \left[\begin{array}{cc}I & I \\ 0 & I\end{array}\right]$$

where:

- $I \in \mathbb{R}^{n \times n}$
- $A \in \mathbb{R}^{2n \times 2n}$

the state $x \in \mathbb{R}^{2N \times 1}$ where the upper half keep track of the level and the lower half of the slope. $A \in \mathbb{R}^2N \times 2N$

the observation matrix (`H`) is:

$$H = \left[\begin{array}{cc}I & 0 \end{array}\right]$$

For the multivariate case the 1 are replaced with an identiy matrix


assuming that the control has the same dimensions of the observations then if we are doing a local slope model we have $B \in \mathbb{R}^{state \times contr}$:
$$ B = \begin{bmatrix} -I & I \\ 0 & 0 \end{bmatrix}$$

In [None]:
#| export
from torch import hstack, eye, vstack, ones, zeros, tensor
from functools import partial
from sklearn.decomposition import PCA

In [None]:
#| exporti
def set_dtype(*args, dtype=torch.float64):
    return [partial(arg, dtype=dtype) for arg in args] 

eye, ones, zeros, tensor = set_dtype(eye, ones, zeros, tensor)

In [None]:
#| export
# @delegates(KalmanFilter)
@patch(cls_method=True)
def init_local_slope_pca(cls: KalmanFilter,
                n_dim_obs, # n_dim_obs and n_dim_contr
                n_dim_state: int, # n_dim_state
                df_pca: pd.DataFrame|None = None, # dataframe for PCA init, None no PCA init
                **kwargs
            ):
    """Local Slope + PCA init"""
    if df_pca is not None:
        comp = PCA(n_dim_state).fit(df_pca).components_
        H = tensor(comp.T) # transform state -> obs
        B = tensor(comp) # transform obs -> state
    else:
        H, B = eye(n_dim_obs), eye(n_dim_obs)
        
    return cls(
        A =     vstack([hstack([eye(n_dim_state),                eye(n_dim_state)]),
                                   hstack([zeros(n_dim_state, n_dim_state), eye(n_dim_state)])]),
        b =        zeros(n_dim_state * 2),        
        Q =        eye(n_dim_state * 2)*.1,        
        H =       hstack([H, zeros(n_dim_obs, n_dim_state)]),
        d =          zeros(n_dim_obs),          
        R =          eye(n_dim_obs)*.01,            
        B =     vstack([hstack([-B,                  B]),
                                   hstack([ zeros(n_dim_state,n_dim_obs), zeros(n_dim_state, n_dim_obs)])]),
        m0 =  zeros(n_dim_state * 2),        
        P0 =   eye(n_dim_state * 2) * 3,
        **kwargs
    ) 

In [None]:
KalmanFilter.init_local_slope_pca(2,2,pd.DataFrame([[1,2], [2,4]])).state_dict()

OrderedDict([('A',
              tensor([[[1., 0., 1., 0.],
                       [0., 1., 0., 1.],
                       [0., 0., 1., 0.],
                       [0., 0., 0., 1.]]], dtype=torch.float64)),
             ('H',
              tensor([[[ 0.4472,  0.8944,  0.0000,  0.0000],
                       [ 0.8944, -0.4472,  0.0000,  0.0000]]], dtype=torch.float64)),
             ('B',
              tensor([[[-0.4472, -0.8944,  0.4472,  0.8944],
                       [-0.8944,  0.4472,  0.8944, -0.4472],
                       [ 0.0000,  0.0000,  0.0000,  0.0000],
                       [ 0.0000,  0.0000,  0.0000,  0.0000]]], dtype=torch.float64)),
             ('Q_raw',
              tensor([[[0.3162, 0.0000, 0.0000, 0.0000],
                       [0.0000, 0.3162, 0.0000, 0.0000],
                       [0.0000, 0.0000, 0.3162, 0.0000],
                       [0.0000, 0.0000, 0.0000, 0.3162]]], dtype=torch.float64)),
             ('R_raw',
              tensor([[[0.1000, 0.0000]

## Export

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()