# Fixed-lag smoothing

In [4]:
import jax
import chex
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [5]:
%config InlineBackend.figure_format = "retina"

## The equations for fixed-lag smoothing
$$
\begin{aligned}
    \theta_{t|t+k} &= \theta_{t|t} + \sum_{s=1}^k {\bf K}_{t,t+s}\,\varepsilon_{t+s}\\
    \Sigma_{t|t+k} &= \Sigma_{t|t} - \sum_{s=1}^k {\bf K}_{t,t+s}\,{\bf S}_{t+s}\,{\bf K}_{t,t+s}^\intercal\\
    {\bf K}_{t,t+s} &= {\rm Cov}(\theta_t, \varepsilon_{t+s})\,{\bf S}_{t+s}^{-1}
\end{aligned}
$$

The main bulk of the computation is to estimate ${\rm Cov}(\theta_t, \varepsilon_{t+s})$, which takes the general form
$$
\begin{aligned}
    {\rm Cov}(\theta_t, \varepsilon_{t+s}) &= \Sigma_{t|t-1}\,\overrightarrow{\bf M}_{t+1:t+s}\,{\bf H}_{t+s}\\
    \overrightarrow{\bf M}_{t+1:t+s} &= \prod_{\tau=1}^s {\bf M}_{t+\tau}\\
    {\bf M}_t &= {\bf F}_t\,\left({\bf I} - \Sigma_{t|t-1}\,{\bf H}_t^\intercal\,{\bf S}_{t}^{-1}\,{\bf H}_{t}\right)
\end{aligned}
$$

In [6]:
def kf_step(bel, y, H, F, R, Q, n_forecast):
    mu, Sigma = bel

    # Predict
    mu_pred = F @ mu
    Sigma_pred = F @ Sigma @ F.T + Q

    # innovation
    S = H @ Sigma_pred @ H.T + R
    K = jnp.linalg.solve(S, H @ Sigma_pred).T
    err =  y - H @ mu_pred # innovation

    # Update
    mu_update = mu_pred + K @ err
    Sigma_update = Sigma_pred - K @ S @ K.T
    bel_next = (mu_update, Sigma_update)

    out = {
        "mu": mu_update,
        "Sigma": Sigma_update,
        "err": err,
        "yhat": H @ mu_pred,
        "y_filter": H @ mu_update,
        "y_forecast": forecast_obs_mean(F, H, mu_update, n_forecast),
    }

    return bel_next, out