# Probabilistic Machine Learning: Lecture 20 - Gauss-Markov Models

#### Introduction

Welcome to Lecture 20 of Probabilistic Machine Learning! This lecture introduces **Gauss-Markov Models**, a fundamental class of probabilistic models used for time series analysis. We will explore how conditional independence assumptions lead to computationally efficient inference algorithms like the Kalman Filter and the Rauch-Tung-Striebel Smoother. These algorithms are crucial for processing data that arrives as a stream, enabling real-time predictions and retrospective analysis.

This notebook will guide you through the theoretical concepts and provide practical implementations using **JAX** for numerical computations and **Plotly** for interactive visualizations.

#### 1. Goal for this Week: Time Series as a Problem Class

The overarching goal for this week is to understand **Time Series as a problem class** (Slide 2). We'll break this down into several layers:

* **Conceptual Layer**: Conditional Independence affects computational complexity of inference.
* **Application Layer**: Data arriving as a stream.
* **Model Structure Layer**: Markov Chains / Hidden Markov Models.
* **Concrete Model Layer**: Gauss-Markov Models.
* **Algorithm Layer**: Kalman Filter & RTS Smoother.

We'll also briefly touch upon generalizations to non-Gaussian models (with approximations) and the theoretical layer of Stochastic Differential Equations.

#### 2. Conditional Independence and Computational Complexity of Inference

As we've seen throughout this course (and recapped in Lecture 2, Slides 3-4), conditional independence plays a crucial role in determining the computational complexity of inference. For complex probabilistic models, exploiting independence structures can drastically reduce the computational burden.

* **Parametric Models (Graphical View)**: In parametric models, data points are conditionally independent given the model weights. This leads to inference complexity often scaling as $\mathcal{O}(NF^2 + F^3)$ for $F$ parameters and $N$ data points (Slide 5).
* **Nonparametric Models (Graphical View)**: In contrast, nonparametric models like Gaussian Processes typically have no finite sufficient statistic, meaning all data points directly interact. This results in inference complexity scaling as $\mathcal{O}(N^3)$ (Slide 6), which becomes prohibitive for large $N$.

This distinction highlights the challenge when dealing with time series, where $N$ can grow indefinitely.

#### 3. Time Series: $\mathcal{O}(N)$ Inference is Indispensable

When data arrives as a stream, we require inference that scales linearly with the number of observations, i.e., $\mathcal{O}(N)$ (Slide 7). This is the defining characteristic of a **time series** (Slide 8).

**Definition**: A time series is a sequence $[y(t_i)]_{i \in \mathbb{N}}$ of observations $y_i := x(t_i) \in \mathbb{Y}$, indexed by a scalar variable $t \in \mathbb{R}$. In many applications, the time points $t_j$ are equally spaced: $t_i = t_0 + i \cdot \delta_t$.

Examples of time series are ubiquitous:
* Climate & weather observations
* Sensor readings in cars
* EEG, ECG, patch clamp signals
* Stock prices, supply & demand data

Inference in time series often needs to happen in real-time and scale to an unbounded set of data, typically on small-scale or embedded systems. This necessitates (low) constant time and memory complexity per time step.

#### 4. Markov Chains: Processes with a "Local Memory"

To achieve $\mathcal{O}(N)$ inference, we need a model with "local memory" that is "passed forward" through time. This concept is formalized by **Markov Chains**, and associated models are called **State-Space Models** (Slide 9).

**Definition**: A joint distribution $p(X)$ over a sequence of random variables $X := [x_0, \dots, x_N]$ is said to have the **Markov property** if:
$$p(x_i | x_0, x_1, \dots, x_{i-1}) = p(x_i | x_{i-1})$$

The sequence is then called a Markov chain. This means the future state $x_i$ only depends on the immediate past state $x_{i-1}$, not on the entire history. This conditional independence structure is key to efficient inference.

In state-space models, we typically assume:
* **State transition**: $p(x_t | X_{0:t-1}) = p(x_t | x_{t-1})$ (the Markov property for the latent states).
* **Observation likelihood**: $p(y_t | X) = p(y_t | x_t)$ (observations depend only on the current latent state).

This structure allows for inference to be separated into three operations (Slides 15-22):

1.  **Predict (Chapman-Kolmogorov Equation)**: Propagating the belief about the state forward in time.
    $$p(x_t | Y_{0:t-1}) = \int p(x_t | x_{t-1}) p(x_{t-1} | Y_{0:t-1}) dx_{t-1}$$

