# Kalman Filter
> Implementation of Kalman filters using pytorch and parameter optimizations with gradient descend

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
#| default_exp kalman.filter

In [None]:
#| export
from fastcore.test import *
from fastcore.basics import *
from meteo_imp.utils import *
from meteo_imp.gaussian import *
from typing import *
from functools import partial

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

## Introduction

The models uses a latent state variable $x$ that is modelled over time, to impute gaps in $y$

The assumption of the model is that the state variable at time $x_t$ depends only on the last state $x_{t-1}$ and not on the previous states.

### Equations

The equations of the model are:

$$\begin{align} p(x_t | x_{t-1}) & = \mathcal{N}(Ax_{t-1} + b, Q) \\
p(y_t | x_t) & = \mathcal{N}(Hx_t + d, R) \end{align}$$


where:

- $A$ is the `A`
- $b$ is the `bset`
- $Q$ is the `Q`
- $H$ is the `obs_trans` 
- $d$ is the `d`
- $R$ is the `R`

in addition the model has also the parameters of the initial state that are used to initialize the filter:

- `m0`
- `P0`

The Kalman filter has 3 steps:

- filter (updating the state at time t with observations till time t-1)
- update (update the state at time t using the observation at time t)
- smooth (update the state using the observations at time t+1)

In case of missing data the update step is skipped.

After smoothing the missing data at time t ($y_t$) can be imputed from the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

The Kalman Filter is an algorithm designed to estimate $P(x_t | y_{0:t})$.  As all state transitions and obss are linear with Gaussian distributed noise, these distributions can be represented exactly as Gaussian distributions with mean `ms[t]` and covs `Ps[t]`.
Similarly, the Kalman Smoother is an algorithm designed to estimate $P(x_t | y_{0:t-1})$



# Kalman Filter Base

### Utils

In [None]:
#| export
def _add_batch_dim(x):
    """make x 3 dimensional by adding empty dims in the correct place"""
    if x.dim() == 1: return x.unsqueeze(0).unsqueeze(-1)
    elif x.dim() == 2: return x.unsqueeze(0)
    else: return x

In [None]:
#| export
def _add_batch_dims_iter(*xs):
    """vectorize `add_batch_dim`"""
    return [_add_batch_dim(x) for x in xs]

In [None]:
show_as_row(_add_batch_dim(torch.ones(2)), _add_batch_dim(torch.ones(2,2)), _add_batch_dim(torch.ones(2,2,2)))

In [None]:
show_as_row(_add_batch_dim(torch.ones(2)).shape, _add_batch_dim(torch.ones(2,2)).shape, _add_batch_dim(torch.ones(2,2,2)).shape)

In [None]:
#| export
def _check_same_size(
    os: Sequence[tuple[Tensor, int]], # sequences of tensors and the dimension to check
    size=None, # Optional size of the common dimension
)-> int: # size of common dimension
    """Check that all args have the same size at the given dimension, raise `ValueError` if not """
    size = ifnone(size, os[0][0].shape[os[0][1]])
    if not all([size == x.shape[dim] for x, dim in os]):
        raise ValueError("All parameters must have the same size at the given dimension")
    return size

## Kalman Filter Base

In [None]:
#| export
class KalmanFilterBase(torch.nn.Module):
    """Base class for handling Kalman Filter implementation in PyTorch"""
    
    params_constr = {
        #name constraint
        'A':  None        ,
        'b':  None        ,
        'Q':  PosDef(),
        'B':  None        ,
        'H':  None        ,
        'd':  None        ,
        'R':  PosDef(),
        'm0': None       ,
        'P0': PosDef()   ,
        }
    
    def __init__(self,
            A: Tensor,                             # [n_dim_state,n_dim_state] $A$, state transition matrix 
            H: Tensor,                             # [n_dim_obs, n_dim_state] $H$, observation matrix
            B: Tensor,                             # [n_dim_state, n_dim_contr] $B$ control matrix
            Q: Tensor,                             # [n_dim_state, n_dim_state] $Q$, state trans covariance matrix
            R: Tensor,                             # [n_dim_obs, n_dim_obs] $R$, observations covariance matrix
            b: Tensor,                             # [n_dim_state] $b$, state transition offset
            d: Tensor,                             # [n_dim_obs] $d$, observations offset
            m0: Tensor,                            # [n_dim_state] $m_0$
            P0: Tensor,                            # [n_dim_state, n_dim_state] $P_0$
    
            n_dim_state: int = None,               # Number of dimensions for state - default infered from parameters
            n_dim_obs: int = None,                 # Number of dimensions for observations - default  infered from parameters
            n_dim_contr: int = None,               # Number of dimensions for control - default infered from parameters
            
            var_names: Iterable[str]|None = None,  # Names of variables for printing 
            contr_names: Iterable[str]|None = None,# Names of control variables for printing
    
            cov_checker: CheckPosDef|None = None,  # Check covariance at every step
            use_conditional: bool = False,         # Use conditional distribution for gaps that don't have all variables missing
            use_control: bool = True,              # Use the control in the filter
            use_smooth: bool = True,               # Use smoother for predictions (otherwise is filter only)
            pred_only_gap: bool = False,           # it True predictions are only for the gap
            pred_std: bool = False,                # return only stds and not covariances
                ):
        
        super().__init__()
        store_attr("var_names, contr_names, use_conditional, use_control, use_smooth, cov_checker, pred_only_gap, pred_std")
        
        A, H, B, Q, R, b, d, m0, P0 = _add_batch_dims_iter(A, H, B, Q, R, b, d, m0, P0)
        
        self._check_params(A, H, B, Q, R, b, d, m0, P0, n_dim_state, n_dim_obs, n_dim_contr)
        self._init_params(A=A, H=H, B=B, Q=Q, R=R, b=b, d=d, m0=m0, P0=P0)

    
    def _check_params(self, A, H, B, Q, R, b, d, m0, P0, n_dim_state, n_dim_obs, n_dim_contr):
        """Checks that the parameters are dimensions are consistent and sets n_dim"""
        self.n_dim_state = _check_same_size(
            [(A,  -2),
             (b,  -2),
             (Q,  -2),
             (m0, -2),
             (P0, -2),
             (H,  -1)],
            n_dim_state
        )
        self.n_dim_obs = _check_same_size(
            [(H, -2),
             (d, -2),
             (R, -2)],
            n_dim_obs
        )
        
        self.n_dim_contr = _check_same_size([(B, -1)], n_dim_contr)
        
        
    def _init_params(self, **params):
        for name, value in params.items():
            if (constraint := self.params_constr[name]) is not None:
                name, value = self._init_constraint(name, value, constraint)
            self._init_param(name, value, train=True)    
    
    def _init_param(self, param_name, value, train):
        self.register_parameter(param_name, torch.nn.Parameter(value, requires_grad=train))
    
    ### === Constraints utils
    def _init_constraint(self, param_name, value, constraint):
        name = f"{param_name}_raw"
        value = constraint.inverse_transform(value)
        setattr(self, param_name + "_constraint", constraint)
        return name, value
    
    def _get_constraint(self, param_name):
        """get the original value"""
        constraint = getattr(self, param_name + "_constraint")
        raw_value = getattr(self, f"{param_name}_raw")
        return constraint.transform(raw_value)
    
    def _get_constraint_cho_fact(self, param_name):
        """get the original value"""
        constraint = getattr(self, param_name + "_constraint")
        raw_value = getattr(self, f"{param_name}_raw")
        return constraint.transform_cho_factor(raw_value)

    def _set_constraint(self, value, param_name, train=True):
            """set the transformed value"""
            constraint = getattr(self, param_name + "_constraint")
            raw_value = constraint.inverse_transform(value)
            self._init_param(f"{param_name}_raw", raw_value, train)
            
               
    @property
    def Q_C(self): return self._get_constraint_cho_fact('Q')
    @property
    def Q(self): return self._get_constraint('Q')
    @Q.setter
    def Q(self, value): self._set_constraint(value, 'Q')

    @property
    def R_C(self): return self._get_constraint_cho_fact('R')
    @property
    def R(self): return self._get_constraint('R')
    @R.setter
    def R(self, value): self._set_constraint(value, 'R')
    
    @property
    def P0_C(self): return self._get_constraint_cho_fact('P0')
    @property
    def P0(self): return self._get_constraint('P0')
    @P0.setter
    def P0(self, value): self._set_constraint(value, 'P0')


    ### === Utility Func    
    def _parse_obs(self, obs, mask, control):
        """maybe get mask from `nan`"""
        # if mask is None: mask = ~torch.isnan(obs)
        return _add_batch_dim(obs).unsqueeze(-1), _add_batch_dim(mask), _add_batch_dim(control).unsqueeze(-1)
    def __repr__(self):
        return f"""Kalman Filter
        N dim obs: {self.n_dim_obs},
        N dim state: {self.n_dim_state},
        N dim contr: {self.n_dim_contr}"""

### Constructors

Giving all the parameters manually to the `KalmanFilterBase` init method is not convenient, hence we are having some methods that help initize the class

due to a bug in fastcore cannot subclass after creating class methods

In [None]:
#| export
class KalmanFilter(KalmanFilterBase):
    pass

In [None]:
#| export
class KalmanFilterSR(KalmanFilterBase):
    pass

In [None]:
#| export
filter_classes = [KalmanFilterBase, KalmanFilter, KalmanFilterSR]

#### Random parameters

In [None]:
#| export
#| include: false
@patch_to(filter_classes, cls_method=True)
def init_random(cls,
                n_dim_obs,
                n_dim_state,
                n_dim_contr,
                dtype=torch.float64,
                seed:int|None = 27,
                **kwargs):
    """kalman filter with random parameters"""
    if seed is not None: torch.manual_seed(seed)
    return cls(
        A  = torch.rand(n_dim_state, n_dim_state, dtype=dtype),
        b  = torch.rand(n_dim_state, dtype=dtype),        
        Q  = to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),        
        B  = torch.rand(n_dim_state, n_dim_contr, dtype=dtype),
        H  = torch.rand(n_dim_obs, n_dim_state, dtype=dtype),
        d  = torch.rand(n_dim_obs, dtype=dtype),          
        R  = to_posdef(torch.rand(n_dim_obs, n_dim_obs, dtype=dtype)),            
        m0 = torch.rand(n_dim_state, dtype=dtype),        
        P0 = to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),
        **kwargs) 
        

In [None]:
kB = KalmanFilterBase.init_random(3,4, 3, dtype=torch.float64)
kB

Kalman Filter
        N dim obs: 3,
        N dim state: 4,
        N dim contr: 3

In [None]:
kB.Q

