# Kalman Filter
> Implementation of Kalman filters using pytorch and parameter optimizations with gradient descend

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
#| default_exp kalman.filter

In [None]:
#| export
from fastcore.test import *
from fastcore.basics import *
from meteo_imp.utils import *
from meteo_imp.gaussian import *
from meteo_imp.data_preparation import MeteoDataTest
from typing import *
from functools import partial

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

## Introduction

The models uses a latent state variable $x$ that is modelled over time, to impute gaps in $y$

The assumption of the model is that the state variable at time $x_t$ depends only on the last state $x_{t-1}$ and not on the previous states.

### Equations

The equations of the model are:

$$\begin{align} p(x_t | x_{t-1}) & = \mathcal{N}(Ax_{t-1} + b, Q) \\
p(y_t | x_t) & = \mathcal{N}(Hx_t + d, R) \end{align}$$


where:

- $A$ is the `trans_matrix`
- $b$ is the `trans_offset`
- $Q$ is the `trans_cov`
- $H$ is the `obs_trans` 
- $d$ is the `obs_off`
- $R$ is the `obs_cov`

in addition the model has also the parameters of the initial state that are used to initialize the filter:

- `init_state_mean`
- `init_state_cov`

The Kalman filter has 3 steps:

- filter (updating the state at time t with observations till time t-1)
- update (update the state at time t using the observation at time t)
- smooth (update the state using the observations at time t+1)

In case of missing data the update step is skipped.

After smoothing the missing data at time t ($y_t$) can be imputed from the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

The Kalman Filter is an algorithm designed to estimate $P(x_t | y_{0:t})$.  As all state transitions and obss are linear with Gaussian distributed noise, these distributions can be represented exactly as Gaussian distributions with mean `filt_state_means[t]` and covs `filt_state_covs[t]`.
Similarly, the Kalman Smoother is an algorithm designed to estimate $P(x_t | y_{0:t-1})$



## Main class

TODO: fill nans with 0 for all data

In [None]:
#| export
class KalmanFilter(torch.nn.Module):
    """Base class for Kalman Filter and Smoother using PyTorch"""
    def __init__(self,
            trans_matrix: Tensor,                    # [n_dim_state,n_dim_state] $A$, state transition matrix 
            obs_matrix: Tensor,                      # [n_dim_obs, n_dim_state] $H$, observation matrix
            contr_matrix: Tensor,                    # [n_dim_state, n_dim_contr] $B$ control matrix
            trans_cov: Tensor,                       # [n_dim_state, n_dim_state] $Q$, state trans covariance matrix
            obs_cov: Tensor,                         # [n_dim_obs, n_dim_obs] $R$, observations covariance matrix
            trans_off: Tensor,                       # [n_dim_state] $b$, state transition offset
            obs_off: Tensor,                         # [n_dim_obs] $d$, observations offset
            init_state_mean: Tensor,                 # [n_dim_state] $\mu_0$
            init_state_cov: Tensor,                  # [n_dim_state, n_dim_state] $\Sigma_0$
            n_dim_state: int = None,                 # Number of dimensions for state - defaults to 1 if cannot be infered from parameters
            n_dim_obs: int = None,                   # Number of dimensions for observations - defaults to 1 if cannot be infered from parameters
            var_names: Iterable[str]|None = None,    # Names of variables for printing 
            contr_names: Iterable[str]|None = None,  # Names of control variables for printing
            cov_checker: CheckPosDef = CheckPosDef() # Check covariance at every step
                ):
        
        super().__init__()
        store_attr("var_names, contr_names")
        # check parameters are consistent
        self.n_dim_state = determine_dimensionality(
            [(trans_matrix, array2d, -2),
             (trans_off, array1d, -1),
             (trans_cov, array2d, -2),
             (init_state_mean, array1d, -1),
             (init_state_cov, array2d, -2),
             (obs_matrix, array2d, -1)],
            n_dim_state
        )
        self.n_dim_obs = determine_dimensionality(
            [(obs_matrix, array2d, -2),
             (obs_off, array1d, -1),
             (obs_cov, array2d, -2)],
            n_dim_obs
        )
        
        self.n_dim_contr = determine_dimensionality([(contr_matrix, array2d, -1)], None)
        
        params = {
        #name               value             constraint
        'trans_matrix':     [trans_matrix,    None        ],
        'trans_off':        [trans_off,       None        ],
        'trans_cov':        [trans_cov,       PosDef()    ],
        'contr_matrix':     [contr_matrix,    None        ],
        'obs_matrix':       [obs_matrix,      None        ],
        'obs_off':          [obs_off,         None        ],
        'obs_cov':          [obs_cov,         DiagPosDef()],
        'init_state_mean':  [init_state_mean, None        ],
        'init_state_cov':   [init_state_cov,  PosDef()    ],
        }
        self._init_params(params)
        
        self.cov_checker = cov_checker
        
    def _init_params(self, params):
        for name, (value, constraint) in params.items():
            if constraint is not None:
                name, value = self._init_constraint(name, value, constraint)
            self._init_param(name, value, train=True)    
    
    def _init_param(self, param_name, value, train):
        self.register_parameter(param_name, torch.nn.Parameter(value, requires_grad=train))
    
    ### === Constraints utils
    def _init_constraint(self, param_name, value, constraint):
        name = param_name + "_raw"
        value = constraint.inverse_transform(value)
        setattr(self, param_name + "_constraint", constraint)
        return name, value
    
    def _get_constraint(self, param_name):
        constraint = getattr(self, param_name + "_constraint")
        raw_value = getattr(self, param_name + "_raw")
        return constraint.transform(raw_value)
    
    def _set_constraint(self, param_name, value, train=True):
        constraint = getattr(self, param_name + "_constraint")
        raw_value = constraint.inverse_transform(value)
        self._init_param(param_name + "_raw", raw_value, train)
    
    ### === Convenience functions to get and set parameters that have a constraint
    @property
    def trans_cov(self): return self._get_constraint('trans_cov')
    @trans_cov.setter
    def trans_cov(self, value): return self._set_constraint('trans_cov', value)

    @property
    def obs_cov(self): return self._get_constraint('obs_cov')
    @obs_cov.setter
    def obs_cov(self, value): return self._set_constraint('obs_cov', value)
    
    @property
    def init_state_cov(self): return self._get_constraint('init_state_cov')
    @init_state_cov.setter
    def init_state_cov(self, value): return self._set_constraint('init_state_cov', value)
    
    
    ### === Utility Func    
    def _parse_obs(self, obs, mask=None):
        """maybe get mask from `nan`"""
        if mask is None: mask = ~torch.isnan(obs)
        # TODO incorrect support for 2d input!!!!!!
        assert obs.dim() == 3
        obs, mask = torch.atleast_3d(obs), torch.atleast_3d(mask)
        return obs, mask
    
    def __repr__(self):
        return f"""Kalman Filter
        N dim obs: {self.n_dim_obs}, N dim state: {self.n_dim_state}, N dim contr: {self.n_dim_contr}"""

