# Fixed-lag smoothing

Some things are obvious in hindsight

In [19]:
import jax
import chex
import jax.numpy as jnp
import seaborn as sns
import matplotlib.pyplot as plt
from functools import partial

In [23]:
%config InlineBackend.figure_format = "retina"

sns.set_palette("colorblind")
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams["font.size"] = 12
# plt.rcParams["figure.figsize"] = (7.2, 4.0)
plt.rcParams["figure.figsize"] = (7.2, 3.0)

jnp.set_printoptions(linewidth=200)

## The equations for fixed-lag smoothing
$$
\begin{aligned}
    \theta_{t|t+k} &= \theta_{t|t} + \sum_{s=1}^k {\bf K}_{t,t+s}\,\varepsilon_{t+s}\\
    \Sigma_{t|t+k} &= \Sigma_{t|t} - \sum_{s=1}^k {\bf K}_{t,t+s}\,{\bf S}_{t+s}\,{\bf K}_{t,t+s}^\intercal\\
\end{aligned}
$$

with
$$
\begin{aligned}
    {\bf K}_{t,t+s} &= {\rm Cov}(\theta_t, \varepsilon_{t+s})\,{\bf S}_{t+s}^{-1}\\
    {\bf S}_{t+s} &= {\bf H}_{t+s}\,\Sigma_{t+s|t+s-1}\,{\bf H}_{t+s}^\intercal + {\bf R}_{t+s}\\
    {\rm Cov}(\theta_t, \varepsilon_{t+s}) &= \Sigma_{t|t-1}\,\overrightarrow{\bf M}_{t+1:t+s}\,{\bf H}_{t+s}\\
    \overrightarrow{\bf M}_{t+1:t+s} &= \prod_{\tau=1}^s {\bf M}_{t+\tau}\\
    {\bf M}_t &= {\bf F}_t\,\left({\bf I} - \Sigma_{t|t-1}\,{\bf H}_t^\intercal\,{\bf S}_{t}^{-1}\,{\bf H}_{t}\right)
\end{aligned}
$$

In [24]:
def kf_step(bel, y, H, F, R, Q):
    mu, Sigma = bel

    # Predict
    mu_pred = F @ mu
    Sigma_pred = F @ Sigma @ F.T + Q

    # innovation
    S = H @ Sigma_pred @ H.T + R
    K = jnp.linalg.solve(S, H @ Sigma_pred).T
    err =  y - H @ mu_pred # innovation

    # Update
    mu_update = mu_pred + K @ err
    Sigma_update = Sigma_pred - K @ S @ K.T
    bel_next = (mu_update, Sigma_update)

    out = {
        "mu": mu_update,
        "Sigma": Sigma_update,
        "err": err,
        "yhat": H @ mu_pred,
        "y_filter": H @ mu_update,
    }

    return bel_next, out

## Tracking problem