tensor([[[1.5725, 0.3829, 0.0843, 0.3637],
         [0.3829, 1.4430, 0.3358, 1.1988],
         [0.0843, 0.3358, 1.7816, 0.7411],
         [0.3637, 1.1988, 0.7411, 1.6773]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

In [None]:
kB.Q_C

tensor([[[1.2540, 0.0000, 0.0000, 0.0000],
         [0.3053, 1.1618, 0.0000, 0.0000],
         [0.0672, 0.2714, 1.3052, 0.0000],
         [0.2901, 0.9556, 0.3542, 0.7446]]], dtype=torch.float64,
       grad_fn=<DiagonalScatterBackward0>)

In [None]:
test_close(kB.Q_C @ kB.Q_C.mT, kB.Q, eps=2e-5)

In [None]:
kB.P0 = to_posdef(torch.rand(1,3,3))

check that assigment works :)

In [None]:
kB.P0 = to_posdef(torch.rand(4, 4, dtype=torch.float64))

In [None]:
kB.P0_C

tensor([[0.9349, 0.0000, 0.0000, 0.0000],
        [0.3928, 1.0748, 0.0000, 0.0000],
        [0.7406, 0.7533, 0.8326, 0.0000],
        [0.5903, 0.0391, 0.8217, 0.9638]], dtype=torch.float64,
       grad_fn=<DiagonalScatterBackward0>)

In [None]:
kB = KalmanFilterBase.init_random(3,4, 3, dtype=torch.float64)
kB

Kalman Filter
        N dim obs: 3,
        N dim state: 4,
        N dim contr: 3

In [None]:
list(kB.named_parameters())

[('A',
  Parameter containing:
  tensor([[[0.1751, 0.5375, 0.7676, 0.7450],
           [0.1204, 0.4777, 0.5823, 0.3786],
           [0.8484, 0.2317, 0.3969, 0.7088],
           [0.0270, 0.6178, 0.6097, 0.9128]]], dtype=torch.float64,
         requires_grad=True)),
 ('H',
  Parameter containing:
  tensor([[[0.4984, 0.1597, 0.4458, 0.7349],
           [0.1070, 0.9531, 0.1942, 0.6683],
           [0.9186, 0.7123, 0.1806, 0.5045]]], dtype=torch.float64,
         requires_grad=True)),
 ('B',
  Parameter containing:
  tensor([[[0.2827, 0.1138, 0.9378],
           [0.5594, 0.9364, 0.5136],
           [0.8592, 0.7647, 0.5183],
           [0.2376, 0.5618, 0.5096]]], dtype=torch.float64, requires_grad=True)),
 ('Q_raw',
  Parameter containing:
  tensor([[[0.4590, 0.0000, 0.0000, 0.0000],
           [0.5348, 0.5296, 0.0000, 0.0000],
           [0.0025, 0.0749, 0.4603, 0.0000],
           [0.0588, 0.5416, 0.9309, 0.5859]]], dtype=torch.float64,
         requires_grad=True)),
 ('R_raw',
  Parameter

In [None]:
kB.state_dict()

OrderedDict([('A',
              tensor([[[0.1751, 0.5375, 0.7676, 0.7450],
                       [0.1204, 0.4777, 0.5823, 0.3786],
                       [0.8484, 0.2317, 0.3969, 0.7088],
                       [0.0270, 0.6178, 0.6097, 0.9128]]], dtype=torch.float64)),
             ('H',
              tensor([[[0.4984, 0.1597, 0.4458, 0.7349],
                       [0.1070, 0.9531, 0.1942, 0.6683],
                       [0.9186, 0.7123, 0.1806, 0.5045]]], dtype=torch.float64)),
             ('B',
              tensor([[[0.2827, 0.1138, 0.9378],
                       [0.5594, 0.9364, 0.5136],
                       [0.8592, 0.7647, 0.5183],
                       [0.2376, 0.5618, 0.5096]]], dtype=torch.float64)),
             ('Q_raw',
              tensor([[[0.4590, 0.0000, 0.0000, 0.0000],
                       [0.5348, 0.5296, 0.0000, 0.0000],
                       [0.0025, 0.0749, 0.4603, 0.0000],
                       [0.0588, 0.5416, 0.9309, 0.5859]]], dtype=torch.float64)

#### From filter

In [None]:
#| export
@patch(cls_method=True)
def init_from(cls: KalmanFilter|KalmanFilterBase|KalmanFilterSR, o: filter_classes # Other filter
             ):
    """Initialize Filter by copying all parameters from another one"""
    return cls(o.A, o.H, o.B, o.Q, o.R, o.b, o.d, o.m0, o.P0,
               o.n_dim_state, o.n_dim_obs, o.n_dim_contr,
               o.var_names, o.contr_names, o.cov_checker,
               o.use_conditional, o.use_control, o.use_smooth, o.pred_only_gap, o.pred_std)

In [None]:
k1 = KalmanFilter.init_random(3,4,3)
k2 = KalmanFilterSR.init_from(k1)
for p1, p2 in zip(k1.parameters(), k2.parameters()):
    test_close(p1,p2, eps=1e-3) #noise added by contraints
                

### Get Info

In [None]:
#| export
@patch
def get_info(self: KalmanFilterBase):
    out = {}
    var_names = ifnone(self.var_names, [f"y_{i}" for i in range(self.n_dim_obs)])
    latent_names = [f"x_{i}" for i in range(self.n_dim_state)]
    contr_names = ifnone(self.contr_names, [f"c_{i}" for i in range(self.n_dim_contr)])
    out['$A$'] = array2df(self.A[0] , latent_names, latent_names, 'state')
    out['$Q$']    = array2df(self.Q[0] , latent_names, latent_names, 'state')
    out['$b$']        = array2df(self.b[0] , latent_names, ['offset'],   'state')
    out['$H$']   = array2df(self.H[0] , var_names,    latent_names, 'variable')
    out['$R$']      = array2df(self.R[0] , var_names,    var_names,    'variable')
    out['$d$']          = array2df(self.d[0] , var_names,    ['offset'],   'variable')
    out['$B$'] = array2df(self.B[0] , latent_names, contr_names,  'state')
    out['$m_0$']  = array2df(self.m0[0], latent_names, ['mean'],     'state')
    out['$P_0$']   = array2df(self.P0[0], latent_names, latent_names, 'state')

    return out

In [None]:
#| export
@patch
def _repr_html_(self: filter_classes):
    title = f"Kalman Filter ({self.n_dim_obs} obs, {self.n_dim_state} state, {self.n_dim_contr} contr)"
    return row_dfs(self.get_info(), title , hide_idx=True)

In [None]:
kB

state,x_0,x_1,x_2,x_3
x_0,0.1751,0.5375,0.7676,0.745
x_1,0.1204,0.4777,0.5823,0.3786
x_2,0.8484,0.2317,0.3969,0.7088
x_3,0.027,0.6178,0.6097,0.9128

state,x_0,x_1,x_2,x_3
x_0,0.9002,0.5074,0.0024,0.0558
x_1,0.5074,1.2712,0.0756,0.569
x_2,0.0024,0.0756,0.9072,0.9246
x_3,0.0558,0.569,0.9246,2.2208

state,offset
x_0,0.4661
x_1,0.3918
x_2,0.0571
x_3,0.9529

variable,x_0,x_1,x_2,x_3
y_0,0.4984,0.1597,0.4458,0.7349
y_1,0.107,0.9531,0.1942,0.6683
y_2,0.9186,0.7123,0.1806,0.5045

variable,y_0,y_1,y_2
y_0,0.5856,0.019,0.5991
y_1,0.019,1.3107,0.513
y_2,0.5991,0.513,2.3748

variable,offset
y_0,0.9334
y_1,0.5645
y_2,0.0695

state,c_0,c_1,c_2
x_0,0.2827,0.1138,0.9378
x_1,0.5594,0.9364,0.5136
x_2,0.8592,0.7647,0.5183
x_3,0.2376,0.5618,0.5096

state,mean
x_0,0.8234
x_1,0.385
x_2,0.338
x_3,0.4376

state,x_0,x_1,x_2,x_3
x_0,0.886,0.4535,0.449,0.2316
x_1,0.4535,1.8632,0.674,1.3448
x_2,0.449,0.674,2.0374,1.2929
x_3,0.2316,1.3448,1.2929,2.0978


### Test data

In [None]:
#| exporti
def get_test_data(n_obs = 10, n_dim_obs=3, n_dim_contr = 3, p_missing=.3, gap_len=None, bs=2, dtype=torch.float64, device='cpu'):
    data = torch.rand(bs, n_obs, n_dim_obs, dtype=dtype, device=device)
    mask = torch.rand(bs, n_obs, n_dim_obs, device=device)
    if gap_len is not None:
        mask[:, n_obs//2-gap_len//2,n_obs//2+gap_len//2, :] = False
    else:
        mask = mask > p_missing
    control = torch.rand(bs, n_obs, n_dim_contr, dtype=dtype, device=device)
    data[~mask] = torch.nan # ensure that the missing data cannot be used
    return data, mask, control

In [None]:
reset_seed()
data, mask, control = get_test_data()
show_as_row(data, mask, control)

# Standard Kalman Filter

In [None]:
k = KalmanFilter.init_random(3,4,3)

## Filter

### Filter predict

Probability of state at time `t` given state a time `t-1` 

$p(x_t) = \mathcal{N}(x_t; m_t^-, P_t^-)$ where:

- predicted state mean: $m_t^- = Am_{t-1} + B c_t + b$  

- predicted state covariance: $P_t^- = AP_{t-1}A^T + Q$

In [None]:
A, Q, b, B, m_pr,P_pr= (k.A, k.Q, k.b, k.B,torch.concat([k.m0]*2), torch.concat([k.P0]*2))

In [None]:
m_pr.shape, P_pr.shape, A.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]), torch.Size([1, 4, 4]))

#### Covariance

implement $P_t^- = AP_{t-1}A^T + Q$

In [None]:
#| export
def _filter_predict_cov_stand(A, Q, P_pr):
    """Standard - Kalman Filter predict covariance"""
    return A @ P_pr @ A.mT + Q

In [None]:
P_m = _filter_predict_cov_stand(A, Q, P_pr)
P_m

tensor([[[6.2873, 3.7655, 4.8718, 5.8589],
         [3.7655, 4.8551, 4.7314, 5.1009],
         [4.8718, 4.7314, 5.9373, 6.0160],
         [5.8589, 5.1009, 6.0160, 8.5954]],

        [[6.2873, 3.7655, 4.8718, 5.8589],
         [3.7655, 4.8551, 4.7314, 5.1009],
         [4.8718, 4.7314, 5.9373, 6.0160],
         [5.8589, 5.1009, 6.0160, 8.5954]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

#### Mean

In [None]:
#| export
def _filter_predict_mean(
    A,      # transition matrix
    B,      # control matrix
    b,      # transition offset
    m_pr,   # Mean previous time step $m_{t-1}$
    control, # control variable
):
    return A @ m_pr + B @ control + b

#### Predict

In [None]:
#| export
def _filter_predict_mean(
    A,      # transition matrix
    B,      # control matrix
    b,      # transition offset
    m_pr,   # Mean previous time step $m_{t-1}$
    control, # control variable
):
    return A @ m_pr + B @ control + b

In [None]:
#| export
def _filter_predict(A,
                    Q,
                    b,
                    B, #[n_dim_state, n_dim_contr]
                    m_pr,
                    P_pr,
                    control, #[n_batches, n_dim_contr]
                    ):
    """Calculate the state at time `t` given the state at time `t-1`"""
    m_m = _filter_predict_mean(A, B, b, m_pr, control)
    P_m = _filter_predict_cov_stand(A, Q, P_pr)
    return (m_m, P_m)

In [None]:
B.shape

torch.Size([1, 4, 3])

In [None]:
control.shape

torch.Size([2, 10, 3])

In [None]:
B[0].shape

torch.Size([4, 3])

In [None]:
B[0] @ control[0, 0].unsqueeze(-1)

tensor([[0.7217],
        [0.9572],
        [0.5742],
        [1.2138]], dtype=torch.float64, grad_fn=<MmBackward0>)

In [None]:
m_m, P_m = _filter_predict(
    A, Q, b, B,
    m_pr,P_pr, control[:, 0, :].unsqueeze(-1))

In [None]:
show_as_row(m_m, P_m)

In [None]:
(m_m.shape, P_m.shape,)

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

### Filter update

Probability of state at time `t` given the observations at time `t`

$p(x_t|y_t) = \mathcal{N}(x_t; m_t, P_t)$ where:

- predicted obs mean: $z_t = Hm_t^- + d$  

- prediced obs covariance: $S_t = HP_t^-H^T + R$

- kalman gain$K_t = P_t^-H^TS_t^{-1}$ 

- corrected state mean: $m_t = m_t^- + K_t(y_t - z_t)$ 

- corrected state covariance: $P_t = (I-K_tH)P_t^-$ 

if the observation are missing this step is skipped and the corrected state is equal to the predicted state


Need to figure out the Nans for the gradients ...

#### Kalman Gain

Don't compute the inverse of the matrix, but use `cholesky_solve` to invert the matrix

In [None]:
H, d, R, obs = k.H, k.d, k.R, data[:,0,:].unsqueeze(-1)

In [None]:
#| export
def _filter_update_k_gain(H, R,P_m):
    "kalman gain for filter update"
    S = H @ P_m @ H.mT + R
    S_C = torch.linalg.cholesky(S)
    return torch.cholesky_solve(H @ P_m.mT, S_C).mT

In [None]:
K = _filter_update_k_gain(H, R, P_m)
K

tensor([[[ 0.1177,  0.2313,  0.0612],
         [-0.5573,  0.6582,  0.0288],
         [-0.3189,  0.5919, -0.0174],
         [ 0.3269,  0.3056, -0.0378]],

        [[ 0.1177,  0.2313,  0.0612],
         [-0.5573,  0.6582,  0.0288],
         [-0.3189,  0.5919, -0.0174],
         [ 0.3269,  0.3056, -0.0378]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
test_close(_filter_update_k_gain(H, R, P_m), P_m @ H.mT @ torch.inverse(H @ P_m @ H.mT + R))

#### Covariance

In [None]:
#| export
def _filter_update_cov(H, K, P_m):
    return (eye_like(P_m) - K @ H) @ P_m

In [None]:
P = _filter_update_cov(H, K, P_m)
P

tensor([[[ 2.1074, -0.2158,  0.3602,  0.1911],
         [-0.2158,  0.4808,  0.0472, -0.1503],
         [ 0.3602,  0.0472,  0.8035, -0.0177],
         [ 0.1911, -0.1503, -0.0177,  0.8448]],

        [[ 2.1074, -0.2158,  0.3602,  0.1911],
         [-0.2158,  0.4808,  0.0472, -0.1503],
         [ 0.3602,  0.0472,  0.8035, -0.0177],
         [ 0.1911, -0.1503, -0.0177,  0.8448]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

#### Mean

In [None]:
z = H @ m_m + d; z
(obs - z)

tensor([[[-3.3343],
         [    nan],
         [    nan]],

        [[-2.4181],
         [-5.0754],
         [-3.9145]]], dtype=torch.float64, grad_fn=<SubBackward0>)

In [None]:
#| export
def _filter_update_mean(H, d, K, m_m, y):
    z = H @ m_m + d
    return m_m + K @ (y - z)

In [None]:
m = _filter_update_mean(H, d, K, m_m, obs)
m

tensor([[[    nan],
         [    nan],
         [    nan],
         [    nan]],

        [[ 0.4396],
         [-0.5094],
         [-0.6276],
         [ 0.6379]]], dtype=torch.float64, grad_fn=<AddBackward0>)

In [None]:
#| export
def _filter_update(
    H, # [1, n_dim_obs, n_dim_state]
    d, # [1, n_dim_obs, 1]
    R, # [1, n_dim_obs, n_dim_obs]
    m_m, # [n_batches, n_dim_state, 1]
    P_m, # [n_batches, n_dim_state, n_dim_state]
    obs # # [n_batches, n_dim_obs, 1]
) -> Tuple: # Filtered state (mean, covariance) [n_batches, n_dim_state]
    "Filter update state at `t` with obs at `t`"
    K = _filter_update_k_gain(H, R, P_m)
    m = _filter_update_mean(H, d, K, m_m, obs)
    P = _filter_update_cov(H, K, P_m)
    return m, P

In [None]:
m, P = _filter_update(H, d, R, m_m, P_m, obs)
show_as_row(m, P)
m.shape, P.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

there are `nan` in the output because there are `nan` in the observations

The next functions adds the support for missing obsevations by also using a mask

#### Missing observations

If all the observations at time $t$ are missing the correct step is skipped and the filtered state at time $t$ () is the same of the filtered state.

If only some observations are missing a variation of equation can be used.

$y^{ng}_t$ is a vector containing the observations that are not missing at time $t$. 

It can be expressed as a linear transformation of $y_t$

$$ y^{ng}_t = My_t$$

where $M$ is a mask matrix that is used to select the subset of $y_t$ that is observed. $M \in \mathbb{R}^{n_{ng} \times n}$ and is made of columns which are made of all zeros but for an entry 1 at row corresponding to the non-missing observation.
hence:

$$ p(y^{ng}_t) = \mathcal{N}(M\mu_{y_t},  M\Sigma_{y_t}M^T)$$

from which you can derive

$$ p(y^{ng}_t|x_t) = p(MHx_t + Mb, MRM^T) $${#eq-filter-correct}

Then the posterior $p(x_t|y_t^{ng})$ can be computed similarly of equation @filter_correct as:

$$ p(x_t|y^{ng}_t) = \mathcal{N}(x_t; m_t, P_t) $${#eq-filter_correct_missing}
    
where:

*  predicted obs mean: $z_t = MHm_t^- + Md$
*  predicted obs covariance: $S_t = MHP_t^-(MH)^T + MRM^T$
*  Kalman gain $K_t = P_t^-(MH)^TS_t^{-1}$
*  corrected state mean: $m_t = m_t^- + K_t(My_t - z_t)$
*  corrected state covariance: $P_t = (I-K_tMH)P_t^-$


##### Details implementation 

For the implementation the matrix multiplication $MH$ can be replaced with `H[m]` where `m` is the mask for the rows for `H` and $MRM^T$ with `R[m][:,m]`

In [None]:
H, R, d,obs, mm = k.H, k.R, k.d, data[:,0,:].unsqueeze(-1), mask[:,0,:].unsqueeze(-1)

In [None]:
m = torch.tensor([False,True,True]) # mask batch
M = torch.tensor([[[0,1,0], # mask matrix
                  [0,0,1]]], dtype=torch.float64)
show_as_row(m, M, H, R)

In [None]:
M @ M.mT

tensor([[[1., 0.],
         [0., 1.]]], dtype=torch.float64)

In [None]:
M @ H, H[:, m]

(tensor([[[0.1679, 0.8635, 0.3753, 0.9760],
          [0.2125, 0.8049, 0.2124, 0.6794]]], dtype=torch.float64,
        grad_fn=<UnsafeViewBackward0>),
 tensor([[[0.1679, 0.8635, 0.3753, 0.9760],
          [0.2125, 0.8049, 0.2124, 0.6794]]], dtype=torch.float64,
        grad_fn=<IndexBackward0>))

In [None]:
M @ R @ M.mT, R[:,m][:,:,m]

(tensor([[[1.2831, 0.9935],
          [0.9935, 2.4550]]], dtype=torch.float64,
        grad_fn=<UnsafeViewBackward0>),
 tensor([[[1.2831, 0.9935],
          [0.9935, 2.4550]]], dtype=torch.float64, grad_fn=<IndexBackward0>))

By using partially missing observations `_filter_update` cannot be easily batched as the shape of the intermediate variables depends on the number of observed variables. So the idea is to divide the batch in blocks that share the same number of variables missing.

In [None]:
mask_values, indices = torch.unique(mask[:,1,:], dim=0, return_inverse=True)
mask_values, indices

(tensor([[ True, False,  True],
         [ True,  True, False]]),
 tensor([0, 1]))

##### Update mask

In [None]:
mm = mask[0,0,:]

In [None]:
#| export
def _filter_update_mask(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_dim_obs] mask must be the same across batches
                       ):
    """Update state at time `t` given observations at time `t` assuming that all observations have the same mask"""
    if (~mask).all(): return (m_m, P_m) # all data is missing
    H_m, d_m, R_m, obs_m, = H[:, mask,:], d[:, mask,:], R[:, mask,:][:, :,mask], obs[:, mask] # _m for masked
    return _filter_update(H_m, d_m, R_m, m_m, P_m, obs_m)

In [None]:
H[:, mm].shape, d[:, mm].shape, R[:, mm][:, :,mm].shape, obs[:, mm].shape

(torch.Size([1, 1, 4]),
 torch.Size([1, 1, 1]),
 torch.Size([1, 1, 1]),
 torch.Size([2, 1, 1]))

In [None]:
show_as_row(*_filter_update_mask(H, d, R, m_m, P_m, obs, mask[0, 0, :] ))

In [None]:
m, P = _filter_update_mask(H, d, R, m_m, P_m, obs, mask[0, 0, :] )
m.shape, P.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

##### Update mask batch

In [None]:
#| export
def _filter_update_mask_batch(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_batches, n_dim_obs] mask must be the same across batches
                       ):
    """Support batches with different masks when update state at time `t` given observations at time `t`"""
    
    ms, Ps= torch.empty_like(m_m), torch.empty_like(P_m)
    
    # find the unique values of the mask and make a sub-batches with it
    mask_values, indices = torch.unique(mask, return_inverse=True, dim=0)  
    for i, mask_v in enumerate(mask_values):
        idx_select = indices == i 
        m, P = _filter_update_mask(
            H, d, R,
            m_m[idx_select], P_m[idx_select],
            obs[idx_select],
            mask_v,
        )
        ms[idx_select], Ps[idx_select] = m, P
    
    return ms, Ps

In [None]:
m, P = _filter_update_mask_batch(H, d, R, m_m, P_m, obs, mask[:,0,:] )
show_as_row(m, P)
m.shape, P.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
m.sum().backward(retain_graph=True) # check that pytorch can compute gradients with the whole batch and gradients aren't nan
H.grad

tensor([[[-1.1769, -2.7620, -0.9158, -2.4954],
         [-2.0431,  0.8269,  0.5074, -1.5867],
         [ 0.1472,  0.0285,  0.1012,  0.0356]]], dtype=torch.float64)

### Filter All

The resursive version of the kalman filter is apperently breaking pytorch gradients calculations so a workaround is needed.
During the loop the states are saved in a python list and then at the end they are combined back into a tensor.
The last line of the function does:

- convert lists to tensors
- correct order dimensions

In [None]:
#| export
def _times2batch(x):
    """Permutes `x` so that the first dimension is the number of batches and not the times"""
    return x.permute(1,0,-2,-1)

In [None]:
#| export
@patch
def _filter_all(self: KalmanFilter,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr]) 
            
           ) ->Tuple[ListMNormal, ListMNormal]: # (Filtered state, predicted state) with shape (n_batches, n_obs, self.n_dim_state)
    """Filter observations using kalman filter """
    obs, mask, control = self._parse_obs(obs, mask, control)
    n_obs = obs.shape[1]
    bs = obs.shape[0]
    # lists are mutable so need to copy them
    m_ms, P_ms, ms, Ps = [[None for _ in range(n_obs)].copy() for _ in range(4)] 

    for t in range(n_obs):
        # --- Predict
        if t == 0:
            m_ms[t], P_ms[t] = self.m0.expand(bs, -1, -1), self.P0.expand(bs, -1, -1)
        else:
            m_ms[t], P_ms[t] = _filter_predict(self.A, self.Q, self.b,
                                               self.B if self.use_control else torch.zeros_like(self.B), # maybe disable control
                                               ms[t - 1], Ps[t - 1], control[:,t,:])
        
        # --- Update
        ms[t], Ps[t] = _filter_update_mask_batch(self.H, self.d, self.R, m_ms[t], P_ms[t], obs[:,t,:], mask[:,t,:])
        
        if self.cov_checker is not None:
            self.cov_checker.check(P_ms[t], t=t, name="filter_predict")
            self.cov_checker.check(Ps[t], t=t, name="filter_update")
    
    m_ms, P_ms, ms, Ps = list(maps(torch.stack, _times2batch, (m_ms, P_ms, ms, Ps,))) # reorder dimensions and convert to tensor
    return ListMNormal(ms, Ps), ListMNormal(m_ms, P_ms) 

In [None]:
filt_state, pred_state  = k._filter_all(data, mask, control)

In [None]:
(ms, Ps), (m_ms, P_ms) = filt_state, pred_state

Predictions at time `0` for both batches

In [None]:
show_as_row(*map(Self.shape(), (m_ms, P_ms, ms, Ps,)))

In [None]:
show_as_row(*map(lambda x:x[0][0], (m_ms, P_ms, ms, Ps,)))

### Filter

The filter methods wraps `_filter_all` but in addition:

- returns only filtered state
- remove last dimensions from mean

In [None]:
#| export
@patch
def filter(self: KalmanFilter,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr])
          ) -> ListMNormal: # Filtered state (n_batches, n_obs, self.n_dim_state)
    """Filter observation"""
    filt_state, _ = self._filter_all(obs, mask, control)
    return filt_state

In [None]:
filt = k.filter(data, mask, control)
filt.mean.shape, filt.cov.shape

(torch.Size([2, 10, 4, 1]), torch.Size([2, 10, 4, 4]))

## Smooth

### Smooth update step

compute the probability of the state at time `t` given all the observations

$p(x_t|Y) = \mathcal{N}(x_t; m_t^s, P_t^s)$ where:

- Kalman smoothing gain: $G_t = P_tA^T(P_{t+1}^-)^{-1}$
- smoothed mean: $m_t^s = m_t + G_t(m_{t+1}^s - m_{t+1}^-)$
- smoothed covariance: $P_t^s = P_t + G_t(P_{t+1}^s - P_{t+1}^-)G_t^T$

In [None]:
#| export
def _smooth_gain(A, P, P_m):
    S_C = torch.linalg.cholesky(P_m)
    return torch.cholesky_solve(A @ P, S_C).mT

In [None]:
test_close(_smooth_gain(A, filt_state.cov, pred_state.cov), filt_state.cov @ A.mT @ torch.inverse(pred_state.cov))

In [None]:
K_p = _smooth_gain(A, filt_state[:,0].cov, pred_state[:,0].cov)
K_p.shape

torch.Size([2, 4, 4])

In [None]:
#| export
def _smooth_mean(K_p,                # [n_dim_state, n_dim_state]
               m,         # [n_dim_state] filtered state at time `t`
               m_m,         # [n_dim_state] state before filtering at time `t + 1` (= using the observation until time t)
               next_m_p): # [n_dim_state] smoothed state at time  `t+1`
    return m + K_p @ (next_m_p - m_m)

In [None]:
_smooth_mean(K_p, filt_state[:,0].mean, pred_state[:,0].mean, filt_state[:,0].mean)

tensor([[[ 0.6254],
         [ 0.2922],
         [ 0.7965],
         [ 0.6964]],

        [[-0.1829],
         [-0.7088],
         [-0.0211],
         [ 0.2570]]], dtype=torch.float64, grad_fn=<AddBackward0>)

In [None]:
#| export
def _smooth_cov(K_p, P, P_m, next_P_p):
    P_p = P + K_p @ (next_P_p - P_m) @ K_p.mT 
    return (P_p + P_p.mT) / 2 # force symmetric to improve num stability 

In [None]:
_smooth_cov(K_p, filt_state[:,0].cov, pred_state[:,0].cov, filt_state[:,0].cov)

tensor([[[ 0.5327,  0.3321,  0.2094, -0.1773],
         [ 0.3321,  0.8061,  0.0930, -0.0363],
         [ 0.2094,  0.0930,  0.5427, -0.0125],
         [-0.1773, -0.0363, -0.0125,  0.6738]],

        [[ 0.3342,  0.0671,  0.0652, -0.2001],
         [ 0.0671,  0.3478, -0.1115, -0.1245],
         [ 0.0652, -0.1115,  0.4549, -0.0025],
         [-0.2001, -0.1245, -0.0025,  0.6981]]], dtype=torch.float64,
       grad_fn=<DivBackward0>)

In [None]:
#| export
def _smooth_update(A,                # [n_dim_state, n_dim_state]
                   filt_state: MNormal,         # [n_dim_state] filtered state at time `t`
                   pred_state: MNormal,         # [n_dim_state] state before filtering at time `t + 1` (= using the observation until time t)
                   next_smoothed_state: Normal, # [n_dim_state] smoothed state at time  `t+1`
                   ) -> MNormal:                # mean and cov of smoothed state at time `t`
    """Correct a pred state with a Kalman Smoother update"""
    smooth_gain = _smooth_gain(A, filt_state.cov, pred_state.cov)

    m_p = _smooth_mean(smooth_gain, filt_state.mean, pred_state.mean, next_smoothed_state.mean)
    P_p = _smooth_cov(smooth_gain,  filt_state.cov, pred_state.cov, next_smoothed_state.cov)
    
    return MNormal(m_p, P_p)

In [None]:
show_as_row(*_smooth_update(A, filt_state[:, 0], pred_state[:, 0], filt_state[:, 0]))

In [None]:
show_as_row(*map(Self.shape(), _smooth_update(A, MNormal(m_m, P_m), MNormal(m_m, P_m), MNormal(m_m, P_m))))

### Smooth

In [None]:
#| export
def _smooth(A, # `[n_dim_state, n_dim_state]`
            filt_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `ms[t]` is the state estimate for time t given obs from times `[0...t]`
            pred_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `m_ms[t]` is the state estimate for time t given obs from times `[0...t-1]`
            cov_checker = None
           ) -> ListMNormal: # `[n_timesteps, n_dim_state]` Smoothed state 
    """Apply the Kalman Smoother"""
    n_obs = pred_state.mean.shape[1]

    smoothed_state = ListMNormal(torch.zeros_like(filt_state.mean), torch.zeros_like(filt_state.cov))
    # For the last timestep cannot use the smoother
    smoothed_state.mean[:,-1] = filt_state.mean[:,-1]
    smoothed_state.cov[:,-1] = filt_state.cov[:,-1]

    for t in reversed(range(n_obs - 1)):
        (smoothed_state.mean[:,t], smoothed_state.cov[:,t]) = (
            _smooth_update(
                A,
                filt_state[:,t],
                pred_state[:,t + 1],
                smoothed_state[:,t+1],
            )
        )
        if cov_checker is not None:
            cov_checker.check(smoothed_state.cov[:,t], name="smooth", t=t)
    return smoothed_state

In [None]:
smooth_state = _smooth(k.A,  filt_state, pred_state)

In [None]:
show_as_row(smooth_state.mean[0][0], smooth_state.cov[0][0])

In [None]:
show_as_row(smooth_state.mean.shape, smooth_state.cov.shape)

### KalmanFilter method

In [None]:
#| export
@patch
def smooth(self: KalmanFilter,
           obs: Tensor,
           mask: Tensor,
           control: Tensor
          ) -> ListMNormal: # `[n_timesteps, n_dim_state]` smoothed state
        
    """Kalman Filter Smoothing"""

    filt_state, pred_state = self._filter_all(obs, mask, control)

    smoothed_state = _smooth(self.A,
                   filt_state, pred_state,
                   self.cov_checker)
    return smoothed_state

In [None]:
smoothed_state = k.smooth(data, mask, control)

In [None]:
show_as_row(smoothed_state.mean.shape, smoothed_state.cov.shape)

In [None]:
smoothed_state.mean.sum().backward(retain_graph=True)
A.grad

tensor([[[-4.9495, -5.2939, -7.8642,  0.2546],
         [ 0.2445,  8.3747, 13.5693, -7.4543],
         [-5.4811, -5.2497, -6.6092, -1.5187],
         [-1.5874,  2.1945,  3.6410, -2.3143]]], dtype=torch.float64)

## Predict

The prediction at time t ($y_t$) are computed rom the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

this works both if the state was filtered or smoother

This add the supports for conditional predictions, which means that at the time (t) when we are making the predictions some of the variables have been actually observed. Since the model prediction is a normal distribution we can condition on the observed values and thus improve the predictions. See `conditional_gaussian`

In order to have conditional predictions that make sense it's not possible to return the full covariance matrix for the predictions but only the standard deviations

In [None]:
test_m = torch.tensor(
    [[True, True, True,],
    [False, True, True],
    [False, False, False]]
)

In [None]:
torch.logical_xor(test_m.all(-1), test_m.any(-1))

tensor([False,  True, False])

In [None]:
A = torch.rand(2,2,3,3)

In [None]:
(A @ A).shape

torch.Size([2, 2, 3, 3])

predict can be vectorized across both the batch and the timesteps, except for timesteps that require conditional predictions

In [None]:
#| export
@patch
def _obs_from_state(self: KalmanFilter, state: ListMNormal):

    mean = self.H @ state.mean + self.d
    cov = self.H @ state.cov @ self.H.mT + self.R
    
    if self.cov_checker is not None:
        for c in cov: # this is batched and for all timestamps
            self.cov_checker.check(c, caller='predict')
    
    return ListMNormal(mean.squeeze(-1), cov)

In [None]:
smoothed_state.mean.shape, smoothed_state.cov.shape

(torch.Size([2, 10, 4, 1]), torch.Size([2, 10, 4, 4]))

In [None]:
(k.H @ smoothed_state.mean).shape

torch.Size([2, 10, 3, 1])

In [None]:
pred_obs0 = k._obs_from_state(smoothed_state)
pred_obs0.mean.shape

torch.Size([2, 10, 3])

In [None]:
pred_obs0.cov.shape

torch.Size([2, 10, 3, 3])

In [None]:
#| export
@patch
def predict(self: KalmanFilter, obs, mask, control, smooth=True):
    """Predicted observations at all times """
    if self.use_conditional and self.pred_only_gap:
        raise ValueError("Kalman Filter predict cannot have conditional predictions and all predictions at the same time")
    
    state = self.smooth(obs, mask, control) if smooth else self.filter(obs, mask, control)
    obs, mask, control = self._parse_obs(obs, mask, control)
    
    pred_obs = self._obs_from_state(state)
    pred_mean, pred_std = pred_obs.mean, cov2std(pred_obs.cov)
    
    if self.use_conditional:
        # conditional predictions are slow, do only if some obs are missing 
        mask = mask.squeeze(0)
        cond_mask = torch.logical_xor(mask.all(-1), mask.any(-1))
        # there may be no conditional prediction to do
        if cond_mask.any():
            # this cannot be batched so returns a list
            cond_preds = cond_gaussian_batched(pred_obs[cond_mask], obs[cond_mask].squeeze(-1), mask[cond_mask])

            for i, c_pred in enumerate(cond_preds):
                m = ~mask[cond_mask][i]
                pred_mean[cond_mask][i][m] = c_pred.mean
                pred_std [cond_mask][i][m] = cov2std(c_pred.cov)
    
    return ListNormal(pred_mean, pred_std)

In [None]:
pred = k.predict(data, mask, control)

In [None]:
pred.mean.shape, pred.std.shape

(torch.Size([2, 10, 3]), torch.Size([2, 10, 3]))

Gradients ...

In [None]:
def get_grad_mask(x):
    "filter gradient after sub the masks value with x"
    d = data.clone()
    d[~mask] = x
    k.predict(data, mask, control).mean.sum().backward(retain_graph=True)
    grad = k.R_raw.grad.clone()
    k.zero_grad() 
    return grad

In [None]:
get_grad_mask(10)

tensor([[[ -2.5082,   0.0000,   0.0000],
         [-11.7667,  -0.1206,   0.0000],
         [ -7.6345,  -8.3390,  -4.0905]]], dtype=torch.float64)

In [None]:
test_close(get_grad_mask(1), get_grad_mask(10))

In [None]:
#| export
@patch
def predict(self: KalmanFilter, obs, mask, control, smooth=True):
    """Predicted observations at all times """
    state = self.smooth(obs, mask, control) if smooth else self.filter(obs, mask, control)
    obs, mask, control = self._parse_obs(obs, mask, control)
    
    pred_obs = self._obs_from_state(state)
    pred_mean, pred_std = pred_obs.mean, cov2std(pred_obs.cov)
    
    if self.use_conditional:
        # conditional predictions are slow, do only if some obs are missing 
        mask = mask.squeeze(0)
        cond_mask = torch.logical_xor(mask.all(-1), mask.any(-1))
        # this cannot be batched so returns a list
        cond_preds = cond_gaussian_batched(
            pred_obs[cond_mask], obs[cond_mask].squeeze(-1), mask[cond_mask])
        
        for i, c_pred in enumerate(cond_preds):
            m = ~mask[cond_mask][i]
            pred_mean[cond_mask][i][m] = c_pred.mean
            pred_std [cond_mask][i][m] = cov2std(c_pred.cov)
    
    return ListNormal(pred_mean, pred_std)

In [None]:
@patch
def predict_times(self: KalmanFilter, times, obs, mask=None, smooth=True, check_args=None):
    """Predicted observations at specific times """
    state = self.smooth(obs, mask, check_args) if smooth else self.filter(obs, mask, check_args)
    obs, mask = self._parse_obs(obs, mask)
    times = array1d(times)
    
    n_timesteps = obs.shape[0]
    n_features = obs.shape[1] if len(obs.shape) > 1 else 1
    
    if times.max() > n_timesteps or times.min() < 0:
        raise ValueError(f"provided times range from {times.min()} to {times.max()}, which is outside allowed range : 0 to {n_timesteps}")

    means = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device)
    stds = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device) 
    for i, t in enumerate(times):
        mean, std = self._obs_from_state(
            state.mean[t],
            state.cov[t],
            {'t': t, **check_args} if check_args is not None else None
        )
        
        means[i], stds[i] = _get_cond_pred(ListNormal(mean, std), obs[t], mask[t])
    
    return ListNormal(means, stds)  

# Numerical Stable Kalman Filter (Square Root Filter)

In [None]:
kSR = KalmanFilterSR.init_random(3,4,3)

## Filter 

### Filter predict

#### Covariance

Implement the numerical stable version of the covariance update

In [None]:
A, Q, b, B, m_pr,P_pr= (k.A, k.Q, k.b, k.B,torch.concat([k.m0]*2), torch.concat([k.P0]*2))

In [None]:
Q_C = kSR.Q_C

In [None]:
_filter_predict_cov_stand(kSR.A, kSR.Q_C @ kSR.Q_C.mT, P_pr)

tensor([[[1.9504, 2.3535, 2.1727, 2.2936],
         [2.3535, 5.8436, 5.0464, 6.0831],
         [2.1727, 5.0464, 6.2940, 6.0969],
         [2.2936, 6.0831, 6.0969, 8.1538]],

        [[1.9504, 2.3535, 2.1727, 2.2936],
         [2.3535, 5.8436, 5.0464, 6.0831],
         [2.1727, 5.0464, 6.2940, 6.0969],
         [2.2936, 6.0831, 6.0969, 8.1538]]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [None]:
P_pr_C = torch.linalg.cholesky(P_pr)

$$W = \begin{bmatrix}AC_{t-1}&C_Q\end{bmatrix}$$

In [None]:
W = torch.concat([A @ P_pr_C, Q_C.expand_as(P_pr_C)], dim=-1)
W.shape

torch.Size([2, 4, 8])

In [None]:
P_m_C = torch.linalg.qr(W.mT).R.mT

In [None]:
P_m_C

tensor([[[-2.5709,  0.0000,  0.0000,  0.0000],
         [-1.7904,  1.1590,  0.0000,  0.0000],
         [-1.9907,  1.3402,  1.2995,  0.0000],
         [-2.2985,  0.5614,  0.8125, -1.0728]],

        [[-2.5709,  0.0000,  0.0000,  0.0000],
         [-1.7904,  1.1590,  0.0000,  0.0000],
         [-1.9907,  1.3402,  1.2995,  0.0000],
         [-2.2985,  0.5614,  0.8125, -1.0728]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
P_m_C @ P_m_C.mT

tensor([[[6.6096, 4.6031, 5.1178, 5.9093],
         [4.6031, 4.5489, 5.1174, 4.7660],
         [5.1178, 5.1174, 7.4476, 6.3838],
         [5.9093, 4.7660, 6.3838, 7.4094]],

        [[6.6096, 4.6031, 5.1178, 5.9093],
         [4.6031, 4.5489, 5.1174, 4.7660],
         [5.1178, 5.1174, 7.4476, 6.3838],
         [5.9093, 4.7660, 6.3838, 7.4094]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

In [None]:
P_m = _filter_predict_cov_stand(A, Q_C @ Q_C.mT, P_pr)

In [None]:
test_close(P_m, P_m_C @ P_m_C.mT)

In [None]:
(P_m - P_m_C @ P_m_C.mT).max()

tensor(8.8818e-16, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
test_P_m_C = torch.linalg.cholesky(P_m)

In [None]:
(test_P_m_C @ test_P_m_C.mT - P_m_C @ P_m_C.mT).max()

tensor(8.8818e-16, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
(test_P_m_C - P_m_C).max()

tensor(5.1418, dtype=torch.float64, grad_fn=<MaxBackward1>)

Cholesky decomposition is not unique! but the solution is correct

In [None]:
#| export
def _filter_predict_cov_SR(A, # transition covariance $A_t$
                        Q_C, # Cholesky Factor of transition covariance $Q_t$
                        P_pr_C # Cholesky Factor of previous state covariance $P_{t-1}$
                       ):
    """Numerical stable Kalman filter predict for covariance"""
    W = torch.concat([A @ P_pr_C, Q_C.expand_as(P_pr_C)], dim=-1)
    return torch.linalg.qr(W.mT).R.mT 

In [None]:
P_m_C = _filter_predict_cov_SR(A, Q_C, P_pr_C)
test_close(P_m_C @ P_m_C.mT, _filter_predict_cov_stand(A, Q_C @ Q_C.mT, P_pr))

In [None]:
def fuzz_filter_predict_cov_SR(n=10):
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        A, Q_C, b, B, m_pr,P_pr = (kSR.A.unsqueeze(0), kSR.Q_C.unsqueeze(0), kSR.b.unsqueeze(-1),
                                                  kSR.B.unsqueeze(0),
                                                  torch.stack([kSR.m0]*2).unsqueeze(-1),
                                                  torch.stack([kSR.P0]*2))
        P_pr_C = torch.linalg.cholesky(P_pr)
        P_m_C = _filter_predict_cov_SR(A, Q_C, P_pr_C)
        test_close(P_m_C @ P_m_C.mT, _filter_predict_cov_stand(A, Q_C @ Q_C.mT, P_pr), eps=5e-13) 

In [None]:
fuzz_filter_predict_cov_SR()

#### Predict

In [None]:
#| export
def _filter_predict_SR(A, Q_C, b, B, m_pr, P_pr_C,control) -> Tuple: # predicted state
    """Calculate the state at time `t` given the state at time `t-1`"""
    m_m = _filter_predict_mean(A, B, b, m_pr, control)
    P_m_C = _filter_predict_cov_SR(A, Q_C, P_pr_C)
    return (m_m, P_m_C)

In [None]:
m_m, P_m_C = _filter_predict_SR(kSR.A, kSR.Q_C, kSR.b, kSR.B, m_pr, P_pr_C, control[:,0].unsqueeze(-1)) 
show_as_row(m_m, P_m_C)

In [None]:
is_posdef(P_m_C @ P_m_C.mT).all()

tensor(True)

### Filter Update

In [None]:
def is_sr(x_C, x): return torch.allclose(x_C @ x_C.mT, x)

In [None]:
#| export
def cat_2d(x): # matrix as list of list of Tensor
    return torch.cat([torch.cat(row, dim=-1) for row in x], dim=-2)

In [None]:
x = [[]]

#### Example calculations

In [None]:
H, R, R_C, obs = kSR.H, kSR.R, kSR.R_C, data[:,0,:].unsqueeze(-1)
P_m = P_m_C @ P_m_C.mT

# use standard filter to compute expected result
S = H @ P_m @ H.mT + R
S_C = torch.linalg.cholesky(S)
K = _filter_update_k_gain(H, R, P_m)
K_bar = K @ S_C
P_stand = _filter_update_cov(H, K, P_m)
P_C_stand = torch.linalg.cholesky(P_stand)

In [None]:
assert all([is_sr(R_C, R), is_sr(S_C, S)])

$$M = \begin{bmatrix} R^{1/2} & H(P^-)^{1/2} \\ 0 & (P^-)^{1/2} \end{bmatrix}$$

In [None]:
M = cat_2d([[R_C.expand(2,-1,-1)              , H @ P_m_C], 
            [torch.zeros_like((H @ P_m_C).mT),  P_m_C]])

In [None]:
M[0]

tensor([[ 1.2478,  0.0000,  0.0000, -1.2662, -1.1753,  0.7242, -0.3448],
        [ 0.3190,  1.2801,  0.0000, -2.1349, -1.5926,  1.0952, -0.3843],
        [ 0.0341,  0.9601,  0.7766, -3.2461, -2.1290,  0.8250, -0.5179],
        [ 0.0000,  0.0000,  0.0000, -1.3966,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -1.6852, -1.7331,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -1.5557, -1.3991,  1.3843,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -1.6423, -1.9130,  0.6253, -1.1858]],
       dtype=torch.float64, grad_fn=<SelectBackward0>)

$$V = \begin{bmatrix} S^{1/2} & 0 \\ \bar{K} & P^{1/2} \end{bmatrix}$$

In [None]:
# V from standard filter
V_stand = cat_2d([[S_C,   torch.zeros_like(K_bar.mT)],
                  [K_bar, P_C_stand]])

In [None]:
V_stand[0]

tensor([[ 2.2770,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 2.5904,  1.8631,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 3.2634,  2.2593,  1.3379,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.7766,  0.5205,  0.6151,  0.8355,  0.0000,  0.0000,  0.0000],
        [ 1.8316,  0.8659,  0.9168, -0.1000,  0.9427,  0.0000,  0.0000],
        [ 2.0275,  0.9734,  0.2654, -0.0859,  0.2527,  1.0461,  0.0000],
        [ 2.2790,  0.9606,  0.6923, -0.4813,  0.4183,  0.2013,  1.0540]],
       dtype=torch.float64, grad_fn=<SelectBackward0>)

In [None]:
test_close(M @ M.mT, V_stand @ V_stand.mT)

In [None]:
V_sr = torch.linalg.qr(M.mT).R.mT

In [None]:
V_sr[0]

tensor([[-2.2770,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-2.5904, -1.8631,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-3.2634, -2.2593, -1.3379,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.7766, -0.5205, -0.6151,  0.8355,  0.0000,  0.0000,  0.0000],
        [-1.8316, -0.8659, -0.9168, -0.1000,  0.9427,  0.0000,  0.0000],
        [-2.0275, -0.9734, -0.2654, -0.0859,  0.2527, -1.0461,  0.0000],
        [-2.2790, -0.9606, -0.6923, -0.4813,  0.4183, -0.2013, -1.0540]],
       dtype=torch.float64, grad_fn=<SelectBackward0>)

In [None]:
test_close(V_sr @ V_sr.mT, V_stand @ V_stand.mT)

In [None]:
n_dim_obs = R_C.shape[-1]
P_C = V_sr[:, n_dim_obs:, n_dim_obs:]
P_C

tensor([[[ 0.8355,  0.0000,  0.0000,  0.0000],
         [-0.1000,  0.9427,  0.0000,  0.0000],
         [-0.0859,  0.2527, -1.0461,  0.0000],
         [-0.4813,  0.4183, -0.2013, -1.0540]],

        [[ 0.8355,  0.0000,  0.0000,  0.0000],
         [-0.1000,  0.9427,  0.0000,  0.0000],
         [-0.0859,  0.2527, -1.0461,  0.0000],
         [-0.4813,  0.4183, -0.2013, -1.0540]]], dtype=torch.float64,
       grad_fn=<SliceBackward0>)

In [None]:
is_sr(P_C, P_stand)

True

$P_C$ computed with the QR decomposition is not the same from the cholesky decomposition, but that's okay! They are all valid square roots of the posterior covariance

In [None]:
(P_C == P_C_stand).all()

tensor(False)

In [None]:
(H @ P_m_C).shape

torch.Size([2, 3, 4])

#### Covariance

In [None]:
#| export
def tensor_info(x): return {'dtype': x.dtype, 'device': x.device}

In [None]:
#| export
def _filter_update_cov_SR(
    H,
    R_C,
    P_m_C
) -> Tuple: # (P_C, S_C) Square roots of filtered covariance and S
    """Covariance measurement update for Square root Filter"""
    bs,n, k = P_m_C.shape[0], R_C.shape[-1], H.shape[-1] # batch size, dim observations, dim_state
    zeros = partial(torch.zeros, **tensor_info(H))
    
    M = cat_2d([[R_C.expand(bs,-1,-1),   H @ P_m_C], 
                [zeros(bs, k, n),        P_m_C    ]])

    V = torch.linalg.qr(M.mT).R.mT

    P_C = V[:, n:, n:]
    S_C = V[:, :n, :n]
    return P_C, S_C

In [None]:
P_C, S_C = _filter_update_cov_SR(kSR.H, kSR.R_C, P_m_C) 

In [None]:
def fuzz_filter_update_cov_SR(n=10):
    errs = []
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        H, R_C, P_m = (kSR.H, kSR.R_C, torch.cat([kSR.P0]*5))
        R = R_C @ R_C.mT
        P_m_C = torch.linalg.cholesky(P_m)
        P_C, _ = _filter_update_cov_SR(H, R_C, P_m_C)
        K = _filter_update_k_gain(H, R, P_m)
        P_stand = _filter_update_cov(H, K, P_m)
        errs.append((P_C @ P_C.mT -  P_stand).abs().max().item())
    return torch.tensor(errs)

In [None]:
err = fuzz_filter_update_cov_SR(100)
assert err.max() < torch.tensor(1e-10)
err.median(), err.max()

(tensor(2.8588e-15), tensor(1.0214e-14))

#### Kalman Gain

Don't compute the inverse of the matrix, but use `cholesky_solve` to invert the matrix

In [None]:
#| export
def _filter_update_k_gain_SR(
    H,
    P_m_C, # square root of $P^-$
    S_C # square root of S = (HPH^T +R)
):
    """kalman gain for filter update for SR filter"""
    return torch.cholesky_solve(H @ P_m_C @ P_m_C.mT, S_C).mT

In [None]:
S = kSR.H @ P_m_C @ P_m_C.mT @ kSR.H.mT + kSR.R

In [None]:
test_close(kSR.R, kSR.R_C @ kSR.R_C.mT)

In [None]:
test_close(S_C @ S_C.mT, S)

In [None]:
_filter_update_k_gain_SR(H, P_m_C, S_C)

tensor([[[-0.0014, -0.2782,  0.4598],
         [ 0.2389, -0.3662,  0.6852],
         [ 0.2854,  0.2819,  0.1984],
         [ 0.3866, -0.1119,  0.5174]],

        [[-0.0014, -0.2782,  0.4598],
         [ 0.2389, -0.3662,  0.6852],
         [ 0.2854,  0.2819,  0.1984],
         [ 0.3866, -0.1119,  0.5174]]], dtype=torch.float64,
       grad_fn=<TransposeBackward0>)

In [None]:
test_close(_filter_update_k_gain(H, R, P_m), _filter_update_k_gain_SR(H, P_m_C, S_C))

In [None]:
def fuzz_kalman_gain_SR(n=10):
    errs = []
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        H, R_C, P_m = (kSR.H, kSR.R_C, torch.cat([kSR.P0]*5))
        R = R_C @ R_C.mT
        P_m_C = torch.linalg.cholesky(P_m)
        P_C, S_C = _filter_update_cov_SR(H, R_C, P_m_C)
        K_stand = _filter_update_k_gain(H, R, P_m)
        K = _filter_update_k_gain_SR(H, P_m_C, S_C)
        errs.append((K_stand - K).abs().max().item())
    return torch.tensor(errs)

In [None]:
err = fuzz_kalman_gain_SR(100)
assert err.max() < torch.tensor(1e-10)
err.median(), err.max()

(tensor(4.5797e-15), tensor(2.5202e-14))

#### Measurement update

In [None]:
#| export
def _filter_update_SR(
    H, # [1, n_dim_obs, n_dim_state]
    d, # [1, n_dim_obs, 1]
    R_C, # [1, n_dim_obs, n_dim_obs]
    m_m, # [n_batches, n_dim_state, 1]
    P_m_C, # [n_batches, n_dim_state, n_dim_state] square root predicted covariance
    obs # # [n_batches, n_dim_obs, 1]
) -> Tuple: # Filtered state (mean, chol_covariance) [n_batches, n_dim_state]
    "Filter update state at `t` with obs at `t`"
    P_C, S_C = _filter_update_cov_SR(H, R_C, P_m_C)
    K = _filter_update_k_gain_SR(H, P_m_C, S_C)
    m = _filter_update_mean(H, d, K, m_m, obs)
    return m, P_C

In [None]:
m, P_C = _filter_update_SR(H, d, R_C, m_m, P_m_C, obs)
show_as_row(m, P_C)
m.shape, P_C.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
get_test_data(1, 5,4, bs=1)[1].shape

torch.Size([1, 1, 5])

In [None]:
def fuzz_filter_update_SR(n=10):
    errs = {'mean': [], 'cov': []}
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        H, d, R, R_C, m_m, P_m = (kSR.H, kSR.d, kSR.R, kSR.R_C, torch.cat([kSR.m0]*5), torch.cat([kSR.P0]*5))
        obs = torch.randn_like(H @ m_m)
        P_m_C = torch.linalg.cholesky(P_m)
        mSR, P_m_C = _filter_update_SR(H, d, R_C, m_m, P_m_C, obs)
        m, P_C = _filter_update(H, d, R, m_m, P_m, obs)
        errs['mean'].append((mSR - m).abs().max().item())
        errs['cov'].append((P_C - P_m_C @ P_m_C.mT).abs().max().item())
    return pd.DataFrame(errs)

err = fuzz_filter_update_SR(100)

err.median(), err.max()

(mean    1.049161e-14
 cov     2.740863e-15
 dtype: float64,
 mean    5.573320e-14
 cov     1.099121e-14
 dtype: float64)

#### Missing observations

##### Update mask

Here need to compute the square root of $R$, because cannot apply the mask to $R^{1/2}$

In [None]:
R

tensor([[[1.5571, 0.3980, 0.0426],
         [0.3980, 1.7403, 1.2398],
         [0.0426, 1.2398, 1.5260]]], dtype=torch.float64,
       grad_fn=<UnsafeViewBackward0>)

In [None]:
is_posdef(R)

tensor([True])

In [None]:
R_m = torch.tensor([[1.5571,  0.0426], [0.0426, 1.5259]])

In [None]:
R_m

tensor([[1.5571, 0.0426],
        [0.0426, 1.5259]])

In [None]:
is_posdef(R_m)

tensor(True)

In [None]:
m = [True, False, True]

In [None]:
is_posdef(R[:, m,:][:, :, m])

tensor([True])

In [None]:
#| export
def _filter_update_mask_SR(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m_C, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_dim_obs] mask must be the same across batches
                       ):
    """SR Filter Update state at time `t` given observations at time `t` assuming that all observations have the same mask"""
    if (~mask).all(): return (m_m, P_m_C) # all data is missing
    H_m, d_m, R_m, obs_m, = H[:, mask,:], d[:, mask,:], R[:, mask,:][:, :,mask], obs[:, mask] # _m for masked
    R_C_m = torch.linalg.cholesky(R_m)
    return _filter_update_SR(H_m, d_m, R_C_m, m_m, P_m_C, obs_m)

In [None]:
H_m, d_m, R_m, R_C_m, obs_m, = H[:, m,:], d[:, m,:], R[:, m,:][:, :,m], R_C[:, m,:][:, :,m], obs[:, m]

In [None]:
R2 = R_m
R2_C_m = torch.linalg.cholesky(R_m) 

In [None]:
is_sr(R_C_m, R_m)

False

In [None]:
_filter_update_SR(H_m, d_m, R_C_m, m_m, P_m_C, obs_m)[0] - _filter_update(H_m, d_m, R_m, m_m, P_m_C @ P_m_C.mT, obs_m)[0]

tensor([[[    nan],
         [    nan],
         [    nan],
         [    nan]],

        [[-0.1152],
         [-0.1834],
         [-0.1472],
         [-0.1785]]], dtype=torch.float64, grad_fn=<SubBackward0>)

In [None]:
_filter_update_SR(H_m, d_m, R2_C_m, m_m, P_m_C, obs_m)[0] - _filter_update(H_m, d_m, R_m, m_m, P_m_C @ P_m_C.mT, obs_m)[0]

tensor([[[        nan],
         [        nan],
         [        nan],
         [        nan]],

        [[-4.4409e-16],
         [-4.4409e-16],
         [ 0.0000e+00],
         [-8.8818e-16]]], dtype=torch.float64, grad_fn=<SubBackward0>)

In [None]:
show_as_row(*_filter_update_mask_SR(H, d, R_C, m_m, P_m_C, obs, mask[0, 0, :] ))

In [None]:
m, P_C = _filter_update_mask_SR(H, d, R_C, m_m, P_m_C, obs, mask[0, 0, :] )
m.shape, P_C.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
mask[0,0].shape

torch.Size([3])

In [None]:
def fuzz_filter_update_SR(n=10):
    errs = {'mean': [], 'cov': []}
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        H, d, R, R_C, m_m, P_m = (kSR.H, kSR.d, kSR.R, kSR.R_C, torch.cat([kSR.m0]*5), torch.cat([kSR.P0]*5))
        obs, mask, _  = get_test_data(1, 5, 4, bs=5)
        obs, mask = obs[:,0].unsqueeze(-1), mask[0,0]
        
        P_m_C = torch.linalg.cholesky(P_m)
        mSR, P_m_C = _filter_update_mask_SR(H, d, R, m_m, P_m_C, obs, mask)
        m, P_C = _filter_update_mask(H, d, R, m_m, P_m, obs, mask)
        errs['mean'].append((mSR - m).abs().max().item())
        errs['cov'].append((P_C - P_m_C @ P_m_C.mT).abs().max().item())
    return pd.DataFrame(errs)

err = fuzz_filter_update_SR(100)

err.median(), err.max()

(mean    3.996803e-15
 cov     1.970646e-15
 dtype: float64,
 mean    3.996803e-15
 cov     7.827072e-15
 dtype: float64)

##### Update mask batch

In [None]:
#| export
def _filter_update_mask_batch_SR(
        H, # [1, n_dim_obs, n_dim_state]
        d, # [1, n_dim_obs, 1]
        R, # [1, n_dim_obs, n_dim_obs]
        m_m, # [n_batches, n_dim_state, 1]
        P_m_C, # [n_batches, n_dim_state, n_dim_state]
        obs, # [n_batches, n_dim_obs, 1] observations
        mask # [n_batches, n_dim_obs] mask must be the same across batches
                       ):
    """Support batches with different masks when update state at time `t` given observations at time `t`"""
    
    ms, P_Cs= torch.empty_like(m_m), torch.empty_like(P_m_C)
    
    # find the unique values of the mask and make a sub-batches with it
    mask_values, indices = torch.unique(mask, return_inverse=True, dim=0)  
    for i, mask_v in enumerate(mask_values):
        idx_select = indices == i 
        m, P_C = _filter_update_mask_SR(
            H, d, R,
            m_m[idx_select], P_m_C[idx_select],
            obs[idx_select],
            mask_v,
        )
        ms[idx_select], P_Cs[idx_select] = m, P_C
    
    return ms, P_Cs

In [None]:
m, P_C = _filter_update_mask_batch_SR(H, d, R, m_m, P_m_C, obs, mask[:,0,:] )
show_as_row(m, P_C)
m.shape, P_C.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
m.sum().backward(retain_graph=True) # check that pytorch can compute gradients with the whole batch and gradients aren't nan
H.grad

tensor([[[-5.3359, -5.2561, -7.0879, -7.6131],
         [ 0.0176, -0.2548, -0.2340, -0.2578],
         [-0.2513, -0.9494, -1.2821, -1.5227]]], dtype=torch.float64)

### Filter All

The resursive version of the kalman filter is apperently breaking pytorch gradients calculations so a workaround is needed.
During the loop the states are saved in a python list and then at the end they are combined back into a tensor.
The last line of the function does:

- convert lists to tensors
- correct order dimensions

In [None]:
#| export
@patch
def _filter_all(self: KalmanFilterSR,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr]) 
            
           ) ->Tuple[ListMNormal, ListMNormal]: # (Filtered state, predicted state) with shape (n_batches, n_obs, self.n_dim_state)
    """Filter observations using kalman filter """
    obs, mask, control = self._parse_obs(obs, mask, control)
    bs, n_obs = obs.shape[0], obs.shape[1]
    # lists are mutable so need to copy them
    m_ms, P_m_Cs, ms, P_Cs = [[None for _ in range(n_obs)].copy() for _ in range(4)] 

    for t in range(n_obs):
        # Predict
        if t == 0:
            m_ms[t], P_m_Cs[t] = self.m0.expand(bs, -1, -1), self.P0_C.expand(bs, -1, -1)
        else:
            m_ms[t], P_m_Cs[t] = _filter_predict_SR(self.A, self.Q_C, self.b,
                                                    self.B if self.use_control else torch.zeros_like(self.B),
                                                    ms[t - 1], P_Cs[t - 1], control[:,t,:])
        
        # Update
        ms[t], P_Cs[t] = _filter_update_mask_batch_SR(self.H, self.d, self.R, m_ms[t], P_m_Cs[t], obs[:,t,:], mask[:,t,:])
        
        if self.cov_checker is not None:
            self.cov_checker.check(P_m_Cs[t] @ P_m_Cs[t].mT, t=t, name="filter_predict", type="SR")
            self.cov_checker.check(P_Cs[t] @ P_Cs[t].mT, t=t, name="filter_update", type="SR")
    
    m_ms, P_m_Cs, ms, P_Cs = list(maps(torch.stack, _times2batch, (m_ms, P_m_Cs, ms, P_Cs,))) # reorder dimensions and convert to tensor
    return ListMNormal(ms, P_Cs), ListMNormal(m_ms, P_m_Cs) 

In [None]:
filt_stateSR, pred_stateSR  = kSR._filter_all(data, mask, control)

In [None]:
(ms, P_Cs), (m_ms, P_m_Cs) = filt_stateSR, pred_stateSR

Predictions at time `0` for both batches

In [None]:
show_as_row(*map(Self.shape(), (m_ms, P_m_Cs, ms, P_Cs,)))

In [None]:
show_as_row(*map(lambda x:x[0][0], (m_ms, P_m_Cs, ms, P_Cs,)))

### Filter

The filter methods wraps `_filter_all` but in addition:

- returns only filtered state
- remove last dimensions from mean

In [None]:
#| export
@patch
def filter(self: KalmanFilterSR,
            obs: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            mask: Tensor, # `([n_batches], n_obs, [self.n_dim_obs])` where `n_batches` and `n_dim_obs` dimensions can be omitted if 1
            control: Tensor, # ([n_batches], n_obs, [self.n_dim_contr])
          ) -> ListMNormal: # Filtered state (n_batches, n_obs, self.n_dim_state)
    """Filter observation"""
    filt_state, _ = self._filter_all(obs, mask, control)
    return filt_state

In [None]:
filtSR = kSR.filter(data, mask, control)
filtSR.mean.shape, filtSR.cov.shape

(torch.Size([2, 10, 4, 1]), torch.Size([2, 10, 4, 4]))

In [None]:
def fuzz_filter_SR(n=10):
    errs = {'mean': [], 'cov': []}
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        k = KalmanFilter.init_from(kSR)
        dat = get_test_data(20, 5, 4)
        mean, cov = k.filter(*dat)
        meanSR, covSR = kSR.filter(*dat)
        covSR = covSR @ covSR.mT
        errs['mean'].append((meanSR - mean).abs().max().item())
        errs['cov'].append((covSR - cov).abs().max().item())
    return pd.DataFrame(errs)

err = fuzz_filter_SR(10)

err.median(), err.max()

(mean    8.108847e-12
 cov     8.149592e-12
 dtype: float64,
 mean    4.446576e-11
 cov     1.261309e-10
 dtype: float64)

## Smooth

### Smooth update step

compute the probability of the state at time `t` given all the observations

$p(x_t|Y) = \mathcal{N}(x_t; m_t^s, P_t^s)$ where:

- Kalman smoothing gain: $G_t = P_tA^T(P_{t+1}^-)^{-1}$
- smoothed mean: $m_t^s = m_t + G_t(m_{t+1}^s - m_{t+1}^-)$
- smoothed covariance: $P_t^s = P_t + G_t(P_{t+1}^s - P_{t+1}^-)G_t^T$

In [None]:
filt_state = ListMNormal(filt_stateSR.mean, filt_stateSR.cov @ filt_stateSR.cov.mT)
pred_state = ListMNormal(pred_stateSR.mean, pred_stateSR.cov @ pred_stateSR.cov.mT)

In [None]:
#| export
def _smooth_gain_SR(A, P_C, P_m_C):
    return torch.cholesky_solve(A @ P_C @ P_C.mT, P_m_C).mT

In [None]:
K_p = _smooth_gain_SR(kSR.A, filt_stateSR[:, 0].cov, pred_stateSR[:, 0].cov)

In [None]:
test_close(
    _smooth_gain_SR(kSR.A, filt_stateSR[:, 0].cov, pred_stateSR[:, 0].cov),
    _smooth_gain(kSR.A, filt_state[:, 0].cov, pred_state[:, 0].cov)
)

In [None]:
test_close(filt_state[0,0].cov, filt_stateSR[0,0].cov @ filt_stateSR[0,0].cov.mT)

In [None]:
test_close(pred_state[0,0].cov, pred_stateSR[0,0].cov @ pred_stateSR[0,0].cov.mT)

In [None]:
def fuzz_gain_SR(n=10):
    errs = {'K': []}
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        k = KalmanFilter.init_from(kSR)
        dat = get_test_data(2, 5, 4)
        (_, f_cov), (_, p_cov) = k._filter_all(*dat)
        (_, f_covSR), (_, p_covSR) = kSR._filter_all(*dat)
        K = _smooth_gain(k.A, f_cov[:, 0], p_cov[:, 0])
        K_SR = _smooth_gain_SR(k.A, f_covSR[:, 0], p_covSR[:, 0])
        errs['K'].append((K - K_SR).abs().max().item())
    return pd.DataFrame(errs)

err = fuzz_gain_SR(10)

err.median(), err.max()

(K    1.920686e-14
 dtype: float64,
 K    4.352074e-14
 dtype: float64)

In [None]:
#| export
def _smooth_update_SR(A,                # [n_dim_state, n_dim_state]
                   filt_stateSR: MNormal,         # [n_dim_state] filtered state at time `t`
                   pred_stateSR: MNormal,         # [n_dim_state] state before filtering at time `t + 1` (= using the observation until time t)
                   next_smoothed_state: Normal, # [n_dim_state] smoothed state at time  `t+1`
                   ) -> MNormal:                # mean and cov of smoothed state at time `t`
    """Correct a pred state with a Kalman Smoother update"""
    # for now use standard smoother
    smooth_gain = _smooth_gain_SR(A, filt_stateSR.cov, pred_stateSR.cov)
    
    # Convert to full covariance matrix only the filter output
    filt_state_cov, pred_state_cov = map(lambda x: x @ x.mT, [filt_stateSR.cov, pred_stateSR.cov])
    
    m_p = _smooth_mean(smooth_gain, filt_stateSR.mean, pred_stateSR.mean, next_smoothed_state.mean)
    P_p = _smooth_cov(smooth_gain,  filt_state_cov, pred_state_cov, next_smoothed_state.cov)
    
    return MNormal(m_p, P_p)

In [None]:
m_p, P_p = _smooth_update_SR(kSR.A, filt_stateSR[:, 0, :], pred_stateSR[:, 0, :], filt_stateSR[:, 0, :])

In [None]:
test_close(filt_state[:, 0, :].mean, filt_stateSR[:, 0, :].mean)

In [None]:
_smooth_update(kSR.A, filt_state[0,0],  pred_state[0,0] , filt_state[0,0] )

MultiNormal(mean=tensor([[[0.6325],
         [0.9899],
         [0.5680],
         [0.6284]]], dtype=torch.float64, grad_fn=<AddBackward0>), cov=tensor([[[ 0.7480, -0.0959, -0.0944, -0.2622],
         [-0.0959,  0.6667,  0.1856,  0.0380],
         [-0.0944,  0.1856,  0.9192, -0.0280],
         [-0.2622,  0.0380, -0.0280,  1.3021]]], dtype=torch.float64,
       grad_fn=<DivBackward0>))

In [None]:
_smooth_update_SR(kSR.A, filt_stateSR[0,0],  pred_stateSR[0,0], filt_stateSR[0,0])

MultiNormal(mean=tensor([[[0.6325],
         [0.9899],
         [0.5680],
         [0.6284]]], dtype=torch.float64, grad_fn=<AddBackward0>), cov=tensor([[[-1.8502, -2.4971, -5.2813, -3.0804],
         [-2.4971, -1.6908, -4.8194, -2.8887],
         [-5.2813, -4.8194, -9.7933, -6.0797],
         [-3.0804, -2.8887, -6.0797, -2.6434]]], dtype=torch.float64,
       grad_fn=<DivBackward0>))

In [None]:
test_close((filt_stateSR.mean, pred_stateSR.mean), (filt_state.mean, pred_state.mean))

In [None]:
test_close(
    _smooth_update_SR(kSR.A, filt_stateSR[:, 0], pred_stateSR[:, 0], filt_state[:, 0]),
    _smooth_update(kSR.A, filt_state[:, 0], pred_state[:, 0], filt_state[:, 0])
)

In [None]:
def fuzz_update_SR(n=10):
    errs = {'mean': [], 'cov': []}
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        k = KalmanFilter.init_from(kSR)
        dat = get_test_data(2, 5, 4)
        f_state, p_state = k._filter_all(*dat)
        f_stateSR, p_stateSR = kSR._filter_all(*dat)
        mean, cov = _smooth_update(k.A, f_state[:, 0], p_state[:, 0], f_state[:, 1])
        meanSR, covSR = _smooth_update_SR(k.A, f_stateSR[:, 0], p_stateSR[:, 0], f_state[:, 1])
        errs['mean'].append((meanSR - mean).abs().max().item())
        errs['cov'].append((covSR - cov).abs().max().item())
    return pd.DataFrame(errs)

err = fuzz_update_SR(10)

err.median(), err.max()

(mean    1.154632e-14
 cov     2.771117e-13
 dtype: float64,
 mean    3.375078e-14
 cov     8.242296e-13
 dtype: float64)

### Smooth

In [None]:
#| export
def _smooth_SR(A, # `[n_dim_state, n_dim_state]`
            filt_stateSR: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `ms[t]` is the state estimate for time t given obs from times `[0...t]`
            pred_stateSR: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `m_ms[t]` is the state estimate for time t given obs from times `[0...t-1]`
            until=0, # iteration where to stop the smoother
            cov_checker = None
           ) -> ListMNormal: # `[n_timesteps, n_dim_state]` Smoothed state 
    """Apply the Kalman Smoother"""
    n_obs = pred_stateSR.mean.shape[1]

    smoothed_state = ListMNormal(torch.zeros_like(filt_stateSR.mean), torch.zeros_like(filt_stateSR.cov))
    
    # For the last timestep cannot use the smoother
    smoothed_state.mean[:,-1] = filt_stateSR.mean[:,-1]
    smoothed_state.cov[:,-1] = filt_stateSR.cov[:,-1] @ filt_stateSR.cov[:,-1].mT
    
    for t in reversed(range(until, n_obs - 1)):
        (smoothed_state.mean[:,t], smoothed_state.cov[:,t]) = (
            _smooth_update_SR(
                A,
                filt_stateSR[:,t],
                pred_stateSR[:,t + 1],
                smoothed_state[:,t+1],
            )
        )
        if cov_checker is not None:
            cov_checker.check(smoothed_state.cov[:,t], name="smooth", t=t)
    return smoothed_state

In [None]:
smooth_state = _smooth_SR(kSR.A,  filt_stateSR, pred_stateSR)

In [None]:
is_sr(filt_stateSR.cov, filt_state.cov)

True

In [None]:
s_mean, s_cov =  _smooth(kSR.A,  filt_state, pred_state)
s_meanSR, s_covSR =  _smooth_SR(kSR.A,  filt_stateSR, pred_stateSR)
test_close(s_cov, s_covSR)
(s_cov - s_covSR).median(), (s_cov - s_covSR).max()

(tensor(0., dtype=torch.float64, grad_fn=<MedianBackward0>),
 tensor(3.5527e-15, dtype=torch.float64, grad_fn=<MaxBackward1>))

In [None]:
show_as_row(smooth_state.mean[0][0], smooth_state.cov[0][0])

In [None]:
show_as_row(smooth_state.mean.shape, smooth_state.cov.shape)

### KalmanFilter method

In [None]:
mask

tensor([[[ True, False, False],
         [ True, False,  True],
         [False,  True, False],
         [False, False, False],
         [False, False,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True, False],
         [ True, False,  True],
         [ True,  True,  True]],

        [[ True,  True,  True],
         [ True,  True, False],
         [ True,  True,  True],
         [ True, False,  True],
         [False,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True],
         [ True,  True,  True]]])

In [None]:
torch.argwhere((~mask).any(-1).any(0)).min()

tensor(0)

In [None]:
torch.ones(1,2,0)

tensor([], size=(1, 2, 0))

In [None]:
#| export
@patch
def smooth(self: KalmanFilterSR,
           obs: Tensor,
           mask: Tensor,
           control: Tensor
          ) -> ListMNormal: # `[n_timesteps, n_dim_state]` smoothed state
        
    """Kalman Filter Smoothing"""

    filt_stateSR, pred_stateSR = self._filter_all(obs, mask, control)
    # run smoother until there is a gap
    # if self.pred_only_gap:
    #     gap_idx = torch.argwhere((~mask).any(-1).any(0))
    #     # no data to predict       
    #     if gap_idx.numel() == 0: return ListMNormal(torch.zeros(0), torch.zeros(0))
    #     until = gap_idx.min()
    # else:
    #     until = 0
    until = 0
        
    smoothed_state = _smooth_SR(self.A,
                   filt_stateSR, pred_stateSR,
                    until=until,
                   cov_checker = self.cov_checker)
    return smoothed_state

In [None]:
smoothed_state = kSR.smooth(data, mask, control)

In [None]:
show_as_row(smoothed_state.mean.shape, smoothed_state.cov.shape)

In [None]:
smoothed_state.cov.isnan().any()

tensor(False)

In [None]:
smoothed_state.mean.sum().backward(retain_graph=True)
A.grad

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], dtype=torch.float64)

In [None]:
smoothed_state_stand = k.smooth(data, mask, control)

In [None]:
(smoothed_state.mean - smoothed_state_stand.mean).max()
(smoothed_state.cov - smoothed_state_stand.cov).max()

tensor(0.6755, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
def fuzz_smooth_SR(n=10):
    errs = {'mean': [], 'cov': []}
    for _ in range(n):
        kSR = KalmanFilterSR.init_random(5,10,4)
        k = KalmanFilter.init_from(kSR)
        dat = get_test_data(20, 5, 4)
        mean, cov = k.smooth(*dat)
        meanSR, covSR = kSR.smooth(*dat)
        errs['mean'].append((meanSR - mean).abs().max().item())
        errs['cov'].append((covSR - cov).abs().max().item())
    return pd.DataFrame(errs)

err = fuzz_smooth_SR(10)

err.median(), err.max()

(mean    9.733464e-12
 cov     8.838374e-12
 dtype: float64,
 mean    4.984400e-09
 cov     4.896766e-09
 dtype: float64)

## Predict

predict can be vectorized across both the batch and the timesteps, except for timesteps that require conditional predictions

### Obs from State

In [None]:
#| export
@patch
def _obs_from_state(self: KalmanFilterSR, state: ListMNormal):
    
    mean = self.H @ state.mean + self.d
    
    if (self.use_sr_pred if hasattr(self, 'use_sr_pred') else False):
        HP = self.H @ state.cov
        W = torch.cat([HP, self.R_C.expand(*HP.shape[:-2], -1, -1)], dim=-1)
        cov = torch.linalg.qr(W.mT).R.mT
    else: # actually compute the covariance matrix
        cov = self.H @ state.cov @ self.H.mT + self.R
    
    if self.cov_checker is not None:
        for c in cov: # this is batched and for all timestamps
            self.cov_checker.check(c, caller='predict')
    
    return ListMNormal(mean, cov)

In [None]:
smoothed_state.mean.shape, smoothed_state.cov.shape

(torch.Size([2, 10, 4, 1]), torch.Size([2, 10, 4, 4]))

In [None]:
pred_obs0 = kSR._obs_from_state(smoothed_state)
pred_obs0.mean.shape, pred_obs0.cov.shape

(torch.Size([2, 10, 3, 1]), torch.Size([2, 10, 3, 3]))

In [None]:
pred_obs0.cov.isnan().any()

tensor(False)

In [None]:
gap_mask = ~mask.all(-1)

In [None]:
gap_mask.shape

torch.Size([2, 10])

Predict has various modes:

- `pred_only_gap` is True, returns predictions only where the mask is False
    - `use_conditional` returns a list (for each batch) of list (for each time stamp) of Tensors of shape [1, gap_len] 
    - `use_conditional` is False, returns a list (for each batch) of Tensor of shape [n_times_gap, n_dim_obs] 

### Masked Batch

In [None]:
#| export
def _masked2batch(x: Tensor, # (`n_time_missing` for every `batch`, n, [n])
                  mask: Tensor, # (`n_batch`, `n_times`, `n`)
                 ) -> list[list[Tensor]]: 
    """transform a flattened masked prediction, into a prediction with a batch shape and select only predictions where the mask is false"""
    batches = []
    n_prev = 0
    gap_mask = ~mask.all(-1)
    for i, n in enumerate(gap_mask.sum(-1)):
        batch = x[n_prev:n_prev+n]
        mask_batch = mask[i][gap_mask[i]]
        assert  (mask_batch == mask[gap_mask][n_prev:n_prev+n]).all() # sanity check that the function is working
        times = []
        for t_pred, t_mask in zip(batch, mask_batch):
            times.append(t_pred[~t_mask] if t_pred.dim() == 1 else t_pred[~t_mask, :][:,~t_mask])
        batches.append(times)
        n_prev += n
    return batches

In [None]:
gap_mask = ~mask.all(-1)

In [None]:
from pprint import pp

In [None]:
pp(_masked2batch(mask[gap_mask], mask)[0])

[tensor([False, False]),
 tensor([False]),
 tensor([False, False]),
 tensor([False, False, False]),
 tensor([False, False]),
 tensor([False]),
 tensor([False])]


In [None]:
str(_masked2batch(mask[gap_mask], mask)[0])

'[tensor([False, False]), tensor([False]), tensor([False, False]), tensor([False, False, False]), tensor([False, False]), tensor([False]), tensor([False])]'

In [None]:
show_as_row(all_mask = mask[0] , only_gap=_masked2batch(mask[gap_mask], mask)[0])

### Predict

In [None]:
#| export
@patch
def predict(self: KalmanFilterSR, obs, mask, control, smooth=True):
    """Predicted observations at all times """
    if self.use_conditional and not self.pred_only_gap:
        raise ValueError("Kalman Filter predict cannot have conditional predictions and all predictions at the same time")
        
    state = self.smooth(obs, mask, control) if smooth else self.filter(obs, mask, control)
    self.use_sr_pred = not smooth
    if not smooth:
        state = ListMNormal(state.mean, state.cov @ state.cov.mT) # convert to actual covariance
    
    if self.pred_only_gap:
        gap_mask = ~mask.all(-1)
        # this destroy batches! so need to do some magic after
        state = state[gap_mask]
    pred_obs = self._obs_from_state(state)
    pred_obs.mean.squeeze_(-1)
    pred_mean, pred_cov = pred_obs.mean, pred_obs.cov
    if self.use_sr_pred:
        pred_cov = pred_cov @ pred_cov.mT
    # pred_std = cov2std(pred_cov)
    
    if self.use_conditional:
        obs, mask, control = self._parse_obs(obs, mask, control)
        # conditional predictions are slow, do only if some obs are missing 
        # cond_mask = torch.logical_xor(mask.all(-1), mask.any(-1))
        cond_mask = mask[gap_mask]
        if cond_mask.any():
            # this cannot be batched so returns a list
            cond_preds = cond_gaussian_batched(
                pred_obs, obs[gap_mask].squeeze(-1), cond_mask)

            for i, c_pred in enumerate(cond_preds):
                m = ~cond_mask[i]
                pred_mean[i][m] = c_pred.mean
                pred_cov[i][m,:][:,m]= c_pred.cov
    
    if self.pred_only_gap:
        pred_mean = _masked2batch(pred_mean, mask)
        pred_cov =  _masked2batch(pred_cov, mask)
    return ListMNormal(pred_mean, pred_cov) if not self.pred_std else ListNormal(pred_mean, cov2std(pred_cov))

### Exploration

In [None]:
#| export
from fastai.learner import replacing_yield
from fastcore.xtras import ContextManagers
from contextlib import contextmanager

In [None]:
#| export
@contextmanager
def replacing_ctx(*args): return replacing_yield(*args)

def with_settings(k, **kwargs):
    return ContextManagers([replacing_ctx(k, attr, v) for attr,v in kwargs.items()])

In [None]:
with with_settings(kSR, use_conditional=False, pred_only_gap=False):
    pred_mean, pred_cov = kSR.predict(data, mask, control)
show_as_row(mean= pred_mean.shape, cov = pred_cov.shape)

In [None]:
with with_settings(kSR, use_conditional=False, pred_only_gap=True):
    pred_mean, pred_cov = kSR.predict(data, mask, control)
show_as_row(mean= pred_mean[0], cov = pred_cov[0])

In [None]:
with with_settings(kSR, use_conditional=True, pred_only_gap=True):
    pred_mean, pred_cov = kSR.predict(data, mask, control)
show_as_row(mean= pred_mean[0], cov = pred_cov[0])

In [None]:
with with_settings(kSR, use_conditional = True, pred_only_gap = False):
    test_fail(kSR.predict, [data, mask, control]) # this params combination is invalid

#### Gap only prediction

copy paste from other notebook to make visualization easier

In [None]:
#| export
def buffer_pred_single(preds: list[Tensor],
                masks: Tensor) -> Tensor:
    """For predictions are for gaps only add buffer of `Nan` so they have same shape of targets"""
    all_pred = torch.empty(masks.shape, dtype=preds[0][0].dtype).fill_(torch.nan)
    i_p = 0
    for i, (mask) in enumerate(masks.cpu()):
        if not mask.all():
            all_pred[i][~mask] = preds[i_p].detach().cpu()
            i_p += 1
    assert i_p == len(preds)
    return all_pred

In [None]:
#| export
def buffer_pred(preds: list[list[Tensor]],
                masks: Tensor) -> Tensor:
    """For predictions are for gaps only add buffer of `Nan` so they have same shape of targets"""
    return torch.stack([buffer_pred_single(pred, mask) for pred, mask in zip(preds, masks)])

In [None]:
with with_settings(kSR, use_conditional=False, pred_only_gap=True):
    pred_m_gap, _ = kSR.predict(data, mask, control)
    
with with_settings(kSR, use_conditional=False, pred_only_gap=False):
    pred_m, _ = kSR.predict(data, mask, control)

soooooooo this is a problem!!! those should be the same

In [None]:
show_as_row(gap = buffer_pred(pred_m_gap, mask), no_gap = pred_m) 

In [None]:
state = kSR.smooth(data, mask, control)

In [None]:
state.mean.shape

torch.Size([2, 10, 4, 1])

In [None]:
gap_mask = ~mask.all(-1)
# this destroy batches! so need to do some magic after
state_gap = state[gap_mask]

In [None]:
state_gap.mean.shape

torch.Size([10, 4, 1])

In [None]:
pred_obs = kSR._obs_from_state(state)
pred_obs.mean.squeeze_(-1)

tensor([[[ 0.0322, -0.3666,  0.3765],
         [ 0.1475, -0.0844,  0.4320],
         [ 0.1625, -0.0796,  0.3119],
         [-0.1102, -0.5474, -0.1901],
         [-0.1439, -0.4996,  0.0158],
         [ 0.1736, -0.0500,  0.1638],
         [ 0.0998, -0.2115,  0.2026],
         [ 0.2642,  0.0727,  0.4966],
         [ 0.2667,  0.0202,  0.5999],
         [ 0.5621,  0.4854,  1.0911]],

        [[ 0.0523, -0.3558,  0.1773],
         [ 0.0572, -0.2074,  0.2059],
         [ 0.0361, -0.2809,  0.0575],
         [-0.0671, -0.3611,  0.1202],
         [ 0.1032, -0.2193, -0.1928],
         [-0.1298, -0.5874, -0.3787],
         [ 0.0450, -0.2554, -0.2029],
         [ 0.0938, -0.0932,  0.3782],
         [ 0.2229, -0.1440, -0.1133],
         [ 0.3722,  0.3017,  1.1004]]], dtype=torch.float64,
       grad_fn=<SqueezeBackward3>)

In [None]:
pred_obs_gap = kSR._obs_from_state(state_gap)
pred_obs_gap.mean.squeeze_(-1)

tensor([[ 0.0322, -0.3666,  0.3765],
        [ 0.1475, -0.0844,  0.4320],
        [ 0.1625, -0.0796,  0.3119],
        [-0.1102, -0.5474, -0.1901],
        [-0.1439, -0.4996,  0.0158],
        [ 0.2642,  0.0727,  0.4966],
        [ 0.2667,  0.0202,  0.5999],
        [ 0.0572, -0.2074,  0.2059],
        [-0.0671, -0.3611,  0.1202],
        [ 0.1032, -0.2193, -0.1928]], dtype=torch.float64,
       grad_fn=<SqueezeBackward3>)

In [None]:
(pred_obs[gap_mask].mean == pred_obs_gap.mean).all() # so far good

tensor(True)

In [None]:
pred_obs.mean.shape

torch.Size([2, 10, 3])

In [None]:
mask.shape

torch.Size([2, 10, 3])

In [None]:
show_as_row(pred = buffer_pred(_masked2batch(pred_obs_gap.mean, mask), mask), mask = mask, all = pred_obs.mean)

In [None]:
pred_gap_buff = buffer_pred(_masked2batch(pred_obs_gap.mean, mask), mask)
mask_na = ~pred_gap_buff.isnan()

In [None]:
test_close(pred_gap_buff[mask_na], pred_obs.mean[mask_na])

In [None]:
with with_settings(kSR, use_conditional=False, pred_only_gap=True):
    pred_gap_buff = buffer_pred(kSR.predict(data, mask, control).mean, mask)
mask_na = ~pred_gap_buff.isnan()
with with_settings(kSR, use_conditional=False, pred_only_gap=False):
    pred_ng = kSR.predict(data, mask, control).mean

In [None]:
test_close(pred_gap_buff[mask_na], pred_ng[mask_na])

In [None]:
show_as_row(pred = pred_gap_buff, mask = mask, all = pred_ng)

#### Conditional

In [None]:
with with_settings(kSR, use_conditional=False, pred_only_gap=True):
    pred_gap = buffer_pred(kSR.predict(data, mask, control).mean, mask)
with with_settings(kSR, use_conditional=True, pred_only_gap=True):
    pred_gap_cond = buffer_pred(kSR.predict(data, mask, control).mean, mask)

In [None]:
show_as_row(no_conditional = pred_gap, conditional = pred_gap_cond)

In [None]:
mask[gap_mask]

tensor([[ True, False, False],
        [ True, False,  True],
        [False,  True, False],
        [False, False, False],
        [False, False,  True],
        [ True,  True, False],
        [ True, False,  True],
        [ True,  True, False],
        [ True, False,  True],
        [False,  True,  True]])

In [None]:
data[gap_mask]

tensor([[0.8775,    nan,    nan],
        [0.6706,    nan, 0.9272],
        [   nan, 0.4967,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan, 0.4760],
        [0.9991, 0.1775,    nan],
        [0.6734,    nan, 0.6468],
        [0.3725, 0.2052,    nan],
        [0.5927,    nan, 0.6441],
        [   nan, 0.9132, 0.0329]], dtype=torch.float64)

In [None]:
data[gap_mask]

tensor([[0.8775,    nan,    nan],
        [0.6706,    nan, 0.9272],
        [   nan, 0.4967,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan, 0.4760],
        [0.9991, 0.1775,    nan],
        [0.6734,    nan, 0.6468],
        [0.3725, 0.2052,    nan],
        [0.5927,    nan, 0.6441],
        [   nan, 0.9132, 0.0329]], dtype=torch.float64)

In [None]:
assert is_posdef(pred_cov[0][0]).all()

In [None]:
state = kSR.smooth(data, mask, control)
assert not state.cov.isnan().any()
kSR.use_sr_pred = False
pred_obs = kSR._obs_from_state(state)
assert not pred_obs.cov.isnan().any()

In [None]:
cov2std(pred_obs.cov).isnan().any()

tensor(False)

In [None]:
assert not kSR.predict(data, mask, control).cov.isnan().any()

In [None]:
assert not kSR.predict(data, mask, control, smooth=False).cov.isnan().any()

In [None]:
#| export
@patch
def _predict_filter(self: KalmanFilterSR, data, mask, control):
    """Predict every obsevation using only the filter step"""
    # use the predicted state not the filtered state!
    self.use_sr_pred = True
    filt_state, pred_state = self._filter_all(data, mask, control)
    mean, cov = self._obs_from_state(pred_state)
            
    return ListNormal(mean.squeeze(-1), cov @ cov.mT) # convert to actual covariances

In [None]:
kSR._predict_filter(data, mask, control).std.shape

torch.Size([2, 10, 3, 3])

In [None]:
kSR.use_conditional = False

In [None]:
pred = kSR.predict(data, mask, control, smooth=True)

In [None]:
pred.mean.shape, pred.std.shape

AttributeError: 'ListMultiNormal' object has no attribute 'std'

In [None]:
filt_state, pred_state = kSR._filter_all(data, mask, control)
mean, cov = kSR._obs_from_state(pred_state)

In [None]:
is_posdef(cov @ cov.mT)

In [None]:
pred

## Debug nan

In [None]:
import polars as pl
import altair as alt

In [None]:
kSR.predict(*get_test_data(200)).cov.isnan().any()

In [None]:
data.shape

In [None]:
reset_seed()

### SR Filter

In [None]:
nan = [{'nan':kSR.predict(*get_test_data(n), smooth=smooth).cov.isnan().any().item(), 'n': n, 'rep': rep, 'smooth': smooth}
       for n in [10, 20, 30, 40, 50, 55, 60, 70, 100, 150, 200] for rep in range(10) for smooth in [True, False]]

In [None]:
nan_df = pl.DataFrame(nan).groupby(['nan', 'n', 'smooth']).count().to_pandas()

In [None]:
alt.Chart(nan_df).mark_line().encode(x='n:Q', y='count', color='nan', column='smooth')

### Standard Filter

In [None]:
k = KalmanFilter.init_from(kSR)

In [None]:
nan = [{'nan':k.predict(*get_test_data(n), smooth=smooth).std.isnan().any().item(), 'n': n, 'rep': rep, 'smooth': smooth}
       for n in [10, 20, 30, 40, 50, 55, 60, 70, 100, 150, 200] for rep in range(10) for smooth in [True, False]]

In [None]:
nan_df = pl.DataFrame(nan).groupby(['nan', 'n', 'smooth']).count().to_pandas()

In [None]:
alt.Chart(nan_df).mark_line().encode(x='n:Q', y='count', color='nan', column='smooth')

so the standard filter is working better than the SR for the smoothing (with this parameter setting), so there is a way to make the sr smoother not that bad

But I just want to see how we have nan

So the problem is that we have a negative number on the diagonal ... so is not positive definite and even the standard deviation is nan

In [None]:
for i in range(20):
    dat = get_test_data(10)
    pred = kSR.predict(*dat)
    if pred.cov.isnan().any():
        print(i)
        break

In [None]:
k = KalmanFilter.init_from(kSR)

In [None]:
for p1, p2 in zip(k.parameters(), kSR.parameters()):
    test_close(p1,p2)                

In [None]:
f_state_stand = k.filter(*dat)
s_state_stand = k.smooth(*dat)
pred_stand = k.predict(*dat)

In [None]:
pred_stand[1, -1].std

In [None]:
(f_state_stand.mean - filt_state.mean).mean()

In [None]:
filt_state = kSR.filter(*dat)

In [None]:
s_state = kSR.smooth(*dat)

In [None]:
is_posdef(s_state.cov)

In [None]:
is_posdef(s_state[1, -1].cov)

In [None]:
filt_state[1,-1].cov

In [None]:
kSR.H @ s_state[1, -1].cov @ kSR.H.mT

In [None]:
kSR.use_sr_pred = False
pred_obs = kSR._obs_from_state(s_state)
pred_obs[1,-1].cov

In [None]:
kSR.predict(*dat).

In [None]:
pred.std[1,-1]

## Additional

### Constructors

#### Simple parameters

In [None]:
#| export
@patch(cls_method=True)
def init_simple(cls: KalmanFilter|KalmanFilterSR,
                n_dim, # n_dim_obs and n_dim_state
                dtype=torch.float64):
    """Simplest version of kalman filter parameters"""
    return cls(
        A =     torch.eye(n_dim, dtype=dtype),
        b =        torch.zeros(n_dim, dtype=dtype),        
        Q =        torch.eye(n_dim, dtype=dtype),        
        H =       torch.eye(n_dim, dtype=dtype),
        d =          torch.zeros(n_dim, dtype=dtype),          
        R =          torch.eye(n_dim, dtype=dtype),            
        B =     torch.eye(n_dim, dtype=dtype),
        m0 =  torch.zeros(n_dim, dtype=dtype),        
        P0 =   torch.eye(n_dim, dtype=dtype),
    )

In [None]:
KalmanFilter.init_simple(2).state_dict()

#### Local slope

Local slope models are an extentions of local level model that in the state variable keep track of also the slope

Given $n$ as the number of dimensions of the observations

The transition matrix (`A`) is:

$$A = \left[\begin{array}{cc}I & I \\ 0 & I\end{array}\right]$$

where:

- $I \in \mathbb{R}^{n \times n}$
- $A \in \mathbb{R}^{2n \times 2n}$

the state $x \in \mathbb{R}^{2N \times 1}$ where the upper half keep track of the level and the lower half of the slope. $A \in \mathbb{R}^2N \times 2N$

the observation matrix (`H`) is:

$$H = \left[\begin{array}{cc}I & 0 \end{array}\right]$$

For the multivariate case the 1 are replaced with an identiy matrix


assuming that the control has the same dimensions of the observations then if we are doing a local slope model we have $B \in \mathbb{R}^{state \times contr}$:
$$ B = \begin{bmatrix} -I & I \\ 0 & 0 \end{bmatrix}$$

In [None]:
#| export
from torch import hstack, eye, vstack, ones, zeros, tensor
from functools import partial
from sklearn.decomposition import PCA

In [None]:
#| exporti
def set_dtype(*args, dtype=torch.float64):
    return [partial(arg, dtype=dtype) for arg in args] 

eye, ones, zeros, tensor = set_dtype(eye, ones, zeros, tensor)

In [None]:
#| export
@patch(cls_method=True)
def init_local_slope_pca(cls: KalmanFilter|KalmanFilterSR,
                n_dim_obs, # n_dim_obs and n_dim_contr
                n_dim_state: int, # n_dim_state
                n_dim_contr:int, #n dim control
                df_pca: pd.DataFrame|None = None, # dataframe for PCA init, None no PCA init,
                pca_contr:int = False,
                **kwargs
            ):
    """Local Slope + PCA init"""
    
    if df_pca is not None:
        comp = PCA(n_dim_state).fit(df_pca).components_
        H = tensor(comp.T) # transform state -> obs 
        if pca_contr:
            if n_dim_obs != n_dim_contr:
                raise ValueError("n dim obs and n dim contr must be the same for pca of control")
            else:
                B = torch.tensor(comp)
        else:
            B = eye(n_dim_contr)
    else:
        H, B = eye(n_dim_obs), eye(n_dim_contr)
        
    return cls(
        A =     vstack([hstack([eye(n_dim_state),                eye(n_dim_state)]),
                                   hstack([zeros(n_dim_state, n_dim_state), eye(n_dim_state)])]),
        b =        zeros(n_dim_state * 2),        
        Q =        eye(n_dim_state * 2)*.1,        
        H =       hstack([H, zeros(n_dim_obs, n_dim_state)]),
        d =          zeros(n_dim_obs),          
        R =          eye(n_dim_obs)*.01,            
        B =     vstack([hstack([-B,                  B]),
                        hstack([ zeros(2 * n_dim_state-n_dim_contr,n_dim_contr), zeros(2 * n_dim_state-n_dim_contr, n_dim_contr)])]),
        m0 =  zeros(n_dim_state * 2),        
        P0 =   eye(n_dim_state * 2) * 3,
        **kwargs
    ) 

In [None]:
KalmanFilter.init_local_slope_pca(2,2,pd.DataFrame([[1,2], [2,4]])).state_dict()

## Export

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()