# Kalman Filter
> Implementation of Kalman filters using pytorch and parameter optimizations with gradient descend

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
#| default_exp kalman.filter

In [None]:
#| export
from fastcore.test import *
from fastcore.basics import *
from meteo_imp.utils import *
from meteo_imp.gaussian import *
from meteo_imp.data_preparation import MeteoDataTest
from typing import *

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

## Introduction

The models uses a latent state variable $x$ that is modelled over time, to impute gaps in $y$

The assumption of the model is that the state variable at time $x_t$ depends only on the last state $x_{t-1}$ and not on the previous states.

### Equations

The equations of the model are:

$$\begin{align} p(x_t | x_{t-1}) & = \mathcal{N}(x_t, Ax_{t-1} + b, Q) \\
p(y_t | x_t) & = \mathcal{N}(Hx_t + d, R) \end{align}$$

The Kalman filter has 3 steps:

- filter (updating the state at time t with observations till time t-1)
- update (update the state at time t using the observation at time t)
- smooth (update the state using the observations at time t+1)

In case of missing data the update step is skipped.

After smoothing the missing data at time t ($y_t$) can be imputed from the state ($x_t$) using this formula:
$$p(y_t|x_t) = \mathcal{N}(Hx_t + d, R + HP^s_tH^T)$$

## KalmanFilter

The Kalman Filter is an algorithm designed to estimate $P(x_t | y_{0:t})$.  As all state transitions and obss are linear with Gaussian distributed noise, these distributions can be represented exactly as Gaussian distributions with mean `filt_state_means[t]` and covs `filt_state_covs[t]`.
Similarly, the Kalman Smoother is an algorithm designed to estimate $P(x_t | y_{0:t-1})$



### Main class

TODO: fill nans with 0 for all data

In [None]:
#| export
class KalmanFilter(torch.nn.Module):
    """Base class for Kalman Filter and Smoother using PyTorch"""
    def __init__(self,
            trans_matrix: Tensor,    # [n_dim_state,n_dim_state] $A$, state transition matrix 
            obs_matrix: Tensor,      # [n_dim_obs, n_dim_state] $H$, observation matrix
            trans_cov: Tensor,       # [n_dim_state, n_dim_state] $Q$, state trans covariance matrix
            obs_cov: Tensor,         # [n_dim_obs, n_dim_obs] $R$, observations covariance matrix
            trans_off: Tensor,       # [n_dim_state] $b$, state transition offset
            obs_off: Tensor,         # [n_dim_obs] $d$, observations offset
            init_state_mean: Tensor, # [n_dim_state] $\mu_0$
            init_state_cov: Tensor,  # [n_dim_state, n_dim_state] $\Sigma_0$
            n_dim_state: int = None, # Number of dimensions for state - defaults to 1 if cannot be infered from parameters
            n_dim_obs: int = None,   # Number of dimensions for observations - defaults to 1 if cannot be infered from parameters
            cov_checker: CheckPosDef = CheckPosDef()
                ):
        
        super().__init__()
        # check parameters are consistent
        self.n_dim_state = determine_dimensionality(
            [(trans_matrix, array2d, -2),
             (trans_off, array1d, -1),
             (trans_cov, array2d, -2),
             (init_state_mean, array1d, -1),
             (init_state_cov, array2d, -2),
             (obs_matrix, array2d, -1)],
            n_dim_state
        )
        self.n_dim_obs = determine_dimensionality(
            [(obs_matrix, array2d, -2),
             (obs_off, array1d, -1),
             (obs_cov, array2d, -2)],
            n_dim_obs
        )
        
        params = {
        #name               value             constraint
        'trans_matrix':     [trans_matrix,    None    ],
        'trans_off':        [trans_off,       None    ],
        'trans_cov':        [trans_cov,       PosDef()],
        'obs_matrix':       [obs_matrix,      None    ],
        'obs_off':          [obs_off,         None    ],
        'obs_cov':          [obs_cov,         PosDef()],
        'init_state_mean':  [init_state_mean, None    ],
        'init_state_cov':   [init_state_cov,  PosDef()],
        }
        self._init_params(params)
        
        self.cov_checker = cov_checker
        
    def _init_params(self, params):
        for name, (value, constraint) in params.items():
            if constraint is not None:
                name, value = self._init_constraint(name, value, constraint)
            self._init_param(name, value, train=True)    
    
    def _init_param(self, param_name, value, train):
        self.register_parameter(param_name, torch.nn.Parameter(value, requires_grad=train))
    
    ### === Constraints utils
    def _init_constraint(self, param_name, value, constraint):
        name = param_name + "_raw"
        value = constraint.inverse_transform(value)
        setattr(self, param_name + "_constraint", constraint)
        return name, value
    
    def _get_constraint(self, param_name):
        constraint = getattr(self, param_name + "_constraint")
        raw_value = getattr(self, param_name + "_raw")
        return constraint.transform(raw_value)
    
    def _set_constraint(self, param_name, value, train=True):
        constraint = getattr(self, param_name + "_constraint")
        raw_value = constraint.inverse_transform(value)
        self._init_param(param_name + "_raw", raw_value, train)
    
    ### === Convenience functions to get and set parameters that have a constraint
    @property
    def trans_cov(self): return self._get_constraint('trans_cov')
    @trans_cov.setter
    def trans_cov(self, value): return self._set_constraint('trans_cov', value)

    @property
    def obs_cov(self): return self._get_constraint('obs_cov')
    @obs_cov.setter
    def obs_cov(self, value): return self._set_constraint('obs_cov', value)
    
    @property
    def init_state_cov(self): return self._get_constraint('init_state_cov')
    @init_state_cov.setter
    def init_state_cov(self, value): return self._set_constraint('init_state_cov', value)
    
    
    ### === Utility Func    
    def _parse_obs(self, obs, mask=None):
        """maybe get mask from `nan`"""
        if mask is None: mask = ~torch.isnan(obs)
        # TODO incorrect support for 2d input!!!!!!
        obs, mask = torch.atleast_3d(obs), torch.atleast_3d(mask)
        return obs, mask
    
    def __repr__(self):
        return f"""Kalman Filter
        N dim obs: {self.n_dim_obs}, N dim state: {self.n_dim_state}"""