2.  **Update (Bayes' Theorem)**: Incorporating a new observation $y_t$ to refine the belief about the current state.
    $$p(x_t | Y_{0:t}) = \frac{p(y_t | x_t) p(x_t | Y_{0:t-1})}{p(y_t)}$$

3.  **Smooth (Backward Pass)**: Refining past state estimates by incorporating future observations.
    $$p(x_t | Y) = p(x_t | Y_{0:t}) \int \frac{p(x_{t+1} | x_t) p(x_{t+1} | Y)}{p(x_{t+1} | Y_{0:t})} dx_{t+1}$$

Both filtering (predict + update) and smoothing can be performed in $\mathcal{O}(T)$ time, where $T$ is the number of time steps.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set JAX to use 64-bit floats for numerical stability
jax.config.update("jax_enable_x64", True)


# --- Utility Functions ---
def generate_linear_gaussian_data(
    T=100,
    state_dim=1,
    obs_dim=1,
    A=None,
    Q_std=0.1,
    H=None,
    R_std=0.1,
    m0=None,
    P0_std=1.0,
    key=None,
):
    """Generates synthetic data from a linear Gaussian state-space model."""
    if key is None:
        key = jax.random.PRNGKey(0)

    if A is None:
        A = jnp.eye(state_dim) * 0.9  # State transition matrix
    if H is None:
        H = jnp.eye(obs_dim, state_dim)  # Observation matrix

    Q = jnp.eye(state_dim) * Q_std**2  # State noise covariance
    R = jnp.eye(obs_dim) * R_std**2  # Observation noise covariance

    if m0 is None:
        m0 = jnp.zeros(state_dim)  # Initial state mean
    P0 = jnp.eye(state_dim) * P0_std**2  # Initial state covariance

    states = [m0]
    observations = []

    for t in range(T):
        key, subkey_state, subkey_obs = jax.random.split(key, 3)
        # State transition: x_t = A @ x_{t-1} + w_t, w_t ~ N(0, Q)
        state_noise = jax.random.multivariate_normal(
            subkey_state, jnp.zeros(state_dim), Q
        )
        next_state = A @ states[-1] + state_noise
        states.append(next_state)

        # Observation: y_t = H @ x_t + v_t, v_t ~ N(0, R)
        obs_noise = jax.random.multivariate_normal(subkey_obs, jnp.zeros(obs_dim), R)
        observation = H @ next_state + obs_noise
        observations.append(observation)

    return jnp.array(states[1:]), jnp.array(observations), A, Q, H, R, m0, P0


def plot_kalman_results(
    true_states,
    observations,
    filtered_means,
    filtered_stds,
    smoothed_means=None,
    smoothed_stds=None,
    title="",
    state_idx=0,
):
    """Plots true states, observations, and Kalman filter/smoother results for a single dimension."""
    fig = go.Figure()
    time_steps = jnp.arange(len(true_states))

    # True States
    fig.add_trace(
        go.Scatter(
            x=time_steps,
            y=true_states[:, state_idx],
            mode="lines",
            name="True State",
            line=dict(color="blue", width=2),
        )
    )

    # Observations
    fig.add_trace(
        go.Scatter(
            x=time_steps,
            y=observations[:, state_idx],
            mode="markers",
            name="Observations",
            marker=dict(color="red", size=4, opacity=0.6),
        )
    )

    # Filtered Mean
    fig.add_trace(
        go.Scatter(
            x=time_steps,
            y=filtered_means[:, state_idx],
            mode="lines",
            name="Filtered Mean",
            line=dict(color="green", width=1, dash="dash"),
        )
    )
    # Filtered Uncertainty
    fig.add_trace(
        go.Scatter(
            x=jnp.concatenate([time_steps, time_steps[::-1]]),
            y=jnp.concatenate(
                [
                    filtered_means[:, state_idx] + 2 * filtered_stds[:, state_idx],
                    (filtered_means[:, state_idx] - 2 * filtered_stds[:, state_idx])[
                        ::-1
                    ],
                ]
            ),
            fill="toself",
            fillcolor="rgba(0,255,0,0.1)",
            line_color="rgba(255,255,255,0)",
            name="Filtered 2 Std Dev",
            showlegend=False,
        )
    )

    # Smoothed Mean (if provided)
    if smoothed_means is not None:
        fig.add_trace(
            go.Scatter(
                x=time_steps,
                y=smoothed_means[:, state_idx],
                mode="lines",
                name="Smoothed Mean",
                line=dict(color="purple", width=1),
            )
        )
        # Smoothed Uncertainty
        fig.add_trace(
            go.Scatter(
                x=jnp.concatenate([time_steps, time_steps[::-1]]),
                y=jnp.concatenate(
                    [
                        smoothed_means[:, state_idx] + 2 * smoothed_stds[:, state_idx],
                        (
                            smoothed_means[:, state_idx]
                            - 2 * smoothed_stds[:, state_idx]
                        )[::-1],
                    ]
                ),
                fill="toself",
                fillcolor="rgba(128,0,128,0.1)",
                line_color="rgba(255,255,255,0)",
                name="Smoothed 2 Std Dev",
                showlegend=False,
            )
        )

    fig.update_layout(
        title_text=title,
        title_x=0.5,
        xaxis_title="Time Step",
        yaxis_title="State Value",
    )
    fig.show()


#### 5. Gauss-Markov Models: Linear Gaussian Case

If all relationships are **linear and Gaussian**, then inference (filtering and smoothing) becomes analytic and involves only linear algebra. This is the realm of **Gauss-Markov Models** (Slide 23).

The model is defined by:
* **State transition**: $p(x_t | x_{t-1}) = \mathcal{N}(x_t; A x_{t-1}, Q)$
* **Observation likelihood**: $p(y_t | x_t) = \mathcal{N}(y_t; H x_t, R)$
* **Initial state**: $p(x_0) = \mathcal{N}(x_0; m_0, P_0)$

Here:
* $A$: State transition matrix.
* $Q$: State noise covariance matrix.
* $H$: Observation matrix.
* $R$: Observation noise covariance matrix.
* $m_0, P_0$: Mean and covariance of the initial state.

The corresponding algorithms are:

### The Kalman Filter (Filtering: $\mathcal{O}(T)$)

The Kalman Filter performs the predict and update steps iteratively for each new observation. It maintains a Gaussian belief over the current state.

**Prediction Step (Time Update)** (Slide 24):
* Predictive mean: $m_t^- = A m_{t-1}$
* Predictive covariance: $P_t^- = A P_{t-1} A^T + Q$

**Update Step (Measurement Update)** (Slide 25):
* Innovation residual: $z = y_t - H m_t^-$
* Innovation covariance: $S = H P_t^- H^T + R$
* Kalman gain: $K = P_t^- H^T S^{-1}$
* Updated mean: $m_t = m_t^- + K z$
* Updated covariance: $P_t = (I - K H) P_t^-$

The overall complexity of the filtering pass through $T$ time steps is $\mathcal{O}(T \cdot (|X|^3 + |Y|^3))$, where $|X|$ is state dimension and $|Y|$ is observation dimension (Slide 27).

In [None]:
# --- Kalman Filter Implementation ---


@jax.jit
def kalman_filter(observations, A, Q, H, R, m0, P0):
    """
    Implements the Kalman Filter for a linear Gaussian state-space model.

    Args:
        observations (jnp.ndarray): Sequence of observations (T, obs_dim).
        A (jnp.ndarray): State transition matrix (state_dim, state_dim).
        Q (jnp.ndarray): State noise covariance (state_dim, state_dim).
        H (jnp.ndarray): Observation matrix (obs_dim, state_dim).
        R (jnp.ndarray): Observation noise covariance (obs_dim, obs_dim).
        m0 (jnp.ndarray): Initial state mean (state_dim,).
        P0 (jnp.ndarray): Initial state covariance (state_dim, state_dim).

    Returns:
        tuple: (filtered_means, filtered_covs, predicted_means, predicted_covs)
    """
    T = observations.shape[0]
    state_dim = m0.shape[0]

    filtered_means = jnp.zeros((T, state_dim))
    filtered_covs = jnp.zeros((T, state_dim, state_dim))
    predicted_means = jnp.zeros((T, state_dim))
    predicted_covs = jnp.zeros((T, state_dim, state_dim))

    m_prev = m0
    P_prev = P0

    for t in range(T):
        # Prediction Step (Time Update)
        m_minus = A @ m_prev  # Predictive mean
        P_minus = A @ P_prev @ A.T + Q  # Predictive covariance

        # Store predicted values
        predicted_means = predicted_means.at[t].set(m_minus)
        predicted_covs = predicted_covs.at[t].set(P_minus)

        # Update Step (Measurement Update)
        z = observations[t] - H @ m_minus  # Innovation residual
        S = H @ P_minus @ H.T + R  # Innovation covariance
        K = P_minus @ H.T @ jnp.linalg.inv(S)  # Kalman gain

        m_t = m_minus + K @ z  # Updated mean
        P_t = (jnp.eye(state_dim) - K @ H) @ P_minus  # Updated covariance

        # Store filtered values
        filtered_means = filtered_means.at[t].set(m_t)
        filtered_covs = filtered_covs.at[t].set(P_t)

        # Prepare for next iteration
        m_prev = m_t
        P_prev = P_t

    return filtered_means, filtered_covs, predicted_means, predicted_covs


## Understanding the Kalman Filter and Its Connection to State Space Models

### What is a State Space Model?

A **state space model** is a mathematical framework for modeling time series data where we assume there is an underlying (possibly unobserved) process, called the **state**, that evolves over time and generates the observed data. The model is defined by two equations:

1. **State Transition (Dynamics):**
    $$
    x_t = A x_{t-1} + w_t, \quad w_t \sim \mathcal{N}(0, Q)
    $$
    - $x_t$: The hidden (latent) state at time $t$.
    - $A$: State transition matrix (how the state evolves).
    - $w_t$: Process noise (randomness in the evolution), Gaussian with covariance $Q$.

2. **Observation (Measurement):**
    $$
    y_t = H x_t + v_t, \quad v_t \sim \mathcal{N}(0, R)
    $$
    - $y_t$: The observed data at time $t$.
    - $H$: Observation matrix (how the state is mapped to observations).
    - $v_t$: Observation noise, Gaussian with covariance $R$.

This structure is called a **Linear Gaussian State Space Model** (or Gauss-Markov Model).

---

### What is the Kalman Filter?

The **Kalman Filter** is an algorithm for **sequentially estimating** the hidden state $x_t$ of a linear Gaussian state space model, given a sequence of noisy observations $y_{1:t}$. It is optimal (in the mean squared error sense) when the model is linear and all noise is Gaussian.

#### Key Ideas:
- **Recursive:** The filter updates its estimate as each new observation arrives, without needing to store all past data.
- **Probabilistic:** It maintains a Gaussian belief (mean and covariance) about the current state.
- **Efficient:** Each update is $\mathcal{O}(1)$ in time and memory per step (for fixed state dimension).

---

### How Does the Kalman Filter Work?

At each time step, the Kalman filter performs two main operations:

1. **Prediction (Time Update):**
    - Use the previous state estimate to predict the current state **before** seeing the new observation.
    - Equations:
      $$
      m_t^- = A m_{t-1}
      $$
      $$
      P_t^- = A P_{t-1} A^T + Q
      $$
      Where $m_{t-1}$ and $P_{t-1}$ are the mean and covariance of the previous state estimate.

2. **Update (Measurement Update):**
    - Incorporate the new observation $y_t$ to refine the state estimate.
    - Equations:
      $$
      z_t = y_t - H m_t^- \quad \text{(innovation)}
      $$
      $$
      S_t = H P_t^- H^T + R \quad \text{(innovation covariance)}
      $$
      $$
      K_t = P_t^- H^T S_t^{-1} \quad \text{(Kalman gain)}
      $$
      $$
      m_t = m_t^- + K_t z_t \quad \text{(updated mean)}
      $$
      $$
      P_t = (I - K_t H) P_t^- \quad \text{(updated covariance)}
      $$

This process repeats for each new observation.

---

### Why is the Kalman Filter Important for Time Series Forecasting?

- **Real-Time Inference:** The Kalman filter is ideal for streaming data, as it updates estimates on-the-fly.
- **Forecasting:** The prediction step provides a forecast of the next state (and thus the next observation) before the new data arrives.
- **Uncertainty Quantification:** The filter not only gives a point estimate but also quantifies uncertainty via the covariance.
- **Optimality:** For linear-Gaussian models, no other filter can do better in terms of mean squared error.

---

### Intuitive Example

Suppose you are tracking the position of a car using noisy GPS measurements. The true position is the hidden state $x_t$, and the GPS reading is $y_t$. The Kalman filter combines your knowledge of how the car moves (the dynamics) and the noisy measurements to give the best possible estimate of the car's position at each time.

---

### Summary Table

| Step         | What it does                        | Formula (Mean)         | Formula (Covariance)        |
|--------------|-------------------------------------|------------------------|-----------------------------|
| **Predict**  | Forecast next state                 | $A m_{t-1}$            | $A P_{t-1} A^T + Q$         |
| **Update**   | Correct with new observation        | $m_t^- + K_t z_t$      | $(I - K_t H) P_t^-$         |

---

### Relation to State Space Models

- The Kalman filter is the **inference algorithm** for linear Gaussian state space models.
- It exploits the **Markov property** (future depends only on present) and **Gaussianity** (all distributions remain Gaussian).
- For **nonlinear** or **non-Gaussian** models, extensions like the Extended Kalman Filter (EKF) or Particle Filter are used.

---

### Further Reading

- [Wikipedia: Kalman Filter](https://en.wikipedia.org/wiki/Kalman_filter)
- [Probabilistic Machine Learning: An Introduction (Kevin Murphy), Chapter 18](https://probml.github.io/pml-book/book2.html#kalman-filter)

## Relation of the Kalman Filter (KMF) to Other State Space Models (e.g., ETS)

### State Space Models: A General Framework

The **Kalman Filter** is an inference algorithm for **linear Gaussian state space models** (also called Gauss-Markov models). Here, "inference algorithm" means a method for **estimating the hidden (latent) states** of the model, given the observed data. This is different from "statistical inference" in the classical sense (such as hypothesis testing or parameter estimation).

- In the context of state space models, **inference** refers to the process of computing the probability distribution (or point estimates, such as the mean and covariance) of the hidden state variables at each time step, given all the observations up to that point (filtering), or all observations in the sequence (smoothing).
- The Kalman filter provides an efficient, recursive way to perform this state estimation for linear-Gaussian models.
- In contrast, **statistical inference** usually refers to learning the model parameters (like $A$, $Q$, $H$, $R$) from data, or testing hypotheses about them.

So, in summary:  
- **Kalman filter as an inference algorithm**: Computes the best estimate of the hidden states over time, given the model parameters and observed data.
- **Statistical inference**: Typically refers to learning or testing about the model parameters themselves.

In state space modeling, both types of inference are important, but the Kalman filter specifically addresses the problem of **state estimation** (sometimes called "latent variable inference" or "filtering/smoothing"), not parameter learning.

- **ARMA/ARIMA models**
- **Exponential Smoothing (ETS) models**
- **Structural Time Series models** (level, trend, seasonality, regression, etc.)
- **Dynamic Linear Models (DLMs)**
- **Nonlinear and non-Gaussian models** (with extensions like the Extended Kalman Filter or Particle Filters)

All these models can be written in the general state space form:
- **State equation:** $x_t = f(x_{t-1}) + w_t$
- **Observation equation:** $y_t = h(x_t) + v_t$
where $w_t$ and $v_t$ are noise terms.

---

### ETS Models as State Space Models

**ETS** stands for **Error, Trend, Seasonality**. ETS models (such as Holt-Winters exponential smoothing) are widely used for time series forecasting and can be written as state space models.

#### Example: Local Level + Trend + Seasonality

Suppose we want to model a time series with:
- **Level** ($\ell_t$): the baseline value
- **Trend** ($b_t$): the slope or growth rate
- **Seasonality** ($s_t$): repeating patterns (e.g., yearly, weekly)

A typical **additive ETS state space model** is:
$$
\begin{align*}
\text{State vector:} \quad & x_t = \begin{bmatrix} \ell_t \\ b_t \\ s_t \end{bmatrix} \\
\text{State transition:} \quad & x_t = A x_{t-1} + w_t \\
\text{Observation:} \quad & y_t = H x_t + v_t
\end{align*}
$$
where $A$ and $H$ are designed to encode how level, trend, and seasonality evolve and contribute to the observation.

#### Example: Local Linear Trend + Seasonality

For a time series with level, trend, and $S$-period seasonality:
- **State vector:** $x_t = [\ell_t, b_t, s_{t,1}, ..., s_{t,S-1}]^T$
- **State transition matrix $A$:** updates level, trend, and rotates seasonal states
- **Observation matrix $H$:** picks out the relevant components

This can be written in the same form as the Kalman filter, and the Kalman filter can be used for inference if the noise is Gaussian.

---

### How to Model Level, Trend, and Seasonality in State Space

- **Level:** Add a state variable $\ell_t$ that evolves over time (e.g., random walk: $\ell_t = \ell_{t-1} + w_t$).
- **Trend:** Add a state variable $b_t$ for the slope, with its own evolution (e.g., $b_t = b_{t-1} + w_t^{(b)}$).
- **Seasonality:** Add $S-1$ state variables for seasonal effects, updated cyclically.

**State vector example for additive model:**
$$
x_t = \begin{bmatrix}
\ell_t \\
b_t \\
s_{t,1} \\
\vdots \\
s_{t,S-1}
\end{bmatrix}
$$

**State transition matrix $A$** and **observation matrix $H$** are constructed to reflect the desired dynamics.

---

### Summary Table

| Model Type         | State Space Form? | Kalman Filter Applicable? | Notes |
|--------------------|------------------|--------------------------|-------|
| ARMA/ARIMA         | Yes              | Yes (linear, Gaussian)   | Special case of state space |
| ETS (Exponential Smoothing) | Yes      | Yes (linear, Gaussian)   | Level, trend, seasonality as states |
| Structural Time Series | Yes           | Yes                      | Flexible, interpretable components |
| Nonlinear/Non-Gaussian | Yes          | No (use EKF, Particle Filter) | Need approximate inference |

---

### Further Reading

- [Durbin & Koopman, "Time Series Analysis by State Space Methods"](https://www.worldcat.org/title/1029823142)
- [Hyndman et al., "Forecasting: Principles and Practice" (ETS models)](https://otexts.com/fpp3/ets.html)
- [Wikipedia: State Space Representation](https://en.wikipedia.org/wiki/State-space_representation_(controls))

---

### Example: State Space Formulation for ETS(A,A,A) (Additive Error, Trend, Seasonality)

$$
\begin{align*}
\ell_t &= \ell_{t-1} + b_{t-1} + \alpha e_t \\
b_t &= b_{t-1} + \beta e_t \\
s_t &= s_{t-m} + \gamma e_t \\
y_t &= \ell_{t-1} + b_{t-1} + s_{t-m} + e_t
\end{align*}
$$
where $e_t$ is the error term, $m$ is the seasonal period, and $\alpha, \beta, \gamma$ are smoothing parameters.

This can be written in the general state space form and solved with the Kalman filter if $e_t$ is Gaussian.

---

**In summary:**  
The Kalman filter is a special case of the general state space approach, and many popular time series models (including ETS) can be written as state space models with appropriate state vectors and transition/observation matrices. This allows you to model not just the level, but also trend and seasonality, and to use the Kalman filter for efficient inference and forecasting.

### The Rauch-Tung-Striebel (RTS) Smoother (Smoothing: $\mathcal{O}(T)$)

The RTS Smoother performs a backward pass through the filtered estimates to improve the state estimates by incorporating information from future observations. It provides a more accurate estimate of the state at each time step than the filter alone.

**Smoothing Step** (Slide 26):
* Smoother gain: $G_t = P_t A^T (P_{t+1}^-)^{-1}$
* Smoothed mean: $m_t^s = m_t + G_t (m_{t+1}^s - m_{t+1}^-)$
* Smoothed covariance: $P_t^s = P_t + G_t (P_{t+1}^s - P_{t+1}^-) G_t^T$

The overall complexity of the smoothing pass is $\mathcal{O}(T \cdot |X|^3)$ (Slide 28).

In [None]:
# --- Rauch-Tung-Striebel (RTS) Smoother Implementation ---


@jax.jit
def rts_smoother(filtered_means, filtered_covs, predicted_means, predicted_covs, A, Q):
    """
    Implements the Rauch-Tung-Striebel (RTS) Smoother.

    Args:
        filtered_means (jnp.ndarray): Means from the Kalman filter (T, state_dim).
        filtered_covs (jnp.ndarray): Covariances from the Kalman filter (T, state_dim, state_dim).
        predicted_means (jnp.ndarray): Predicted means from the Kalman filter (T, state_dim).
        predicted_covs (jnp.ndarray): Predicted covariances from the Kalman filter (T, state_dim, state_dim).
        A (jnp.ndarray): State transition matrix (state_dim, state_dim).
        Q (jnp.ndarray): State noise covariance (state_dim, state_dim).

    Returns:
        tuple: (smoothed_means, smoothed_covs)
    """
    T = filtered_means.shape[0]
    state_dim = filtered_means.shape[1]

    smoothed_means = jnp.copy(filtered_means)
    smoothed_covs = jnp.copy(filtered_covs)

    for t in reversed(range(T - 1)):
        # Smoother Gain
        # P_t_plus_1_minus_inv = jnp.linalg.inv(predicted_covs[t+1]) # Direct inverse, can be unstable
        # G_t = filtered_covs[t] @ A.T @ P_t_plus_1_minus_inv

        # More stable way to compute G_t using solve
        G_t = jnp.linalg.solve(
            predicted_covs[t + 1].T, (filtered_covs[t] @ A.T).T
        ).T  # (P_t_plus_1_minus)^-1 * (P_t * A.T)

        # Smoothed Mean
        smoothed_means = smoothed_means.at[t].set(
            filtered_means[t] + G_t @ (smoothed_means[t + 1] - predicted_means[t + 1])
        )

        # Smoothed Covariance
        smoothed_covs = smoothed_covs.at[t].set(
            filtered_covs[t]
            + G_t @ (smoothed_covs[t + 1] - predicted_covs[t + 1]) @ G_t.T
        )

    return smoothed_means, smoothed_covs


## Understanding the Rauch-Tung-Striebel (RTS) Smoother

### What is Smoothing in State Space Models?

In time series models, **filtering** and **smoothing** are two different inference tasks:

- **Filtering**: At each time step $t$, estimate the hidden state $x_t$ using all observations up to and including time $t$ (i.e., $y_{1:t}$). This is what the **Kalman filter** does.
    - **Filtered estimate:** $p(x_t \mid y_{1:t})$
    - **Use case:** Real-time estimation, forecasting, control.

- **Smoothing**: After collecting the entire sequence of observations $y_{1:T}$, estimate the hidden state $x_t$ at each time step using **all** observations, both past and future.
    - **Smoothed estimate:** $p(x_t \mid y_{1:T})$
    - **Use case:** Retrospective analysis, denoising, signal reconstruction.

**Key Difference:**  
- The filter only uses information up to the current time step, while the smoother uses the entire dataset, including future observations, to refine the estimate at each time.

---

### Why is Smoothing More Accurate?

- **Filtering** is causal: it cannot "see the future." Its estimate at time $t$ is optimal given only $y_{1:t}$.
- **Smoothing** is acausal: it can use all data, including $y_{t+1}, y_{t+2}, ..., y_T$. This extra information allows it to correct or refine earlier state estimates, often reducing uncertainty and error.

---

### The Rauch-Tung-Striebel (RTS) Smoother

The **RTS smoother** is an efficient algorithm for computing the smoothed state estimates in linear Gaussian state space models (i.e., Gauss-Markov models). It works in two stages:

1. **Forward Pass (Filtering):** Run the Kalman filter to compute $p(x_t \mid y_{1:t})$ for all $t$.
2. **Backward Pass (Smoothing):** Starting from the last time step, recursively refine the state estimates using future information.

#### RTS Smoother Equations

Let:
- $m_t$ and $P_t$ be the filtered mean and covariance at time $t$ (from the Kalman filter).
- $m_{t+1}^-$ and $P_{t+1}^-$ be the predicted mean and covariance for $x_{t+1}$ given $y_{1:t}$.
- $m_{t+1}^s$ and $P_{t+1}^s$ be the smoothed mean and covariance at time $t+1$.

The RTS smoother computes, for $t = T-1, ..., 0$:
- **Smoother gain:**  
  $$
  G_t = P_t A^\top (P_{t+1}^-)^{-1}
  $$
- **Smoothed mean:**  
  $$
  m_t^s = m_t + G_t (m_{t+1}^s - m_{t+1}^-)
  $$
- **Smoothed covariance:**  
  $$
  P_t^s = P_t + G_t (P_{t+1}^s - P_{t+1}^-) G_t^\top
  $$

**Intuition:**  
- The smoother gain $G_t$ determines how much to adjust the filtered estimate at time $t$ based on the difference between the smoothed and predicted state at $t+1$.
- If the future data suggests that the prediction at $t+1$ was off, the smoother "corrects" the earlier state accordingly.

---

### Kalman Filter vs. RTS Smoother: Summary Table

| Task      | Estimate                | Uses Future Data? | Typical Use Case         |
|-----------|-------------------------|-------------------|-------------------------|
| Filter    | $p(x_t \mid y_{1:t})$   | No                | Real-time, forecasting  |
| Smoother  | $p(x_t \mid y_{1:T})$   | Yes               | Retrospective analysis  |

- **Kalman filter**: Fast, online, but cannot use information from after time $t$.
- **RTS smoother**: Two-pass (forward + backward), uses all data, gives more accurate and less uncertain state estimates.

---

### Visual Example

- In the plot above, the **green dashed line** (filtered mean) is the Kalman filter's estimate at each time, using only past and present data.
- The **purple line** (smoothed mean) is the RTS smoother's estimate, which is typically closer to the true state and has smaller uncertainty, because it uses the entire sequence of observations.

---

### Further Reading

- [Wikipedia: Kalman Smoother](https://en.wikipedia.org/wiki/Kalman_filter#Fixed-interval_smoothers)
- [Murphy, "Probabilistic Machine Learning: An Introduction", Section 18.4](https://probml.github.io/pml-book/book2.html#kalman-smoother)

### Example: Applying Kalman Filter and RTS Smoother

Let's generate some synthetic 1D time series data and apply the Kalman Filter and RTS Smoother to estimate the underlying true state.

In [None]:
# --- Main Execution: Kalman Filter and RTS Smoother Example ---

# Generate data
T = 100  # Number of time steps
state_dim = 1  # 1D state
obs_dim = 1  # 1D observation

A = jnp.array([[0.98]])  # Slightly decaying state
Q_std = 0.1  # Standard deviation of state noise
H = jnp.array([[1.0]])  # Direct observation of state
R_std = 0.5  # Standard deviation of observation noise
m0 = jnp.array([0.0])  # Initial state mean
P0_std = 1.0  # Initial state standard deviation

key = jax.random.PRNGKey(50)
true_states, observations, A, Q, H, R, m0, P0 = generate_linear_gaussian_data(
    T=T,
    state_dim=state_dim,
    obs_dim=obs_dim,
    A=A,
    Q_std=Q_std,
    H=H,
    R_std=R_std,
    m0=m0,
    P0_std=P0_std,
    key=key,
)

print("\n--- Running Kalman Filter ---")
filtered_means, filtered_covs, predicted_means, predicted_covs = kalman_filter(
    observations, A, Q, H, R, m0, P0
)
filtered_stds = jnp.sqrt(jnp.diagonal(filtered_covs, axis1=1, axis2=2))

print("\n--- Running RTS Smoother ---")
smoothed_means, smoothed_covs = rts_smoother(
    filtered_means, filtered_covs, predicted_means, predicted_covs, A, Q
)
smoothed_stds = jnp.sqrt(jnp.diagonal(smoothed_covs, axis1=1, axis2=2))

# Plotting results for the first (and only) state dimension
plot_kalman_results(
    true_states,
    observations,
    filtered_means,
    filtered_stds,
    smoothed_means,
    smoothed_stds,
    title="Kalman Filter and RTS Smoother Results (1D State)",
)


In the plot above, you can observe:
* **True State (Blue)**: The actual underlying process.
* **Observations (Red Markers)**: Noisy measurements of the true state.
* **Filtered Mean (Green Dashed Line)**: The Kalman filter's estimate of the state at each time step, using all observations up to that point. The green shaded area represents its uncertainty (2 standard deviations).
* **Smoothed Mean (Purple Line)**: The RTS smoother's estimate, which is generally closer to the true state and has smaller uncertainty (purple shaded area) because it uses *all* available observations (past and future) for each time step's estimate.

In [None]:
import jax
from functools import partial

# Example: State Space Model with Level, Trend, and Seasonality (1D observation)
# We'll use a local linear trend + seasonal model, and apply Kalman Filter and RTS Smoother.

import jax.numpy as jnp

# Parameters
T = 120  # Number of time steps
season_period = 12  # e.g., monthly seasonality
state_dim = 2 + (season_period - 1)  # [level, trend, season_1, ..., season_{S-1}]
obs_dim = 1

# State transition matrix A
A = jnp.zeros((state_dim, state_dim))
# Level update: level_t = level_{t-1} + trend_{t-1}
A = A.at[0, 0].set(1.0)
A = A.at[0, 1].set(1.0)
# Trend update: trend_t = trend_{t-1}
A = A.at[1, 1].set(1.0)
# Seasonality update: rotate seasonal components
for i in range(2, state_dim - 1):
    A = A.at[i, i + 1].set(1.0)
A = A.at[state_dim - 1, 2].set(1.0)  # wrap-around

# Observation matrix H: observe level + current season
H = jnp.zeros((obs_dim, state_dim))
H = H.at[0, 0].set(1.0)  # level
H = H.at[0, 2].set(1.0)  # current season

# Noise covariances
Q_std_level = 0.05
Q_std_trend = 0.01
Q_std_season = 0.01
Q = jnp.diag(
    jnp.array(
        [Q_std_level**2, Q_std_trend**2] + [Q_std_season**2] * (season_period - 1)
    )
)
R_std = 0.3
R = jnp.eye(obs_dim) * R_std**2

# Initial state
m0 = jnp.zeros(state_dim)
m0 = m0.at[0].set(2.0)  # initial level
m0 = m0.at[1].set(0.1)  # initial trend
# Initial seasonality: sum to zero constraint (last component is -sum(others))
season_init = jnp.array([0.5, -0.3, 0.2, -0.4, 0.1, 0.3, -0.2, 0.0, 0.1, -0.1, -0.2])
m0 = m0.at[2:].set(season_init)
P0 = jnp.eye(state_dim) * 1.0


# Generate synthetic data
def generate_trend_season_data(T, A, Q, H, R, m0, key):
    state_dim = A.shape[0]
    obs_dim = H.shape[0]
    states = [m0]
    observations = []
    key = key
    for t in range(T):
        key, subkey1, subkey2 = jax.random.split(key, 3)
        state_noise = jax.random.multivariate_normal(subkey1, jnp.zeros(state_dim), Q)
        next_state = A @ states[-1] + state_noise
        # Enforce sum-to-zero constraint for seasonality
        season_sum = jnp.sum(next_state[2:])
        next_state = next_state.at[2:].add(-season_sum / (season_period - 1))
        obs_noise = jax.random.multivariate_normal(subkey2, jnp.zeros(obs_dim), R)
        observation = H @ next_state + obs_noise
        states.append(next_state)
        observations.append(observation)
        # Rotate H to observe the correct seasonal component
        H = jnp.roll(H, shift=1, axis=1)
    return jnp.array(states[1:]), jnp.array(observations)


key = jax.random.PRNGKey(123)
true_states, observations = generate_trend_season_data(T, A, Q, H, R, m0, key)


# Because H rotates, we need to reconstruct the correct H for each time step for filtering
def get_H_sequence(obs_dim, state_dim, season_period, T):
    H_seq = []
    H = jnp.zeros((obs_dim, state_dim))
    H = H.at[0, 0].set(1.0)
    H = H.at[0, 2].set(1.0)
    for t in range(T):
        H_seq.append(H)
        H = jnp.roll(H, shift=1, axis=1)
    return jnp.stack(H_seq)


H_seq = get_H_sequence(obs_dim, state_dim, season_period, T)


# Kalman filter for time-varying H
def kalman_filter_timevarying_H(observations, A, Q, H_seq, R, m0, P0):
    T = observations.shape[0]
    state_dim = m0.shape[0]
    filtered_means = jnp.zeros((T, state_dim))
    filtered_covs = jnp.zeros((T, state_dim, state_dim))
    predicted_means = jnp.zeros((T, state_dim))
    predicted_covs = jnp.zeros((T, state_dim, state_dim))
    m_prev = m0
    P_prev = P0
    for t in range(T):
        H = H_seq[t]
        # Prediction
        m_minus = A @ m_prev
        P_minus = A @ P_prev @ A.T + Q
        predicted_means = predicted_means.at[t].set(m_minus)
        predicted_covs = predicted_covs.at[t].set(P_minus)
        # Update
        z = observations[t] - H @ m_minus
        S = H @ P_minus @ H.T + R
        K = P_minus @ H.T @ jnp.linalg.inv(S)
        m_t = m_minus + K @ z
        P_t = (jnp.eye(state_dim) - K @ H) @ P_minus
        filtered_means = filtered_means.at[t].set(m_t)
        filtered_covs = filtered_covs.at[t].set(P_t)
        m_prev = m_t
        P_prev = P_t
    return filtered_means, filtered_covs, predicted_means, predicted_covs


filtered_means, filtered_covs, predicted_means, predicted_covs = (
    kalman_filter_timevarying_H(observations, A, Q, H_seq, R, m0, P0)
)
filtered_stds = jnp.sqrt(jnp.diagonal(filtered_covs, axis1=1, axis2=2))

# RTS smoother (A and Q are constant)


def rts_smoother_timevarying_A(
    filtered_means, filtered_covs, predicted_means, predicted_covs, A, Q
):
    T = filtered_means.shape[0]
    state_dim = filtered_means.shape[1]
    smoothed_means = jnp.copy(filtered_means)
    smoothed_covs = jnp.copy(filtered_covs)
    for t in reversed(range(T - 1)):
        G_t = jnp.linalg.solve(predicted_covs[t + 1].T, (filtered_covs[t] @ A.T).T).T
        smoothed_means = smoothed_means.at[t].set(
            filtered_means[t] + G_t @ (smoothed_means[t + 1] - predicted_means[t + 1])
        )
        smoothed_covs = smoothed_covs.at[t].set(
            filtered_covs[t]
            + G_t @ (smoothed_covs[t + 1] - predicted_covs[t + 1]) @ G_t.T
        )
    return smoothed_means, smoothed_covs


smoothed_means, smoothed_covs = rts_smoother_timevarying_A(
    filtered_means, filtered_covs, predicted_means, predicted_covs, A, Q
)
smoothed_stds = jnp.sqrt(jnp.diagonal(smoothed_covs, axis1=1, axis2=2))

# Plot: show observed, true, filtered, and smoothed for the level (state_idx=0)
plot_kalman_results(
    true_states,
    observations,
    filtered_means,
    filtered_stds,
    smoothed_means,
    smoothed_stds,
    title="State Space Model with Trend and Seasonality: Kalman Filter & RTS Smoother",
    state_idx=0,
)


### Interpreting the Results of State Estimation from the Kalman Filter (KF) and RTS Smoother

When you run the Kalman Filter (KF) and Rauch-Tung-Striebel (RTS) Smoother on a state space model, you obtain **estimates of the hidden (latent) states** at each time step, not the model parameters themselves. Here’s how to interpret these results:

---

#### 1. **State Estimates vs. Parameter Estimates**

- **State Estimates** (`filtered_means`, `smoothed_means`):  
    These are the best guesses (means) of the hidden state vector at each time step, given the observations and the current model parameters ($A$, $Q$, $H$, $R$, $m_0$, $P_0$).
        - **Kalman Filter:** $p(x_t \mid y_{1:t})$ — uses only past and present data.
        - **RTS Smoother:** $p(x_t \mid y_{1:T})$ — uses all data (past and future).

- **Parameter Estimates**:  
    The matrices $A$, $Q$, $H$, $R$, $m_0$, $P_0$ are typically **assumed known** when running the filter/smoother. If you want to fit (learn) these parameters from data, you need to use a separate procedure (e.g., Expectation-Maximization, maximum likelihood, or Bayesian inference).

---

#### 2. **What Do the State Estimates Tell Us?**

- **Filtered Means (`filtered_means`)**:  
    At each time $t$, this is your best estimate of the state, using all data up to $t$. Useful for real-time tracking and forecasting.

- **Smoothed Means (`smoothed_means`)**:  
    At each time $t$, this is your best estimate of the state, using all data in the sequence. Useful for retrospective analysis, denoising, and signal reconstruction.

- **Uncertainties (`filtered_stds`, `smoothed_stds`)**:  
    The standard deviations (square roots of the diagonal of the covariance matrices) quantify your uncertainty about each state component at each time step.

---

#### 3. **How to Use These Results**

- **Interpretation**:  
    - The estimated states can be interpreted according to your model structure. For example, in a local trend + seasonality model:
        - The first state might represent the **level** (baseline).
        - The second state might represent the **trend** (slope).
        - The remaining states might represent **seasonal effects**.
    - By plotting these components, you can see how the underlying process evolves over time, separated from noise.

- **Diagnostics**:  
    - If the filter/smoother estimates track the true states well (in simulation), your model is likely well-specified.
    - Large uncertainties or systematic deviations may indicate model mismatch or the need for parameter tuning.

---

#### 4. **Parameter Fitting (Learning)**

- **KF/RTS do not fit parameters**:  
    The Kalman Filter and RTS Smoother **assume** the model parameters are known. They only estimate the hidden states.

- **Parameter Learning**:  
    To fit parameters ($A$, $Q$, $H$, $R$, etc.) from data, you typically use:
        - **Maximum Likelihood Estimation (MLE)**: Find parameters that maximize the likelihood of the observed data, often using the Expectation-Maximization (EM) algorithm.
        - **Bayesian Inference**: Place priors on parameters and infer their posterior distributions.
    - After fitting, you can rerun the KF/RTS with the learned parameters for improved state estimation.

---

#### 5. **Summary Table**

| Output                | What it means                                      | How to use it                |
|-----------------------|----------------------------------------------------|------------------------------|
| `filtered_means`      | State estimates using data up to time $t$          | Real-time tracking, forecast |
| `smoothed_means`      | State estimates using all data                     | Retrospective analysis       |
| `filtered_stds`       | Uncertainty in filtered state estimates            | Confidence intervals         |
| `smoothed_stds`       | Uncertainty in smoothed state estimates            | Confidence intervals         |
| Model parameters      | Assumed fixed during filtering/smoothing           | Need to fit separately       |

---

#### 6. **Key Takeaway**

> **The Kalman Filter and RTS Smoother provide optimal estimates of the hidden states, given the model parameters. They do not fit the model parameters themselves. To learn parameters, you need additional optimization or inference procedures.**

---

**Further Reading:**
- [Murphy, "Probabilistic Machine Learning: An Introduction", Section 18.5 (Parameter Learning)](https://probml.github.io/pml-book/book2.html#kalman-learning)
- [Wikipedia: Kalman Filter – Parameter Estimation](https://en.wikipedia.org/wiki/Kalman_filter#Parameter_estimation)

#### 6. Summary

To summarize (Slide 30):

* **Markov Chains** formalize the notion of a stochastic process with a local finite memory through conditional independence.
* **Gauss-Markov models** map this state to linear algebra, where state transitions and observations are linear and Gaussian.
* The **Kalman filter** is the name for the corresponding analytic filtering algorithm.
* **Bayesian filters** (like the Kalman filter and RTS smoother) are not just for signal processing of "fast" signals, but a general tool for inference in a chain of experiments/observations, enabling efficient sequential and retrospective analysis.

#### Exercises

**Exercise 1: Impact of Noise Levels**
Experiment with different values for `Q_std` (state noise standard deviation) and `R_std` (observation noise standard deviation) in the `generate_linear_gaussian_data` function. How do changes in these parameters affect the filtered and smoothed estimates and their uncertainties? Discuss why this happens.

**Exercise 2: Higher Dimensional State**
Modify the `state_dim` to 2 or 3. You will also need to adjust `A`, `Q`, `H`, `R`, `m0`, and `P0` to be matrices/vectors of appropriate dimensions. For example, for `state_dim=2`:
```python
A = jnp.array([[0.9, 0.1], [0.0, 0.9]])
Q_std = 0.1; Q = jnp.eye(state_dim) * Q_std**2
H = jnp.array([[1.0, 0.0]]) # Observe only the first state dimension
R_std = 0.5; R = jnp.eye(obs_dim) * R_std**2
m0 = jnp.array([0.0, 0.0])
P0_std = 1.0; P0 = jnp.eye(state_dim) * P0_std**2
```
Run the filter and smoother. How does the complexity of the matrices change? Plot the results for each state dimension (e.g., `state_idx=0` and `state_idx=1`).

**Exercise 3: Non-Identity Observation Matrix**
Change the `H` matrix to observe a linear combination of states. For example, if `state_dim=2` and `obs_dim=1`, set `H = jnp.array([[0.5, 0.5]])`. How does this affect the filter's ability to estimate individual state components? Discuss the role of observability.

**Exercise 4 (Advanced): Implementing a Simple Non-Linear Filter (e.g., Extended Kalman Filter concept)**
The Kalman Filter is for linear systems. Briefly research the concept of the **Extended Kalman Filter (EKF)**, which handles non-linear state transitions and/or observation models by linearizing them around the current mean estimate. Describe conceptually how you would modify the `kalman_filter` function to incorporate a non-linear `f(x_t-1)` or `h(x_t)` function using JAX's `jax.jacobian` for linearization.