## Constructors

Giving all the parameters manually to the `KalmanFilter` init method is not convenient, hence we are having some methods that help initize the class

### Random parameters

In [None]:
#| export
@patch(cls_method=True)
def init_random(cls: KalmanFilter, n_dim_obs, n_dim_state, n_dim_contr, dtype=torch.float32, **kwargs):
    """kalman filter with random parameters"""
    params = {
        'trans_matrix':    torch.rand(n_dim_state, n_dim_state, dtype=dtype),
        'trans_off':       torch.rand(n_dim_state, dtype=dtype),        
        'trans_cov':       to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),        
        'contr_matrix':    torch.rand(n_dim_state, n_dim_contr, dtype=dtype),
        'obs_matrix':      torch.rand(n_dim_obs, n_dim_state, dtype=dtype),
        'obs_off':         torch.rand(n_dim_obs, dtype=dtype),          
        'obs_cov':         to_posdef(torch.rand(n_dim_obs, n_dim_obs, dtype=dtype)),            
        'init_state_mean': torch.rand(n_dim_state, dtype=dtype),        
        'init_state_cov':  to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),
    } 
    return cls(**params, **kwargs) 
        

In [None]:
k = KalmanFilter.init_random(3,4, 3, dtype=torch.float64)
k

Kalman Filter
        N dim obs: 3, N dim state: 4, N dim contr: 3

In [None]:
k.init_state_cov