### Constructors

Giving all the parameters manually to the `KalmanFilter` init method is not convenient, hence we are having some methods that help initize the class

#### Random parameters

In [None]:
#| export
@patch(cls_method=True)
def init_random(cls: KalmanFilter, n_dim_obs, n_dim_state, dtype=torch.float32):
    """kalman filter with random parameters"""
    params = {
        'trans_matrix':    torch.rand(n_dim_state, n_dim_state, dtype=dtype),
        'trans_off':       torch.rand(n_dim_state, dtype=dtype),        
        'trans_cov':       to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),        
        'obs_matrix':      torch.rand(n_dim_obs, n_dim_state, dtype=dtype),
        'obs_off':         torch.rand(n_dim_obs, dtype=dtype),          
        'obs_cov':         to_posdef(torch.rand(n_dim_obs, n_dim_obs, dtype=dtype)),            
        'init_state_mean': torch.rand(n_dim_state, dtype=dtype),        
        'init_state_cov':  to_posdef(torch.rand(n_dim_state, n_dim_state, dtype=dtype)),
    } 
    return cls(**params) 
        

In [None]:
k = KalmanFilter.init_random(3,4, dtype=torch.float64)
k

Kalman Filter
        N dim obs: 3, N dim state: 4

In [None]:
k.init_state_cov