tensor([[0.7817, 0.6556, 0.5306, 0.8167],
        [0.6556, 0.6433, 0.6856, 0.6853],
        [0.5306, 0.6856, 1.1192, 0.5465],
        [0.8167, 0.6853, 0.5465, 0.8930]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

check that assigment works :)

In [None]:
k.init_state_cov = to_posdef(torch.rand(4, 4, dtype=torch.float64))

In [None]:
k.init_state_cov_raw

Parameter containing:
tensor([[0.9359, 0.0000, 0.0000, 0.0000],
        [0.8074, 0.5687, 0.0000, 0.0000],
        [0.5154, 0.5700, 0.7047, 0.0000],
        [0.7349, 0.1574, 0.1750, 0.3656]], dtype=torch.float64,
       requires_grad=True)

In [None]:
list(k.named_parameters())

[('trans_matrix',
  Parameter containing:
  tensor([[0.6023, 0.3413, 0.3698, 0.6603],
          [0.6889, 0.2815, 0.5281, 0.8140],
          [0.2740, 0.7634, 0.2004, 0.6965],
          [0.8038, 0.2749, 0.6018, 0.9625]], dtype=torch.float64,
         requires_grad=True)),
 ('trans_off',
  Parameter containing:
  tensor([0.0022, 0.9561, 0.7751, 0.1734], dtype=torch.float64,
         requires_grad=True)),
 ('trans_cov_raw',
  Parameter containing:
  tensor([[ 1.4683,  0.0000,  0.0000,  0.0000],
          [ 0.9920,  0.2792,  0.0000,  0.0000],
          [ 0.8064, -0.0434,  0.2780,  0.0000],
          [ 0.4284,  0.3965,  0.1591,  0.0303]], dtype=torch.float64,
         requires_grad=True)),
 ('contr_matrix',
  Parameter containing:
  tensor([[2.6977e-01, 7.9453e-01, 1.6282e-01],
          [2.3017e-01, 6.1964e-01, 4.0785e-01],
          [5.3003e-01, 4.5534e-01, 4.3027e-01],
          [7.5926e-01, 2.3556e-04, 1.8805e-01]], dtype=torch.float64,
         requires_grad=True)),
 ('obs_matrix',
  Pa

### Test data

In [None]:
#| exporti
def get_test_data(n_obs = 10, n_dim_obs=3, n_dim_contr = 3, p_missing=.3, bs=2, dtype=torch.float32, device='cpu'):
    data = torch.rand(bs, n_obs, n_dim_obs, dtype=dtype, device=device)
    mask = torch.rand(bs, n_obs, n_dim_obs, device=device) > p_missing
    control = torch.rand(bs, n_obs, n_dim_contr, dtype=dtype, device=device)
    data[~mask] = torch.nan # ensure that the missing data cannot be used
    return data, mask, control

In [None]:
reset_seed()
data, mask, control = get_test_data(dtype=torch.float64)
show_as_row(data, mask, control)

## Filter

### Filter predict

Probability of state at time `t` given state a time `t-1` 

$p(x_t) = \mathcal{N}(x_t; m_t^-, P_t^-)$ where:

- predicted state mean: $m_t^- = Am_{t-1} + B c_t + b$  

- predicted state covariance: $P_t^- = AP_{t-1}A^T + Q$

In [None]:
#| export
def unsqueeze_iter(*args, dim): return list(map(partial(torch.unsqueeze, dim=dim), args))
unsqueeze_first = partial(unsqueeze_iter, dim=0)
unsqueeze_last = partial(unsqueeze_iter, dim=-1)

In [None]:
#| export
from datetime import datetime
def _filter_predict(trans_matrix,
                    trans_cov,
                    trans_off,
                    contr_matrix, #[n_dim_state, n_dim_contr]
                    curr_state_mean,
                    curr_state_cov,
                    control, #[n_batches, n_dim_contr]
                    cov_checker=CheckPosDef()):
    r"""Calculate the state at time `t+1` given the state at time `t`"""
    
    pred_state_mean = trans_matrix.unsqueeze(0) @ curr_state_mean + contr_matrix.unsqueeze(0) @ control.unsqueeze(-1) + trans_off.unsqueeze(-1)
    pred_state_cov =  trans_matrix.unsqueeze(0) @ curr_state_cov @ trans_matrix.unsqueeze(0).mT + trans_cov.unsqueeze(0)

    cov_checker.check(pred_state_cov, caller='filter_predict')
    return (pred_state_mean, pred_state_cov)

In [None]:
trans_matrix, trans_cov, trans_off, contr_matrix, curr_state_mean,curr_state_cov = (k.trans_matrix, k.trans_cov, k.trans_off,
                                                  k.contr_matrix,
                                                  torch.stack([k.init_state_mean]*2).unsqueeze(-1),
                                                  torch.stack([k.init_state_cov]*2))

In [None]:
pred_state_mean, pred_state_cov = _filter_predict(
    trans_matrix, trans_cov, trans_off, contr_matrix,
    curr_state_mean,curr_state_cov, control[:,0,:])

In [None]:
show_as_row(pred_state_mean, pred_state_cov)

In [None]:
show_as_row((pred_state_mean.shape, pred_state_cov.shape,))

### Filter correct

Probability of state at time `t` given the observations at time `t`

$p(x_t|y_t) = \mathcal{N}(x_t; m_t, P_t)$ where:

- predicted obs mean: $z_t = Hm_t^- + d$  

- prediced obs covariance: $S_t = HP_t^-H^T + R$

- kalman gain$K_t = P_t^-H^TS_t^{-1}$ 

- corrected state mean: $m_t = m_t^- + K_t(y_t - z_t)$ 

- corrected state covariance: $P_t = (I-K_tH)P_t^-$ 

if the observation are missing this step is skipped and the corrected state is equal to the predicted state


Need to figure out the Nans for the gradients ...

#### Missing observations

If all the observations at time $t$ are missing the correct step is skipped and the filtered state at time $t$ () is the same of the filtered state.

If only some observations are missing a variation of equation can be used.

$y^{ng}_t$ is a vector containing the observations that are not missing at time $t$. 

It can be expressed as a linear transformation of $y_t$

$$ y^{ng}_t = My_t$$

where $M$ is a mask matrix that is used to select the subset of $y_t$ that is observed. $M \in \mathbb{R}^{n_{ng} \times n}$ and is made of columns which are made of all zeros but for an entry 1 at row corresponding to the non-missing observation.
hence:

$$ p(y^{ng}_t) = \mathcal{N}(M\mu_{y_t},  M\Sigma_{y_t}M^T)$$

from which you can derive

$$ p(y^{ng}_t|x_t) = p(MHx_t + Mb, MRM^T) $${#eq-filter-correct}

Then the posterior $p(x_t|y_t^{ng})$ can be computed similarly of equation @filter_correct as:

$$ p(x_t|y^{ng}_t) = \mathcal{N}(x_t; m_t, P_t) $${#eq-filter_correct_missing}
    
where:

*  predicted obs mean: $z_t = MHm_t^- + Md$
*  predicted obs covariance: $S_t = MHP_t^-(MH)^T + MRM^T$
*  Kalman gain $K_t = P_t^-(MH)^TS_t^{-1}$
*  corrected state mean: $m_t = m_t^- + K_t(My_t - z_t)$
*  corrected state covariance: $P_t = (I-K_tMH)P_t^-$


In [None]:
k.obs_off.shape

torch.Size([3])

##### Details implementation 

For the implementation the matrix multiplication $MH$ can be replaced with `H[m]` where `m` is the mask for the rows for `H` and $MRM^T$ with `R[m][:,m]`

In [None]:
obs_matrix, obs_cov, obs_off,obs, mm = (k.obs_matrix, k.obs_cov, k.obs_off, data[:,0,:], mask[:,0,:])

In [None]:
m = torch.tensor([False,True,True]) # mask batch
M = torch.tensor([[0,1,0], # mask matrix
                  [0,0,1]], dtype=torch.float64)
show_as_row(m, M, obs_matrix, obs_cov)

In [None]:
M @ obs_matrix, obs_matrix[m]

(tensor([[0.1210, 0.6346, 0.0657, 0.3273],
         [0.4661, 0.8703, 0.4889, 0.6028]], dtype=torch.float64,
        grad_fn=<MmBackward0>),
 tensor([[0.1210, 0.6346, 0.0657, 0.3273],
         [0.4661, 0.8703, 0.4889, 0.6028]], dtype=torch.float64,
        grad_fn=<IndexBackward0>))

In [None]:
M @ obs_cov @ M.T, obs_cov[m][:,m]

(tensor([[0.9728, 0.0000],
         [0.0000, 1.4943]], dtype=torch.float64, grad_fn=<MmBackward0>),
 tensor([[0.9728, 0.0000],
         [0.0000, 1.4943]], dtype=torch.float64, grad_fn=<IndexBackward0>))

By using partially missing observations `_filter_correct` cannot be easily batched as the shape of the intermediate variables depends on the number of observed variables. So the idea is to divide the batch in batches where there is the same number of variables.

In [None]:
mask_values, indices = torch.unique(mask[:,1,:], dim=0, return_inverse=True)
mask_values, indices

(tensor([[ True, False,  True],
         [ True,  True, False]]),
 tensor([0, 1]))

In [None]:
#| export
def _filter_correct_batch(
                    obs_matrix,
                    obs_cov,
                    obs_off,
                    pred_state_mean,
                    pred_state_cov,
                    obs, # [n_obs]
                    mask, # [n_obs_np, n_obs] mask to obtain non missing obs from obs
                    cov_checker=CheckPosDef()):
    """Update state at time `t` given observations at time `t` assuming that all observations have the same mask"""

    m_obs_matrix, m_obs_off, m_obs, m_obs_cov = obs_matrix[mask], obs_off[mask], obs[:, mask], obs_cov[mask][:,mask]
    
    # extra dim needed to have batched matmul working between matrices and means
    (m_obs_matrix,), (m_obs_off, m_obs) = unsqueeze_first(m_obs_matrix), unsqueeze_last(m_obs_off, m_obs) 
    
    pred_obs_mean = m_obs_matrix @ pred_state_mean + m_obs_off
    pred_obs_cov = m_obs_matrix @ pred_state_cov @ m_obs_matrix.mT + m_obs_cov
    kalman_gain = pred_state_cov @ m_obs_matrix.mT @ torch.inverse(pred_obs_cov) # torch.cholesky_inverse(torch.linalg.cholesky(pred_obs_cov))

    corr_state_mean = pred_state_mean + kalman_gain @ (m_obs - pred_obs_mean) #select with the mask instead of multipling so that support nan in the dataset
    corr_state_cov = pred_state_cov - kalman_gain @ m_obs_matrix @ pred_state_cov

    cov_checker.check(pred_state_cov, caller='filter_correct')
    return (corr_state_mean, corr_state_cov)

In [None]:
obs_matrix, obs_cov, obs_off,obs, mm = (k.obs_matrix, k.obs_cov, k.obs_off, data[:,0,:], mask[:,0,:])

In [None]:
corr_s_mean,corr_s_cov = _filter_correct_batch(obs_matrix, obs_cov, obs_off, pred_state_mean[0:1], pred_state_cov[0:1], obs[0:1], mm[0])

In [None]:
corr_s_mean.shape, corr_s_cov.shape

(torch.Size([1, 4, 1]), torch.Size([1, 4, 4]))

In [None]:
#| export
def _filter_correct(obs_matrix,
                    obs_cov,
                    obs_off,
                    pred_state_mean,
                    pred_state_cov,
                    obs,
                    mask,
                    cov_checker=CheckPosDef()) -> ListMNormal:
    """Update state at time `t` given observations at time `t`"""

    corr_state_mean, corr_state_cov = torch.empty_like(pred_state_mean), torch.empty_like(pred_state_cov)
    
    # find the unique values of the mask and make a sub-batches with it
    mask_values, indices = torch.unique(mask, return_inverse=True, dim=0)  
    for i, mask_v in enumerate(mask_values):
        idx_select = indices == i 
        corr_state_mean[idx_select], corr_state_cov[idx_select] = _filter_correct_batch(
            obs_matrix, obs_cov, obs_off,
            pred_state_mean[idx_select], pred_state_cov[idx_select],
            obs[idx_select], mask_v,
            cov_checker
        
        )
        assert all(mask[idx_select][0] == mask_v)
    
    return ListMNormal(corr_state_mean, corr_state_cov)

In [None]:
obs_matrix, obs_cov, obs_off,obs, mm = (k.obs_matrix, k.obs_cov, k.obs_off, data[:,0,:], mask[:,0,:])

In [None]:
corr_s_mean, corr_s_cov = _filter_correct(obs_matrix, obs_cov, obs_off, pred_state_mean, pred_state_cov, obs, mm)

In [None]:
show_as_row(corr_s_mean, corr_s_cov)

In [None]:
corr_s_mean.shape, corr_s_cov.shape

(torch.Size([2, 4, 1]), torch.Size([2, 4, 4]))

In [None]:
corr_s_mean.sum().backward(retain_graph=True) # check that pytorch can compute gradients with the whole batch

### Filter

The resursive version of the kalman filter is apperently breaking pytorch gradients calculations so a workaround is needed.
During the loop the states are saved in a python list and then at the end they are combined back into a tensor.
The last line of the function does:

- convert lists to tensors
- correct order dimensions

In [None]:
#| export
def _times2batch(x):
    """Permutes `x` so that the first dimension is the number of batches and not the times"""
    return x.permute(1,0,-2,-1)

In [None]:
#| export
def _filter(trans_matrix, obs_matrix, contr_matrix,
            trans_cov, obs_cov,
            trans_off, obs_off,
            init_state_mean, init_state_cov,
            obs, mask, control,
            cov_checker=CheckPosDef()
           ) ->Tuple[List, List, List, List]: # pred_state_means, pred_state_covs, filt_state_means, filt_state_covs
    """Filter observations using kalman filter """
    n_timesteps = obs.shape[-2]
    bs = obs.shape[0]
    # lists are mutable so need to copy them
    pred_state_means, pred_state_covs, filt_state_means, filt_state_covs = [[None for _ in range(n_timesteps)].copy() for _ in range(4)] 

    for t in range(n_timesteps):
        if t == 0:
            pred_state_means[t], pred_state_covs[t] = torch.stack([init_state_mean]*bs).unsqueeze(-1), torch.stack([init_state_cov]*bs)
        else:
            pred_state_means[t], pred_state_covs[t] = _filter_predict(trans_matrix, trans_cov, trans_off, contr_matrix,
                                                                      filt_state_means[t - 1], filt_state_covs[t - 1], control[:,t,:],
                                                                      cov_checker.add_args(t=t))

        filt_state_means[t], filt_state_covs[t] = _filter_correct(obs_matrix, obs_cov, obs_off,
                                                                     pred_state_means[t], pred_state_covs[t],
                                                                     obs[:,t,:], mask[:,t,:],
                                                                     cov_checker.add_args(t=t))
    
    ret = list(maps(torch.stack, _times2batch, (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs,)))
    return ret

In [None]:
obs, init_state_mean, init_state_cov = data, k.init_state_mean, k.init_state_cov

In [None]:
pred_state_means, pred_state_covs, filt_state_means, filt_state_covs = _filter(
    trans_matrix, obs_matrix, contr_matrix,
    trans_cov, obs_cov,
    trans_off, obs_off,
    init_state_mean, init_state_cov,
    data, mask, control)

Predictions at time `0` for both batches

In [None]:
show_as_row(list(map(Self.shape(), (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs,))))

In [None]:
show_as_row(list(map(lambda x:x[0][0], (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs,))))

### KalmanFilter method

In [None]:
#| export
@patch
def _filter_all(self: KalmanFilter, obs, mask, control
               ) ->Tuple[List, List, List, List]: # pred_state_means, pred_state_covs, filt_state_means, filt_state_covs
    """ wrapper around `_filter`"""
    obs, mask = self._parse_obs(obs, mask)
    return _filter(
            self.trans_matrix, self.obs_matrix, self.contr_matrix,
            self.trans_cov, self.obs_cov,
            self.trans_off, self.obs_off,
            self.init_state_mean, self.init_state_cov,
            obs, mask, control,
            self.cov_checker
        )

In [None]:
pred_mean, _, _, _ = k._filter_all(obs, mask, control);

In [None]:
type(k._filter_all(obs, mask, control))

list

In [None]:
pred_mean.sum().backward(retain_graph=True) # it works!

The filter methods wraps `_filter_all` but in addition:

- returns only filtered state
- detach tensors

In [None]:
#| export
@patch
def filter(self: KalmanFilter,
          obs: Tensor, # [n_timesteps, n_dim_obs] obs for times [0...n_timesteps-1]
          mask: Tensor,  # [n_timesteps, n_dim_obs] obs for times [0...n_timesteps-1]
          control: Tensor, # [n_timesteps, n_dim_contr] control for times [1...n_timesteps-1]
          ) -> ListMNormal: # Filtered state
    """Filter observation"""
    _, _, filt_state_means, filt_state_covs = self._filter_all(obs, mask, control)
    return ListMNormal(filt_state_means.squeeze(-1), filt_state_covs)

In [None]:
filt = k.filter(obs, mask, control)
filt.mean.shape, filt.cov.shape

(torch.Size([2, 10, 4]), torch.Size([2, 10, 4, 4]))

## Smooth

### Smooth step

compute the probability of the state at time `t` given all the observations

$p(x_t|Y) = \mathcal{N}(x_t; m_t^s, P_t^s)$ where:

- Kalman smoothing gain: $G_t = P_tA^T(P_{t+1}^-)^{-1}$
- smoothed mean: $m_t^s = m_t + G_t(m_{t+1}^s - m_{t+1}^-)$
- smoothed covariance: $P_t^s = P_t + G_t(P_{t+1}^s - P_{t+1}^-)G_t^T$

In [None]:
#| export
def _smooth_update(trans_matrix,                # [n_dim_state, n_dim_state]
                   filt_state: MNormal,         # [n_dim_state] filtered state at time `t`
                   pred_state: MNormal,         # [n_dim_state] state before filtering at time `t + 1` (= using the observation until time t)
                   next_smoothed_state: Normal, # [n_dim_state] smoothed state at time  `t+1`
                   cov_checker = CheckPosDef()
                   ) -> MNormal:                # mean and cov of smoothed state at time `t`
    """Correct a pred state with a Kalman Smoother update"""
    kalman_smoothing_gain = filt_state.cov @ trans_matrix.unsqueeze(0).mT @ torch.inverse(pred_state.cov) # torch.cholesky_inverse(torch.linalg.cholesky(pred_state.cov))

    smoothed_state_mean = filt_state.mean + kalman_smoothing_gain @ (next_smoothed_state.mean - pred_state.mean)
    smoothed_state_cov = filt_state.cov + kalman_smoothing_gain @ (next_smoothed_state.cov - pred_state.cov) @ kalman_smoothing_gain.mT

    cov_checker.check(smoothed_state_cov, caller='smooth_update')
    
    return MNormal(smoothed_state_mean, smoothed_state_cov)

In [None]:
filt_state, pred_state, next_smoothed_state = [MNormal(pred_state_mean, pred_state_cov)] * 3 # just for testing

In [None]:
show_as_row(*_smooth_update(trans_matrix, MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov)))

In [None]:
show_as_row(*map(Self.shape(), _smooth_update(trans_matrix, MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov))))

### Smooth

In [None]:
#| export
def _smooth(trans_matrix, # `[n_dim_state, n_dim_state]`
            filt_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `filt_state_means[t]` is the state estimate for time t given obs from times `[0...t]`
            pred_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `pred_state_means[t]` is the state estimate for time t given obs from times `[0...t-1]`
            cov_checker = CheckPosDef()
           ) -> ListMNormal: # `[n_timesteps, n_dim_state]` Smoothed state 
    """Apply the Kalman Smoother"""
    x = pred_state.mean # sample for getting tensor properties
    bs, n_timesteps, n_dim_state = x.shape[0], x.shape[1], x.shape[2]

    smoothed_state = ListMNormal(torch.zeros((bs, n_timesteps,n_dim_state,1),             dtype=x.dtype, device=x.device), 
                                 torch.zeros((bs, n_timesteps, n_dim_state,n_dim_state), dtype=x.dtype, device=x.device))
    # For the last timestep cannot use the smoother
    smoothed_state.mean[:,-1,] = filt_state.mean[:,-1]
    smoothed_state.cov[:,-1] = filt_state.cov[:,-1]

    for t in reversed(range(n_timesteps - 1)):
        (smoothed_state.mean[:,t], smoothed_state.cov[:,t]) = (
            _smooth_update(
                trans_matrix,
                filt_state[:,t],
                pred_state[:,t + 1],
                smoothed_state[:,t+1],
            )
        )
    return smoothed_state

In [None]:
(pred_state_means, pred_state_covs, filt_state_means, filt_state_covs ) = k._filter_all(data, mask, control)
filt_state, pred_state = ListMNormal(filt_state_means, filt_state_covs), ListMNormal(pred_state_means, pred_state_covs)

In [None]:
smooth_state = _smooth(k.trans_matrix,  filt_state, pred_state)

In [None]:
show_as_row(smooth_state.mean[0][0], smooth_state.cov[0][0])

In [None]:
show_as_row(smooth_state.mean.shape, smooth_state.cov.shape)

### KalmanFilter method

In [None]:
#| export
@patch
def smooth(self: KalmanFilter,
           obs: Tensor,
           mask: Tensor,
           control: Tensor
          ) -> ListMNormal: # `[n_timesteps, n_dim_state]` smoothed state
        
    """Kalman Filter Smoothing"""

    (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs) = self._filter_all(obs, mask, control)

    smoothed_state = _smooth(self.trans_matrix,
                   ListMNormal(filt_state_means, filt_state_covs), ListMNormal(pred_state_means, pred_state_covs),
                   self.cov_checker)
    smoothed_state.mean.squeeze_(-1)
    return smoothed_state

In [None]:
smoothed_state = k.smooth(data, mask, control)

In [None]:
show_as_row(smoothed_state.mean.shape, smoothed_state.cov.shape)

## Predict

The prediction at time t ($y_t$) are computed rom the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

this works both if the state was filtered or smoother

This add the supports for conditional predictions, which means that at the time (t) when we are making the predictions some of the variables have been actually observed. Since the model prediction is a normal distribution we can condition on the observed values and thus improve the predictions. See `conditional_gaussian`

In order to have conditional predictions that make sense it's not possible to return the full covariance matrix for the predictions but only the standard deviations

In [None]:
test_m = torch.tensor(
    [[True, True, True,],
    [False, True, True],
    [False, False, False]]
)

In [None]:
torch.logical_xor(test_m.all(-1), test_m.any(-1))

tensor([False,  True, False])

In [None]:
A = torch.rand(2,2,3,3)

In [None]:
(A @ A).shape

torch.Size([2, 2, 3, 3])

predict can be vectorized across both the batch and the timesteps, except for timesteps that require conditional predictions

In [None]:
#| export
@patch
def _obs_from_state(self: KalmanFilter, state: ListMNormal):

    mean = self.obs_matrix @ state.mean.unsqueeze(-1) + self.obs_off.unsqueeze(-1)
    cov = self.obs_matrix @ state.cov @ self.obs_matrix.mT + self.obs_cov
    
    for c in cov: # this is batched and for all timestamps
        self.cov_checker.check(c, caller='predict')
    
    return ListMNormal(mean.squeeze(-1), cov)

In [None]:
smoothed_state.mean.shape, smoothed_state.cov.shape

(torch.Size([2, 10, 4]), torch.Size([2, 10, 4, 4]))

In [None]:
(k.obs_matrix @ smoothed_state.mean.unsqueeze(-1)).shape

torch.Size([2, 10, 3, 1])

In [None]:
pred_obs0 = k._obs_from_state(smoothed_state)
pred_obs0.mean.shape

torch.Size([2, 10, 3])

In [None]:
pred_obs0.cov.shape

torch.Size([2, 10, 3, 3])

In [None]:
#| export
@patch
def predict(self: KalmanFilter, obs, mask, control, smooth=True):
    """Predicted observations at all times """
    state = self.smooth(obs, mask, control) if smooth else self.filter(obs, mask, control)
    obs, mask = self._parse_obs(obs, mask)
    
    pred_obs = self._obs_from_state(state)
    # conditional predictions are slow, do only if some obs are missing 
    cond_mask = torch.logical_xor(mask.all(-1), mask.any(-1))
    
    # this cannot be batched so returns a list
    cond_preds = cond_gaussian_batched(
        pred_obs[cond_mask], obs[cond_mask], mask[cond_mask])
    
    pred_mean, pred_std = pred_obs.mean, cov2std(pred_obs.cov) # multiple [] still not properly implemented in ListMNormal
    
    for i, c_pred in enumerate(cond_preds):
        m = ~mask[cond_mask][i]
        pred_mean[cond_mask][i][m] = c_pred.mean
        pred_std [cond_mask][i][m] = cov2std(c_pred.cov)
    
    return ListNormal(pred_mean, pred_std)

In [None]:
pred = k.predict(data, mask, control)

In [None]:
pred.mean.shape, pred.std.shape

(torch.Size([2, 10, 3]), torch.Size([2, 10, 3]))

Gradients ...

In [None]:
def get_grad_mask(x):
    "filter gradient after sub the masks value with x"
    d = data.clone()
    d[~mask] = x
    k.predict(data, mask, control).mean.sum().backward(retain_graph=True)
    grad = k.obs_cov_raw.grad.clone()
    k.zero_grad() 
    return grad

In [None]:
get_grad_mask(10)

tensor([-9.9394, 17.8374,  2.9461], dtype=torch.float64)

In [None]:
test_close(get_grad_mask(1), get_grad_mask(10))

In [None]:
@patch
def predict_times(self: KalmanFilter, times, obs, mask=None, smooth=True, check_args=None):
    """Predicted observations at specific times """
    state = self.smooth(obs, mask, check_args) if smooth else self.filter(obs, mask, check_args)
    obs, mask = self._parse_obs(obs, mask)
    times = array1d(times)
    
    n_timesteps = obs.shape[0]
    n_features = obs.shape[1] if len(obs.shape) > 1 else 1
    
    if times.max() > n_timesteps or times.min() < 0:
        raise ValueError(f"provided times range from {times.min()} to {times.max()}, which is outside allowed range : 0 to {n_timesteps}")

    means = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device)
    stds = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device) 
    for i, t in enumerate(times):
        mean, std = self._obs_from_state(
            state.mean[t],
            state.cov[t],
            {'t': t, **check_args} if check_args is not None else None
        )
        
        means[i], stds[i] = _get_cond_pred(ListNormal(mean, std), obs[t], mask[t])
    
    return ListNormal(means, stds)  

## Additional

### Get Info

In [None]:
k.obs_matrix

Parameter containing:
tensor([[0.5533, 0.3582, 0.8828, 0.3156],
        [0.1210, 0.6346, 0.0657, 0.3273],
        [0.4661, 0.8703, 0.4889, 0.6028]], dtype=torch.float64,
       requires_grad=True)

In [None]:
#| export
@patch
def get_info(self: KalmanFilter):
    out = {}
    var_names = ifnone(self.var_names, [f"y_{i}" for i in range(self.obs_matrix.shape[0])])
    latent_names = [f"x_{i}" for i in range(self.trans_matrix.shape[0])]
    contr_names = ifnone(self.contr_names, [f"c_{i}" for i in range(self.contr_matrix.shape[1])])
    out['trans matrix (A)'] = array2df(self.trans_matrix,    latent_names, latent_names, 'state')
    out['trans cov (Q)']    = array2df(self.trans_cov,       latent_names, latent_names, 'state')
    out['trans off']        = array2df(self.trans_off,       latent_names, ['offset'],   'state')
    out['obs matrix (H)']   = array2df(self.obs_matrix,      var_names,    latent_names, 'variable')
    out['obs cov (R)']      = array2df(self.obs_cov,         var_names,    var_names,    'variable')
    out['obs off']          = array2df(self.obs_off,         var_names,    ['offset'],   'variable')
    out['contr matrix (B)'] = array2df(self.contr_matrix,    latent_names, contr_names,  'state')
    out['init state mean']  = array2df(self.init_state_mean, latent_names, ['mean'],     'state')
    out['init state cov']   = array2df(self.init_state_cov,  latent_names, latent_names, 'state')

    return out

In [None]:
k.contr_matrix

Parameter containing:
tensor([[2.6977e-01, 7.9453e-01, 1.6282e-01],
        [2.3017e-01, 6.1964e-01, 4.0785e-01],
        [5.3003e-01, 4.5534e-01, 4.3027e-01],
        [7.5926e-01, 2.3556e-04, 1.8805e-01]], dtype=torch.float64,
       requires_grad=True)

In [None]:
display_as_row(k.get_info())

Unnamed: 0,state,x_0,x_1,x_2,x_3
0,x_0,0.6023,0.3413,0.3698,0.6603
1,x_1,0.6889,0.2815,0.5281,0.814
2,x_2,0.274,0.7634,0.2004,0.6965
3,x_3,0.8038,0.2749,0.6018,0.9625

Unnamed: 0,state,x_0,x_1,x_2,x_3
0,x_0,2.1561,1.4566,1.184,0.629
1,x_1,1.4566,1.062,0.7878,0.5357
2,x_2,1.184,0.7878,0.7294,0.3725
3,x_3,0.629,0.5357,0.3725,0.367

Unnamed: 0,state,offset
0,x_0,0.0022
1,x_1,0.9561
2,x_2,0.7751
3,x_3,0.1734

Unnamed: 0,variable,x_0,x_1,x_2,x_3
0,y_0,0.5533,0.3582,0.8828,0.3156
1,y_1,0.121,0.6346,0.0657,0.3273
2,y_2,0.4661,0.8703,0.4889,0.6028

Unnamed: 0,variable,y_0,y_1,y_2
0,y_0,0.554,0.0,0.0
1,y_1,0.0,0.9728,0.0
2,y_2,0.0,0.0,1.4943

Unnamed: 0,variable,offset
0,y_0,0.4648
1,y_1,0.9363
2,y_2,0.8193

Unnamed: 0,control,c_0,c_1,c_2
0,x_0,0.2698,0.7945,0.1628
1,x_1,0.2302,0.6196,0.4079
2,x_2,0.53,0.4553,0.4303
3,x_3,0.7593,0.0002,0.1881

Unnamed: 0,state,mean
0,x_0,0.7244
1,x_1,0.6438
2,x_2,0.957
3,x_3,0.8085

Unnamed: 0,state,x_0,x_1,x_2,x_3
0,x_0,0.8758,0.7556,0.4823,0.6878
1,x_1,0.7556,0.9753,0.7402,0.6828
2,x_2,0.4823,0.7402,1.0871,0.5918
3,x_3,0.6878,0.6828,0.5918,0.7291


In [None]:
#| export
@patch
def _repr_html_(self: KalmanFilter):
    title = f"Kalman Filter ({self.n_dim_obs} obs, {self.n_dim_state} state, {self.n_dim_contr} contr)"
    return row_dfs(self.get_info(), title , hide_idx=True)

In [None]:
k

state,x_0,x_1,x_2,x_3
x_0,0.6023,0.3413,0.3698,0.6603
x_1,0.6889,0.2815,0.5281,0.814
x_2,0.274,0.7634,0.2004,0.6965
x_3,0.8038,0.2749,0.6018,0.9625

state,x_0,x_1,x_2,x_3
x_0,2.1561,1.4566,1.184,0.629
x_1,1.4566,1.062,0.7878,0.5357
x_2,1.184,0.7878,0.7294,0.3725
x_3,0.629,0.5357,0.3725,0.367

state,offset
x_0,0.0022
x_1,0.9561
x_2,0.7751
x_3,0.1734

variable,x_0,x_1,x_2,x_3
y_0,0.5533,0.3582,0.8828,0.3156
y_1,0.121,0.6346,0.0657,0.3273
y_2,0.4661,0.8703,0.4889,0.6028

variable,y_0,y_1,y_2
y_0,0.554,0.0,0.0
y_1,0.0,0.9728,0.0
y_2,0.0,0.0,1.4943

variable,offset
y_0,0.4648
y_1,0.9363
y_2,0.8193

control,c_0,c_1,c_2
x_0,0.2698,0.7945,0.1628
x_1,0.2302,0.6196,0.4079
x_2,0.53,0.4553,0.4303
x_3,0.7593,0.0002,0.1881

state,mean
x_0,0.7244
x_1,0.6438
x_2,0.957
x_3,0.8085

state,x_0,x_1,x_2,x_3
x_0,0.8758,0.7556,0.4823,0.6878
x_1,0.7556,0.9753,0.7402,0.6828
x_2,0.4823,0.7402,1.0871,0.5918
x_3,0.6878,0.6828,0.5918,0.7291


### Constructors

#### Simple parameters

In [None]:
#| export
@patch(cls_method=True)
def init_simple(cls: KalmanFilter,
                n_dim, # n_dim_obs and n_dim_state
                dtype=torch.float32):
    """Simplest version of kalman filter parameters"""
    return cls(
        trans_matrix =     torch.eye(n_dim, dtype=dtype),
        trans_off =        torch.zeros(n_dim, dtype=dtype),        
        trans_cov =        torch.eye(n_dim, dtype=dtype),        
        obs_matrix =       torch.eye(n_dim, dtype=dtype),
        obs_off =          torch.zeros(n_dim, dtype=dtype),          
        obs_cov =          torch.eye(n_dim, dtype=dtype),            
        contr_matrix =     torch.eye(n_dim, dtype=dtype),
        init_state_mean =  torch.zeros(n_dim, dtype=dtype),        
        init_state_cov =   torch.eye(n_dim, dtype=dtype),
    )

In [None]:
KalmanFilter.init_simple(2).state_dict()

#### Local slope

Local slope models are an extentions of local level model that in the state variable keep track of also the slope

Given $n$ as the number of dimensions of the observations

The transition matrix (`A`) is:

$$A = \left[\begin{array}{cc}I & I \\ 0 & I\end{array}\right]$$

where:

- $I \in \mathbb{R}^{n \times n}$
- $A \in \mathbb{R}^{2n \times 2n}$

the state $x \in \mathbb{R}^{2N \times 1}$ where the upper half keep track of the level and the lower half of the slope. $A \in \mathbb{R}^2N \times 2N$

the observation matrix (`H`) is:

$$H = \left[\begin{array}{cc}I & 0 \end{array}\right]$$

For the multivariate case the 1 are replaced with an identiy matrix


In [None]:
#| export
@patch(cls_method=True)
def init_local_slope(cls: KalmanFilter,
                n_dim, # n_dim_obs and n_dim_state
                dtype=torch.float32):
    """Simplest version of kalman filter parameters"""
    n_dim_state = 2 * n_dim
    return cls(
        trans_matrix =     torch.eye(n_dim, dtype=dtype),
        trans_off =        torch.zeros(n_dim, dtype=dtype),        
        trans_cov =        torch.eye(n_dim, dtype=dtype),        
        obs_matrix =       torch.eye(n_dim, dtype=dtype),
        obs_off =          torch.zeros(n_dim, dtype=dtype),          
        obs_cov =          torch.eye(n_dim, dtype=dtype),            
        contr_matrix =     torch.zeros(n_dim, dtype=dtype),
        init_state_mean =  torch.zeros(n_dim, dtype=dtype),        
        init_state_cov =   torch.eye(n_dim, dtype=dtype),
    )

In [None]:
KalmanFilter.init_simple(2).state_dict()

## Export

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()