tensor([[1.3752, 1.1063, 1.3391, 1.3323],
        [1.1063, 1.1409, 1.0593, 0.9916],
        [1.3391, 1.0593, 1.4392, 1.3858],
        [1.3323, 0.9916, 1.3858, 1.3901]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

check that assigment works :)

In [None]:
k.init_state_cov = to_posdef(torch.rand(4, 4, dtype=torch.float64))

In [None]:
k.init_state_cov_raw

Parameter containing:
tensor([[ 1.6133,  0.0000,  0.0000,  0.0000],
        [ 0.6912,  0.8888,  0.0000,  0.0000],
        [ 1.0255,  0.0125,  0.2927,  0.0000],
        [ 0.5586,  0.7375, -0.4368,  0.0874]], dtype=torch.float64,
       requires_grad=True)

In [None]:
list(k.named_parameters())

[('trans_matrix',
  Parameter containing:
  tensor([[0.9959, 0.6486, 0.4152, 0.8179],
          [0.1256, 0.5073, 0.0909, 0.9389],
          [0.3924, 0.7097, 0.1217, 0.2662],
          [0.3080, 0.4001, 0.9252, 0.2889]], dtype=torch.float64,
         requires_grad=True)),
 ('trans_off',
  Parameter containing:
  tensor([0.6283, 0.3925, 0.7199, 0.8452], dtype=torch.float64,
         requires_grad=True)),
 ('trans_cov_raw',
  Parameter containing:
  tensor([[ 1.0998,  0.0000,  0.0000,  0.0000],
          [ 1.2269,  0.5530,  0.0000,  0.0000],
          [ 0.8618, -0.3656,  0.5690,  0.0000],
          [ 0.9289,  0.0318,  0.6003,  0.3249]], dtype=torch.float64,
         requires_grad=True)),
 ('obs_matrix',
  Parameter containing:
  tensor([[0.2592, 0.9640, 0.3593, 0.7745],
          [0.9641, 0.6363, 0.6401, 0.0719],
          [0.1985, 0.0254, 0.0740, 0.6775]], dtype=torch.float64,
         requires_grad=True)),
 ('obs_off',
  Parameter containing:
  tensor([0.9729, 0.3114, 0.5707], dtype=torc

#### Test data

In [None]:
#| exporti
def get_test_data(n_obs = 10, n_dim_obs=3, p_missing=.3, bs=2, dtype=torch.float32, device='cpu'):
    data = torch.rand(bs, n_obs, n_dim_obs, dtype=dtype, device=device)
    mask = torch.rand(bs, n_obs, n_dim_obs, device=device) > p_missing
    # data[~mask] = torch.nan # ensure that the missing data cannot be used
    return data, mask

In [None]:
reset_seed()
data, mask = get_test_data()
show_as_row(data, mask)

### Filter

#### Filter predict

Probability of state at time `t` given state a time `t-1` 

$p(x_t) = \mathcal{N}(x_t; m_t^-, P_t^-)$ where:

- predicted state mean: $m_t^- = Am_{t-1} + B c_t + b$  

- predicted state covariance: $P_t^- = AP_{t-1}A^T + Q$

In [None]:
#| export
from datetime import datetime
def _filter_predict(trans_matrix,
                    trans_cov,
                    trans_off,
                    curr_state_mean,
                    curr_state_cov,
                    control_matrix=0,
                    control=0,
                    cov_checker=CheckPosDef()):
    r"""Calculate the state at time `t+1` given the state at time `t`"""
    pred_state_mean = trans_matrix.unsqueeze(0) @ curr_state_mean + trans_off.unsqueeze(-1)
    pred_state_cov =  trans_matrix.unsqueeze(0) @ curr_state_cov @ trans_matrix.unsqueeze(0).mT + trans_cov.unsqueeze(0)

    cov_checker.check(pred_state_cov, caller='filter_predict')
    return (pred_state_mean, pred_state_cov)

In [None]:
trans_matrix, trans_cov, trans_off,curr_state_mean,curr_state_cov = (k.trans_matrix, k.trans_cov, k.trans_off,
                                                  torch.stack([k.init_state_mean]*2).unsqueeze(-1),
                                                  torch.stack([k.init_state_cov]*2))

In [None]:
pred_state_mean, pred_state_cov = _filter_predict(trans_matrix, trans_cov, trans_off,curr_state_mean,curr_state_cov)

In [None]:
show_as_row(pred_state_mean, pred_state_cov)

In [None]:
show_as_row(pred_state_mean.shape, pred_state_cov.shape)

#### Filter correct

Probability of state at time `t` given the observations at time `t`

$p(x_t|y_t) = \mathcal{N}(x_t; m_t, P_t)$ where:

- predicted obs mean: $z_t = Hm_t^- + d$  

- prediced obs covariance: $S_t = HP_t^-H^T + R$

- kalman gain$K_t = P_t^-H^TS_t^{-1}$ 

- corrected state mean: $m_t = m_t^- + K_t(y_t - z_t)$ 

- corrected state covariance: $P_t = (I-K_tH)P_t^-$ 

if the observation are missing this step is skipped and the corrected state is equal to the predicted state


Need to figure out the Nans for the gradients ...

In [None]:
k.obs_off.shape

torch.Size([3])

In [None]:
#| export
def _filter_correct(obs_matrix,
                    obs_cov,
                    obs_off,
                    pred_state_mean,
                    pred_state_cov,
                    obs,
                    mask,
                    cov_checker=CheckPosDef()):
    """Update state at time `t` given observations at time `t`"""
    
    pred_obs_mean = obs_matrix.unsqueeze(0) @ pred_state_mean + obs_off.unsqueeze(-1) # extra dim needed to hae batched matmul working
    pred_obs_cov = obs_matrix.unsqueeze(0) @ pred_state_cov @ obs_matrix.unsqueeze(0).mT + obs_cov

    kalman_gain = pred_state_cov @ obs_matrix.T @ torch.inverse(pred_obs_cov) # torch.cholesky_inverse(torch.linalg.cholesky(pred_obs_cov))

    corr_state_mean = pred_state_mean + kalman_gain @ (obs.unsqueeze(-1) - pred_obs_mean)
    corr_state_cov = pred_state_cov - kalman_gain @ obs_matrix @ pred_state_cov
    
    mask = mask.all(-1) # if any observation is missing need to discard the whole vector
    corr_state_mean[~mask] = pred_state_mean[~mask]
    corr_state_cov[~mask] = pred_state_cov[~mask]

    cov_checker.check(pred_state_cov, caller='filter_correct')
    return (kalman_gain, corr_state_mean, corr_state_cov)

In [None]:
obs_matrix, obs_cov, obs_off,obs, mm = (k.obs_matrix, k.obs_cov, k.obs_off, data[:,0,:], mask[:,0,:])

In [None]:
k_gain, corr_s_mean, corr_s_cov = _filter_correct(obs_matrix, obs_cov, obs_off, pred_state_mean, pred_state_cov, obs, mm)

In [None]:
show_as_row(k_gain, corr_s_mean, corr_s_cov)

In [None]:
show_as_row(*map(lambda x:x.shape, (k_gain, corr_s_mean, corr_s_cov,)))

In [None]:
test_close(corr_s_mean[1], pred_state_mean[1]) # correctly ignoring the missing data

In [None]:
corr_s_mean.sum().backward(retain_graph=True) # check that pytorch can compute gradients with the whole batch

#### Filter

The resursive version of the kalman filter is apperently breaking pytorch gradients calculations so a workaround is needed.
During the loop the states are saved in a python list and then at the end they are combined back into a tensor.
The last line of the function does:

- convert lists to tensors
- correct order dimensions

In [None]:
#| export
def _times2batch(x):
    """Permutes `x` so that the first dimension is the number of batches and not the times"""
    return x.permute(1,0,-2,-1)

In [None]:
#| export
def _filter(trans_matrix, obs_matrix,
            trans_cov, obs_cov,
            trans_off, obs_off,
            init_state_mean, init_state_cov,
            obs, mask,
            cov_checker=CheckPosDef()
           ) ->Tuple[List, List, List, List]: # pred_state_means, pred_state_covs, filt_state_means, filt_state_covs
    """Filter observations using kalman filter """
    n_timesteps = obs.shape[-2]
    bs = obs.shape[0]
    # lists are mutable so need to copy them
    pred_state_means, pred_state_covs, filt_state_means, filt_state_covs = [[None for _ in range(n_timesteps)].copy() for _ in range(4)] 

    for t in range(n_timesteps):
        if t == 0:
            pred_state_means[t], pred_state_covs[t] = torch.stack([init_state_mean]*bs).unsqueeze(-1), torch.stack([init_state_cov]*bs)
        else:
            pred_state_means[t], pred_state_covs[t] = _filter_predict(trans_matrix, trans_cov, trans_off,
                                                                      filt_state_means[t - 1], filt_state_covs[t - 1],
                                                                      cov_checker.add_args(t=t))

        _, filt_state_means[t], filt_state_covs[t] = _filter_correct(obs_matrix, obs_cov, obs_off,
                                                                     pred_state_means[t], pred_state_covs[t],
                                                                     obs[:,t,:], mask[:,t,:],
                                                                     cov_checker.add_args(t=t))
    
    ret = list(maps(torch.stack, _times2batch, (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs,)))
    return ret

In [None]:
obs, init_state_mean, init_state_cov = data, k.init_state_mean, k.init_state_cov

In [None]:
pred_state_means, pred_state_covs, filt_state_means, filt_state_covs = _filter(trans_matrix, obs_matrix, trans_cov, obs_cov, trans_off, obs_off, init_state_mean, init_state_cov, data, mask)

Predictions at time `0` for both batches

In [None]:
show_as_row(*map(Self.shape(), (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs,)))

In [None]:
show_as_row(*map(lambda x:x[0][0], (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs,)))

#### KalmanFilter method

In [None]:
#| export
@patch
def _filter_all(self: KalmanFilter, obs, mask=None
               ) ->Tuple[List, List, List, List]: # pred_state_means, pred_state_covs, filt_state_means, filt_state_covs
    """ wrapper around `_filter`"""
    obs, mask = self._parse_obs(obs, mask)
    return _filter(
            self.trans_matrix, self.obs_matrix,
            self.trans_cov, self.obs_cov,
            self.trans_off, self.obs_off,
            self.init_state_mean, self.init_state_cov,
            obs, mask,
            self.cov_checker
        )

In [None]:
pred_mean, _, _, _ = k._filter_all(obs);

In [None]:
type(k._filter_all(obs, mask))

list

In [None]:
pred_mean.sum().backward(retain_graph=True) # it works!

The filter methods wraps `_filter_all` but in addition:

- returns only filtered state
- detach tensors

In [None]:
#| export
@patch
def filter(self: KalmanFilter,
          obs: Tensor, # [n_timesteps, n_dim_obs] obs for times [0...n_timesteps-1]
          mask = None,
          ) -> ListMNormal: # Filtered state
    """Filter observation"""
    _, _, filt_state_means, filt_state_covs = self._filter_all(obs, mask)
    return ListMNormal(filt_state_means.squeeze(-1), filt_state_covs)

In [None]:
filt = k.filter(obs)
filt.mean.shape, filt.cov.shape

(torch.Size([2, 10, 4]), torch.Size([2, 10, 4, 4]))

### Smooth

#### Smooth step

compute the probability of the state at time `t` given all the observations

$p(x_t|Y) = \mathcal{N}(x_t; m_t^s, P_t^s)$ where:

- Kalman smoothing gain: $G_t = P_tA^T(P_{t+1}^-)^{-1}$
- smoothed mean: $m_t^s = m_t + G_t(m_{t+1}^s - m_{t+1}^-)$
- smoothed covariance: $P_t^s = P_t + G_t(P_t{t+1}^s - P_t{t+1}^-)G_t^T$

In [None]:
#| export
def _smooth_update(trans_matrix,                # [n_dim_state, n_dim_state]
                   filt_state: MNormal,         # [n_dim_state] filtered state at time `t`
                   pred_state: MNormal,         # [n_dim_state] state before filtering at time `t + 1` (= using the observation until time t)
                   next_smoothed_state: Normal, # [n_dim_state] smoothed state at time  `t+1`
                   cov_checker = CheckPosDef()
                   ) -> MNormal:                # mean and cov of smoothed state at time `t`
    """Correct a pred state with a Kalman Smoother update"""
    kalman_smoothing_gain = filt_state.cov @ trans_matrix.unsqueeze(0).mT @ torch.inverse(pred_state.cov) # torch.cholesky_inverse(torch.linalg.cholesky(pred_state.cov))

    smoothed_state_mean = filt_state.mean + kalman_smoothing_gain @ (next_smoothed_state.mean - pred_state.mean)
    smoothed_state_cov = filt_state.cov + kalman_smoothing_gain @ (next_smoothed_state.cov - pred_state.cov) @ kalman_smoothing_gain.mT

    cov_checker.check(smoothed_state_cov, caller='smooth_update')
    
    return MNormal(smoothed_state_mean, smoothed_state_cov)

In [None]:
filt_state, pred_state, next_smoothed_state = [MNormal(pred_state_mean, pred_state_cov)] * 3 # just for testing

In [None]:
show_as_row(*_smooth_update(trans_matrix, MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov)))

In [None]:
show_as_row(*map(Self.shape(), _smooth_update(trans_matrix, MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov), MNormal(pred_state_mean, pred_state_cov))))

#### Smooth

In [None]:
#| export
def _smooth(trans_matrix, # `[n_dim_state, n_dim_state]`
            filt_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `filt_state_means[t]` is the state estimate for time t given obs from times `[0...t]`
            pred_state: ListMNormal, # `[n_timesteps, n_dim_state]`
                # `pred_state_means[t]` is the state estimate for time t given obs from times `[0...t-1]`
            cov_checker = CheckPosDef()
           ) -> ListMNormal: # `[n_timesteps, n_dim_state]` Smoothed state 
    """Apply the Kalman Smoother"""
    x = pred_state.mean # sample for getting tensor properties
    bs, n_timesteps, n_dim_state = x.shape[0], x.shape[1], x.shape[2]

    smoothed_state = ListMNormal(torch.zeros((bs, n_timesteps,n_dim_state,1),             dtype=x.dtype, device=x.device), 
                                 torch.zeros((bs, n_timesteps, n_dim_state,n_dim_state), dtype=x.dtype, device=x.device))
    # For the last timestep cannot use the smoother
    smoothed_state.mean[:,-1,] = filt_state.mean[:,-1]
    smoothed_state.cov[:,-1] = filt_state.cov[:,-1]

    for t in reversed(range(n_timesteps - 1)):
        (smoothed_state.mean[:,t], smoothed_state.cov[:,t]) = (
            _smooth_update(
                trans_matrix,
                filt_state[:,t],
                pred_state[:,t + 1],
                smoothed_state[:,t+1],
            )
        )
    return smoothed_state

In [None]:
(pred_state_means, pred_state_covs, filt_state_means, filt_state_covs ) = k._filter_all(data)
filt_state, pred_state = ListMNormal(filt_state_means, filt_state_covs), ListMNormal(pred_state_means, pred_state_covs)

In [None]:
smooth_state = _smooth(k.trans_matrix,  filt_state, pred_state)

In [None]:
show_as_row(smooth_state.mean[0][0], smooth_state.cov[0][0])

In [None]:
show_as_row(smooth_state.mean.shape, smooth_state.cov.shape)

#### KalmanFilter method

In [None]:
#| export
@patch
def smooth(self: KalmanFilter,
           obs: Tensor,
           mask: Tensor = None,
          ) -> ListMNormal: # `[n_timesteps, n_dim_state]` smoothed state
        
    """Kalman Filter Smoothing"""

    (pred_state_means, pred_state_covs, filt_state_means, filt_state_covs) = self._filter_all(obs, mask)

    smoothed_state = _smooth(self.trans_matrix,
                   ListMNormal(filt_state_means, filt_state_covs), ListMNormal(pred_state_means, pred_state_covs),
                   self.cov_checker)
    smoothed_state.mean.squeeze_(-1)
    return smoothed_state

In [None]:
smoothed_state = k.smooth(data)

In [None]:
show_as_row(smoothed_state.mean.shape, smoothed_state.cov.shape)

### Predict

In order to have conditional predictions that make sense it's not possible to return the full covariance matrix for the predictions but only the standard deviations

This add the supports for conditional predictions, which means that at the time (t) when we are making the predictions some of the variables have been actually observed. Since the model prediction is a normal distribution we can condition on the observed values and thus improve the predictions. See `conditional_gaussian`

In [None]:
test_m = torch.tensor(
    [[True, True, True,],
    [False, True, True],
    [False, False, False]]
)

In [None]:
torch.logical_xor(test_m.all(-1), test_m.any(-1))

tensor([False,  True, False])

In [None]:
A = torch.rand(2,2,3,3)

In [None]:
(A @ A).shape

torch.Size([2, 2, 3, 3])

predict can be vectorized across both the batch and the timesteps, except for timesteps that require conditional predictions

In [None]:
#| export
@patch
def _obs_from_state(self: KalmanFilter, state: ListMNormal):

    mean = self.obs_matrix @ state.mean.unsqueeze(-1) + self.obs_off
    cov = self.obs_matrix @ state.cov @ self.obs_matrix.mT + self.obs_cov
    
    self.cov_checker.check(cov, caller='predict')
    
    return ListMNormal(mean.squeeze(-1), cov)

In [None]:
pred_obs0 = k._obs_from_state(smoothed_state)
pred_obs0.mean.shape

torch.Size([2, 10, 3])

In [None]:
pred_obs0.cov.shape

torch.Size([2, 10, 3, 3])

In [None]:
#| export
@patch
def predict(self: KalmanFilter, obs, mask=None, smooth=True):
    """Predicted observations at all times """
    state = self.smooth(obs, mask) if smooth else self.filter(obs, mask)
    obs, mask = self._parse_obs(obs, mask)
    
    pred_obs = self._obs_from_state(state)
    # conditional predictions are slow, do only if some obs are missing 
    cond_mask = torch.logical_xor(mask.all(-1), mask.any(-1))
    
    # this cannot be batched so returns a list
    cond_preds = cond_gaussian_batched(
        pred_obs[cond_mask], obs[cond_mask], mask[cond_mask])
    
    pred_mean, pred_std = pred_obs.mean, cov2std(pred_obs.cov) # multiple [] still not properly implemented in ListMNormal
    
    for i, c_pred in enumerate(cond_preds):
        m = ~mask[cond_mask][i]
        pred_mean[cond_mask][i][m] = c_pred.mean
        pred_std [cond_mask][i][m] = cov2std(c_pred.cov)
    
    return ListNormal(pred_mean, pred_std)

In [None]:
pred = k.predict(data)

In [None]:
pred.mean.shape, pred.std.shape

(torch.Size([2, 10, 3]), torch.Size([2, 10, 3]))

In [None]:
state = k.smooth(data)

In [None]:
k.obs_matrix @ state.mean

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x10 and 4x3)

In [None]:
state.mean[0,0].shape

torch.Size([4])

In [None]:
state.mean.shape

torch.Size([2, 10, 4])

In [None]:
k.obs_matrix @ state.mean[0,0].unsqueeze(-1)

tensor([[0.6712],
        [0.6381],
        [0.2612]], dtype=torch.float64, grad_fn=<MmBackward0>)

In [None]:
k.obs_matrix @ state.mean[0,0]

tensor([0.6712, 0.6381, 0.2612], dtype=torch.float64, grad_fn=<MvBackward0>)

In [None]:
(k.obs_matrix @ state.mean.unsqueeze(-1)).shape

torch.Size([2, 10, 3, 1])

In [None]:
state.mean[0,0]

(torch.Size([3, 4]),
 tensor([[0.1567],
         [0.2523],
         [0.4789],
         [0.2779]], dtype=torch.float64, grad_fn=<UnsqueezeBackward0>))

In [None]:
k._obs_from_state(state)

ListMultiNormal(mean=tensor([[[ 6.7115e-01,  6.3813e-01,  2.6123e-01],
         [-2.1993e-01,  1.3188e-01,  4.4874e-01],
         [-1.3071e-01,  1.3131e-01,  6.4592e-02],
         [-1.5902e-01,  3.6143e-01, -2.9935e-02],
         [-2.4394e-01,  3.0209e-01, -3.5254e-04],
         [ 1.0110e-01,  5.5177e-01,  4.0669e-02],
         [-2.1634e-01,  5.8063e-01, -1.0452e-01],
         [-5.9903e-01,  9.2058e-02,  1.1058e-01],
         [-3.3310e-01, -3.8377e-01,  1.7389e-01],
         [-6.7295e-01, -3.5709e-01, -3.4562e-01]],

        [[-8.1683e-01, -4.6740e-01,  2.4914e-02],
         [-1.1661e-02, -1.1949e-01, -1.6048e-01],
         [ 3.0735e-01,  7.6889e-01,  9.2516e-02],
         [-6.3984e-01, -2.1339e-01,  1.1235e-01],
         [-5.3536e-01, -2.9153e-01, -2.0968e-01],
         [-1.6897e-01, -4.0205e-02, -1.4994e-01],
         [ 1.2525e-01,  5.1477e-01, -8.5688e-02],
         [-4.2193e-01,  7.8435e-02,  3.6673e-01],
         [-3.2254e-01, -4.1101e-01,  4.7112e-02],
         [-5.2914e-01,  1.2

In [None]:
k.smooth(data).mean.shape

torch.Size([2, 10, 4])

In [None]:
k.smooth(data).cov.shape

torch.Size([2, 10, 4, 4])

In [None]:
pred.mean.shape

torch.Size([2, 10, 3])

In [None]:
pred.std.shape

torch.Size([2, 10, 3])

In [None]:
k.predict(data).mean.sum().backward(retain_graph=True)

In [None]:
k.obs_cov_raw.grad

tensor([[ 37.0426,  27.7586,  14.0329],
        [-31.4088, -43.6932, -10.6448],
        [ 82.8144, -59.2140,  17.9308]], dtype=torch.float64)

In [None]:
k.trans_matrix.grad

tensor([[ -6.2729,  16.4753,   2.7112,   5.3477],
        [ 16.0184, -21.5304,   1.1225,   0.4614],
        [  2.7967,  -6.8813,   1.6863,   1.5134],
        [  0.9788, -15.3344,   1.1676,  -3.5403]], dtype=torch.float64)

Gradients ...

In [None]:
data[~mask] = 0

In [None]:
data

tensor([[[0.9847, 0.0852, 0.5334],
         [0.0000, 0.2617, 0.7972],
         [0.2088, 0.4545, 0.1455],
         [0.0000, 0.0000, 0.2881],
         [0.0000, 0.9087, 0.0000],
         [0.5610, 0.9079, 0.2507],
         [0.0000, 0.7851, 0.0212],
         [0.0000, 0.6513, 0.3955],
         [0.8111, 0.2558, 0.7570],
         [0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.2511, 0.4720],
         [0.6684, 0.0000, 0.1489],
         [0.6714, 0.4719, 0.5053],
         [0.0000, 0.7793, 0.3246],
         [0.0000, 0.0000, 0.0000],
         [0.8191, 0.7040, 0.3264],
         [0.0842, 0.0000, 0.0000],
         [0.0000, 0.3308, 0.7610],
         [0.3228, 0.0961, 0.3075],
         [0.0947, 0.4745, 0.0000]]])

In [None]:
k.predict(data, mask).mean.sum().backward(retain_graph=True)

print(k.obs_cov_raw.grad)

k.zero_grad()

tensor([[ 45.0106,  49.0797,  19.3835],
        [-37.1283, -66.3086, -15.8536],
        [ 92.5008, -79.6189,  18.7456]], dtype=torch.float64)


In [None]:
@patch
def predict_times(self: KalmanFilter, times, obs, mask=None, smooth=True, check_args=None):
    """Predicted observations at specific times """
    state = self.smooth(obs, mask, check_args) if smooth else self.filter(obs, mask, check_args)
    obs, mask = self._parse_obs(obs, mask)
    times = array1d(times)
    
    n_timesteps = obs.shape[0]
    n_features = obs.shape[1] if len(obs.shape) > 1 else 1
    
    if times.max() > n_timesteps or times.min() < 0:
        raise ValueError(f"provided times range from {times.min()} to {times.max()}, which is outside allowed range : 0 to {n_timesteps}")

    means = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device)
    stds = torch.empty((times.shape[0], n_features), dtype=obs.dtype, device=obs.device) 
    for i, t in enumerate(times):
        mean, std = self._obs_from_state(
            state.mean[t],
            state.cov[t],
            {'t': t, **check_args} if check_args is not None else None
        )
        
        means[i], stds[i] = _get_cond_pred(ListNormal(mean, std), obs[t], mask[t])
    
    return ListNormal(means, stds)  

### Get Info

In [None]:
k.obs_matrix

Parameter containing:
tensor([[0.2592, 0.9640, 0.3593, 0.7745],
        [0.9641, 0.6363, 0.6401, 0.0719],
        [0.1985, 0.0254, 0.0740, 0.6775]], dtype=torch.float64,
       requires_grad=True)

In [None]:
#| export
@patch
def get_info(self: KalmanFilter, var_names=None):
    out = {}
    var_names = ifnone(var_names, [f"x_{i}" for i in range(self.obs_matrix.shape[0])])
    latent_names = [f"z_{i}" for i in range(self.trans_matrix.shape[0])]
    out['trans_matrix (A)'] = array2df(self.trans_matrix,    latent_names, latent_names, 'latent')
    out['trans_cov (Q)']     = array2df(self.trans_cov,       latent_names, latent_names, 'latent')
    out['trans_off']        = array2df(self.trans_off,       latent_names, ['offset'],     'latent')
    out['obs_matrix (H)']    = array2df(self.obs_matrix,      var_names,    latent_names, 'variable')
    out['obs_cov (R)']       = array2df(self.obs_cov,         var_names,    var_names,    'variable')
    out['obs_off']          = array2df(self.obs_off,         var_names,    ['offset'],     'variable')
    out['init_state_mean']  = array2df(self.init_state_mean, latent_names, ['mean'],       'latent')
    out['init_state_cov']   = array2df(self.init_state_cov,  latent_names, latent_names, 'latent')
    
    return out

In [None]:
display_as_row(k.get_info())

latent,z_0,z_1,z_2,z_3
z_0,0.9959,0.6486,0.4152,0.8179
z_1,0.1256,0.5073,0.0909,0.9389
z_2,0.3924,0.7097,0.1217,0.2662
z_3,0.308,0.4001,0.9252,0.2889

latent,z_0,z_1,z_2,z_3
z_0,1.2095,1.3493,0.9478,1.0216
z_1,1.3493,1.811,0.8551,1.1573
z_2,0.9478,0.8551,1.2,1.1305
z_3,1.0216,1.1573,1.1305,1.3299

latent,offset
z_0,0.6283
z_1,0.3925
z_2,0.7199
z_3,0.8452

variable,z_0,z_1,z_2,z_3
x_0,0.2592,0.964,0.3593,0.7745
x_1,0.9641,0.6363,0.6401,0.0719
x_2,0.1985,0.0254,0.074,0.6775

variable,x_0,x_1,x_2
x_0,1.2377,1.0162,0.4392
x_1,1.0162,1.1789,0.2616
x_2,0.4392,0.2616,0.2088

variable,offset
x_0,0.9729
x_1,0.3114
x_2,0.5707

latent,mean
z_0,0.254
z_1,0.2803
z_2,0.2163
z_3,0.7995

latent,z_0,z_1,z_2,z_3
z_0,2.6027,1.115,1.6544,0.9013
z_1,1.115,1.2677,0.7199,1.0417
z_2,1.6544,0.7199,1.1375,0.4543
z_3,0.9013,1.0417,0.4543,1.0545


## Constructor Additional

#### Simple parameters

In [None]:
#| export
@patch(cls_method=True)
def init_simple(cls: KalmanFilter,
                n_dim, # n_dim_obs and n_dim_state
                dtype=torch.float32):
    """Simplest version of kalman filter parameters"""
    return cls(
        trans_matrix =     torch.eye(n_dim, dtype=dtype),
        trans_off =        torch.zeros(n_dim, dtype=dtype),        
        trans_cov =        torch.eye(n_dim, dtype=dtype),        
        obs_matrix =       torch.eye(n_dim, dtype=dtype),
        obs_off =          torch.zeros(n_dim, dtype=dtype),          
        obs_cov =          torch.eye(n_dim, dtype=dtype),            
        init_state_mean =  torch.zeros(n_dim, dtype=dtype),        
        init_state_cov =   torch.eye(n_dim, dtype=dtype),
    )

In [None]:
KalmanFilter.init_simple(2).state_dict()

OrderedDict([('trans_matrix',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('trans_off', tensor([0., 0.])),
             ('trans_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('obs_matrix',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('obs_off', tensor([0., 0.])),
             ('obs_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('init_state_mean', tensor([0., 0.])),
             ('init_state_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]]))])

#### Local slope

In [None]:
#| export
@patch(cls_method=True)
def init_local_slope(cls: KalmanFilter,
                n_dim, # n_dim_obs and n_dim_state
                dtype=torch.float32):
    """Simplest version of kalman filter parameters"""
    return cls(
        trans_matrix =     torch.eye(n_dim, dtype=dtype),
        trans_off =        torch.zeros(n_dim, dtype=dtype),        
        trans_cov =        torch.eye(n_dim, dtype=dtype),        
        obs_matrix =       torch.eye(n_dim, dtype=dtype),
        obs_off =          torch.zeros(n_dim, dtype=dtype),          
        obs_cov =          torch.eye(n_dim, dtype=dtype),            
        init_state_mean =  torch.zeros(n_dim, dtype=dtype),        
        init_state_cov =   torch.eye(n_dim, dtype=dtype),
    )

In [None]:
KalmanFilter.init_simple(2).state_dict()

OrderedDict([('trans_matrix',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('trans_off', tensor([0., 0.])),
             ('trans_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('obs_matrix',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('obs_off', tensor([0., 0.])),
             ('obs_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]])),
             ('init_state_mean', tensor([0., 0.])),
             ('init_state_cov_raw',
              tensor([[1., 0.],
                      [0., 1.]]))])

## Export

In [None]:
#| hide
from nbdev import nbdev_export
# nbdev_export()