# Probabilistic Machine Learning: Lecture 21 - Hidden Markov Models

#### Introduction

Welcome to Lecture 21 of Probabilistic Machine Learning! This lecture continues our exploration of time series models, moving beyond the linear Gaussian assumptions of Gauss-Markov models to introduce **Hidden Markov Models (HMMs)**. We will delve into the connections between Gaussian Processes and Gauss-Markov Processes, discuss how to learn the parameters of these models, and, crucially, address the challenge of inference when the world isn't linear or Gaussian.

This notebook will provide theoretical insights and practical illustrations using **JAX** for numerical computations and **Plotly** for interactive visualizations.

#### 1. Reminder: Time Series as a Problem Class

As a reminder from Lecture 20 (Slide 2), our goal for this week is to understand time series as a problem class. We've covered:

* **Application Layer**: Data arriving as a stream.
* **Model Structure Layer**: Markov Chains / Hidden Markov Models.
* **Concrete Model Layer**: Gauss-Markov Models.
* **Algorithm Layer**: Kalman Filter & RTS Smoother.

Today, we focus on:

* **Theory**: What is the connection between Gaussian processes and Gauss-Markov Processes?
* **Parameters**: Can we learn the parameters of a Gauss-Markov model?
* **Generalization**: What if the world isn't Gaussian?

#### 2. Recap: Markov Chains and their Algorithmic Structure

A **Markov Chain** formalizes the notion of a stochastic process with a local finite memory, where the future state $x_t$ depends only on the immediate past state $x_{t-1}$ (Slide 3-4): $p(x_t | X_{0:t-1}) = p(x_t | x_{t-1})$. Observations $y_t$ are typically assumed to depend only on the current state $x_t$: $p(y_t | X) = p(y_t | x_t)$.

This conditional independence structure allows for efficient inference operations (Slides 5-6), each performed in $\mathcal{O}(T)$ time for $T$ time steps:

* **Filtering**: Estimating the current state given all past and current observations.
    * **Predict**: $p(x_t | Y_{0:t-1}) = \int p(x_t | x_{t-1}) p(x_{t-1} | Y_{0:t-1}) dx_{t-1}$ (Chapman-Kolmogorov Equation)
    * **Update**: $p(x_t | Y_{0:t}) = \frac{p(y_t | x_t) p(x_t | Y_{0:t-1})}{p(y_t)}$ (Bayes' Theorem)
* **Smoothing**: Estimating past states given all observations (past, current, and future).
    * **Smooth**: $p(x_t | Y) = p(x_t | Y_{0:t}) \int p(x_{t+1} | x_t) \frac{p(x_{t+1} | Y)}{p(x_{t+1} | Y_{0:t})} dx_{t+1}$ (backward pass)

If the relationships are linear and Gaussian (Gauss-Markov Models), these operations are analytic and given by the Kalman Filter and Rauch-Tung-Striebel Smoother (Slides 7-9).

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set JAX to use 64-bit floats for numerical stability
jax.config.update("jax_enable_x64", True)


# --- Utility Functions (from Lecture 20, slightly adapted) ---
def generate_linear_gaussian_data(
    T=100,
    state_dim=1,
    obs_dim=1,
    A=None,
    Q_std=0.1,
    H=None,
    R_std=0.1,
    m0=None,
    P0_std=1.0,
    key=None,
):
    """Generates synthetic data from a linear Gaussian state-space model."""
    if key is None:
        key = jax.random.PRNGKey(0)

    if A is None:
        A = jnp.eye(state_dim) * 0.9  # State transition matrix
    if H is None:
        H = jnp.eye(obs_dim, state_dim)  # Observation matrix

    Q = jnp.eye(state_dim) * Q_std**2  # State noise covariance
    R = jnp.eye(obs_dim) * R_std**2  # Observation noise covariance

    if m0 is None:
        m0 = jnp.zeros(state_dim)  # Initial state mean
    P0 = jnp.eye(state_dim) * P0_std**2  # Initial state covariance

    states = [m0]
    observations = []

    for t in range(T):
        key, subkey_state, subkey_obs = jax.random.split(key, 3)
        # State transition: x_t = A @ x_{t-1} + w_t, w_t ~ N(0, Q)
        state_noise = jax.random.multivariate_normal(
            subkey_state, jnp.zeros(state_dim), Q
        )
        next_state = A @ states[-1] + state_noise
        states.append(next_state)

        # Observation: y_t = H @ x_t + v_t, v_t ~ N(0, R)
        obs_noise = jax.random.multivariate_normal(subkey_obs, jnp.zeros(obs_dim), R)
        observation = H @ next_state + obs_noise
        observations.append(observation)

    return jnp.array(states[1:]), jnp.array(observations), A, Q, H, R, m0, P0


def plot_time_series_results(
    true_signal,
    observations,
    estimated_mean,
    estimated_std=None,
    title="",
    fig=None,
    row=None,
    col=None,
):
    """Plots time series data with estimates and uncertainty."""
    if fig is None:
        fig = go.Figure()
    time_steps = jnp.arange(len(true_signal))

    fig.add_trace(
        go.Scatter(
            x=time_steps,
            y=np.asarray(true_signal).flatten(),
            mode="lines",
            name="True Signal",
            line=dict(color="blue", width=2),
        ),
        row=row,
        col=col,
    )

    fig.add_trace(
        go.Scatter(
            x=time_steps,
            y=np.asarray(observations).flatten(),
            mode="markers",
            name="Observations",
            marker=dict(color="red", size=4, opacity=0.6),
        ),
        row=row,
        col=col,
    )

    fig.add_trace(
        go.Scatter(
            x=time_steps,
            y=np.asarray(estimated_mean).flatten(),
            mode="lines",
            name="Estimated Mean",
            line=dict(color="green", width=1, dash="dash"),
        ),
        row=row,
        col=col,
    )

    if estimated_std is not None:
        estimated_std_np = np.asarray(estimated_std).flatten()
        upper_bound = np.asarray(estimated_mean).flatten() + 2 * estimated_std_np
        lower_bound = np.asarray(estimated_mean).flatten() - 2 * estimated_std_np
        fig.add_trace(
            go.Scatter(
                x=np.concatenate([time_steps, time_steps[::-1]]),
                y=np.concatenate([upper_bound, lower_bound[::-1]]),
                fill="toself",
                fillcolor="rgba(0,255,0,0.1)",
                line_color="rgba(255,255,255,0)",
                name="2 Std Dev Uncertainty",
                showlegend=False,
            ),
            row=row,
            col=col,
        )

    fig.update_layout(
        title_text=title, title_x=0.5, xaxis_title="Time Step", yaxis_title="Value"
    )
    return fig


# Kalman Filter (from Lecture 20)
@jax.jit
def kalman_filter(observations, A, Q, H, R, m0, P0):
    """
    Implements the Kalman Filter for a linear Gaussian state-space model.
    """
    T = observations.shape[0]
    state_dim = m0.shape[0]

    filtered_means = jnp.zeros((T, state_dim))
    filtered_covs = jnp.zeros((T, state_dim, state_dim))
    predicted_means = jnp.zeros((T, state_dim))
    predicted_covs = jnp.zeros((T, state_dim, state_dim))

    m_prev = m0
    P_prev = P0

    for t in range(T):
        # Prediction Step (Time Update)
        m_minus = A @ m_prev  # Predictive mean
        P_minus = A @ P_prev @ A.T + Q  # Predictive covariance

        predicted_means = predicted_means.at[t].set(m_minus)
        predicted_covs = predicted_covs.at[t].set(P_minus)

        # Update Step (Measurement Update)
        z = observations[t] - H @ m_minus  # Innovation residual
        S = H @ P_minus @ H.T + R  # Innovation covariance
        K = P_minus @ H.T @ jnp.linalg.inv(S)  # Kalman gain

        m_t = m_minus + K @ z  # Updated mean
        P_t = (jnp.eye(state_dim) - K @ H) @ P_minus  # Updated covariance

        filtered_means = filtered_means.at[t].set(m_t)
        filtered_covs = filtered_covs.at[t].set(P_t)

        m_prev = m_t
        P_prev = P_t

    return filtered_means, filtered_covs, predicted_means, predicted_covs


# RTS Smoother (from Lecture 20)
@jax.jit
def rts_smoother(filtered_means, filtered_covs, predicted_means, predicted_covs, A, Q):
    """
    Implements the Rauch-Tung-Striebel (RTS) Smoother.
    """
    T = filtered_means.shape[0]
    state_dim = filtered_means.shape[1]

    smoothed_means = jnp.copy(filtered_means)
    smoothed_covs = jnp.copy(filtered_covs)

    for t in reversed(range(T - 1)):
        # Smoother Gain
        G_t = jnp.linalg.solve(predicted_covs[t + 1].T, (filtered_covs[t] @ A.T).T).T

        # Smoothed Mean
        smoothed_means = smoothed_means.at[t].set(
            filtered_means[t] + G_t @ (smoothed_means[t + 1] - predicted_means[t + 1])
        )

        # Smoothed Covariance
        smoothed_covs = smoothed_covs.at[t].set(
            filtered_covs[t]
            + G_t @ (smoothed_covs[t + 1] - predicted_covs[t + 1]) @ G_t.T
        )

    return smoothed_means, smoothed_covs


#### 3. Connection between Gaussian Processes and Gauss-Markov Processes

A fascinating theoretical connection exists between certain Gaussian Processes (GPs) and **Linear Time-Invariant Stochastic Differential Equations (LTI-SDEs)** (Slides 10-26). An LTI-SDE describes the local behavior of a Gaussian process. For our purposes, a linear, time-invariant SDE:

$$dx(t) = Fx(t) dt + L d\omega(t)$$

where $d\omega(t)$ is a Wiener process (Brownian motion), describes a Gaussian process with an analytic mean function $m(t)$ and covariance function $k(t, t')$ (Slide 25):

$$m(t) = e^{Ft} x_0$$
$$k(t, t') = \int_{\min(t,t')}^{\max(t,t')} e^{F(\max(t,t')-\tau)} LL^T e^{F^T(\max(t,t')-\tau)} d\tau$$

Crucially, these LTI-SDEs can be **discretized exactly** to yield discrete, linear Gaussian transition models (Slide 25):

$$p(x(t_{i+1}) | x(t_i)) = \mathcal{N}(x(t_{i+1}); A_i x(t_i), Q_i)$$

with $A_i = e^{F\Delta t_i}$ and $Q_i = \int_0^{\Delta t_i} e^{F(\Delta t_i - \tau)} LL^T e^{F^T(\Delta t_i - \tau)} d\tau$, where $\Delta t_i = t_{i+1} - t_i$.

This means that certain GPs (known as Gauss-Markov processes) can be exactly represented as discrete-time linear Gaussian state-space models, allowing their inference to be performed in linear time using the Kalman Filter and Smoother (Slide 26).

### Examples of GPs and their SDE counterparts (Slides 28-31):

* **Scaled Wiener Process**: $F=0, L=\theta$. This corresponds to a GP with mean $x_0$ and covariance $k(t, t') = \theta^2 \min(t, t')$. Its discrete form has $A_i = I$ and $Q_i = \theta^2 \Delta t_i$.
* **Ornstein-Uhlenbeck Process**: $F = -1/\lambda, L = \sqrt{2}\theta/\sqrt{\lambda}$. This corresponds to a GP with an exponential kernel $k(t, t') = \theta^2 e^{-|t-t'|/\lambda}$.
* **Integrated Wiener Velocity**: This is a non-scalar example (state includes position and velocity), leading to a polynomial spline kernel.

Let's illustrate the discretization of a simple Wiener process SDE into its discrete-time Kalman filter parameters.

In [None]:
# --- Illustrating SDE Discretization (Wiener Process) ---


def discretize_wiener_process(delta_t, theta_sq):
    """Discretizes a scaled Wiener process SDE into discrete-time Kalman filter parameters."""
    # SDE: dx(t) = theta * d_omega(t)  (F=0, L=theta)
    # Discrete-time: x_{i+1} = A * x_i + w_i, w_i ~ N(0, Q)
    # A = exp(F * delta_t) = exp(0 * delta_t) = I
    # Q = integral_0^delta_t exp(F*(delta_t - tau)) * L @ L.T * exp(F.T*(delta_t - tau)) d_tau
    # For F=0, L=theta: Q = integral_0^delta_t theta^2 I d_tau = theta^2 * delta_t * I

    A_discrete = jnp.array([[1.0]])  # Identity matrix for 1D
    Q_discrete = jnp.array([[theta_sq * delta_t]])
    return A_discrete, Q_discrete


# Example usage:
delta_t = 0.1
theta_sq = 0.5  # Corresponds to Q_std^2 in generate_linear_gaussian_data
A_wp, Q_wp = discretize_wiener_process(delta_t, theta_sq)

print("\n--- Wiener Process Discretization Example ---")
print(f"For delta_t = {delta_t}, theta_sq = {theta_sq}:")
print(f"Discrete A matrix: {A_wp}")
print(f"Discrete Q matrix: {Q_wp}")

# Simulate a Wiener process using the discrete model and Kalman filter
T_sim = 200
H_sim = jnp.array([[1.0]])  # Direct observation
R_std_sim = 0.2  # Observation noise
R_sim = jnp.array([[R_std_sim**2]])
m0_sim = jnp.array([0.0])
P0_std_sim = 0.1
P0_sim = jnp.array([[P0_std_sim**2]])

key_sim = jax.random.PRNGKey(100)
true_states_sim, observations_sim, _, _, _, _, _, _ = generate_linear_gaussian_data(
    T=T_sim,
    state_dim=1,
    obs_dim=1,
    A=A_wp,
    Q_std=jnp.sqrt(Q_wp[0, 0]),
    H=H_sim,
    R_std=R_std_sim,
    m0=m0_sim,
    P0_std=P0_std_sim,
    key=key_sim,
)

filtered_means_sim, filtered_covs_sim, _, _ = kalman_filter(
    observations_sim, A_wp, Q_wp, H_sim, R_sim, m0_sim, P0_sim
)
filtered_stds_sim = jnp.sqrt(jnp.diagonal(filtered_covs_sim, axis1=1, axis2=2))

fig_wiener = plot_time_series_results(
    true_states_sim,
    observations_sim,
    filtered_means_sim,
    filtered_stds_sim,
    title="Simulated Wiener Process with Kalman Filter",
)
fig_wiener.update_layout(height=600, width=800)
fig_wiener.show()


#### 4. Parameter Learning in Gauss-Markov Models

A crucial question is: Can we learn the parameters (hyperparameters) of a Gauss-Markov model, such as $A, Q, H, R$? (Slide 33)

Similar to learning kernel parameters in GPs (recap from Lecture 10, Slide 34), we can use **Bayesian Hierarchical Inference**. The key is to optimize the **model evidence (marginal likelihood)** $p(y | \theta)$, where $\theta$ represents all the unknown model parameters.

For Gauss-Markov Models, the evidence can be computed efficiently in $\mathcal{O}(N)$ time using the **Prediction Error Decomposition** (Slides 35-37):

$$p(y | \theta) = \prod_{i=0}^N p(y_i | y_{0:i-1}, \theta)$$

Each term $p(y_i | y_{0:i-1}, \theta)$ is the likelihood of the current observation given all previous observations, which can be computed directly from the Kalman filter's innovation residual $z_i$ and innovation covariance $S_i$:

$$p(y_i | y_{0:i-1}, \theta) = \mathcal{N}(y_i; H m_i^-, H P_i^- H^T + R) = \mathcal{N}(z_i; 0, S_i)$$

Therefore, the log evidence is:

$$\log p(y | \theta) = -\frac{1}{2} \sum_{i=1}^N \left( z_i^T S_i^{-1} z_i + \log |S_i| + \log 2\pi \right)$$

This objective function can be optimized using gradient-based methods (e.g., L-BFGS) to find the maximum likelihood estimates of the model parameters. The gradients of the log-evidence with respect to the parameters can be computed using automatic differentiation, making this a powerful approach for learning in linear Gaussian state-space models.

While we won't implement a full parameter learning loop here (as it requires careful handling of parameter constraints and optimization), understanding this formulation is crucial. The Kalman filter naturally provides all the components needed to compute this likelihood.

#### 5. Generalization: What if the World Isn't Linear Gaussian? (Hidden Markov Models)

The real world is often not perfectly linear or Gaussian. What happens then? This leads us to **Hidden Markov Models (HMMs)** in a broader sense, and various approximate Bayesian filtering and smoothing algorithms (Slides 38-39):

| System Type              | State Transition $p(x_t \mid x_{t-1})$                | Observation $p(y_t \mid x_t)$                | Algorithm                                         |
| :----------------------- | :---------------------------------------------------- | :------------------------------------------- | :------------------------------------------------ |
| **Markovian System**     | General                                               | General                                      | General Bayesian filtering and smoothing           |
| **Linear Gaussian**      | $\mathcal{N}(x_t; A x_{t-1}, Q)$                      | $\mathcal{N}(y_t; H x_t, R)$                 | Kalman filter, RTS smoother                       |
| **Nonlinear Gaussian**   | $\mathcal{N}(x_t; a(x_{t-1}), Q)$                     | $\mathcal{N}(y_t; h(x_t), R)$                | Extended/Unscented/Particle filter, etc.          |
| **Non-Gaussian Obs.**   | $\mathcal{N}(x_t; A x_{t-1}, Q)$                      | $p(y_t \mid h(x_t))$                         | (Requires approximations)                         |
| **Hidden Markov Model**  | $p(x_t = \Pi x_{t-1})$ (discrete)                     | $\mathcal{N}(y_t; h(x_t), R)$                | Viterbi, Forward-Backward, Baum-Welch             |

---

For continuous systems with nonlinear dynamics and/or nonlinear observations, a number of **approximately Gaussian filters** have been developed:

- **Extended Kalman Filter (EKF):**  
    Linearizes the nonlinear functions around the current mean estimate using a Taylor expansion. This is an approximation, but often works well in practice.

- **Unscented Kalman Filter (UKF):**  
    Uses a deterministic sampling approach (the unscented transform) to propagate mean and covariance through the nonlinearities, avoiding explicit Jacobians.

- **Particle Filter:**  
    A non-parametric approach that represents the posterior distribution with a set of weighted particles. This method is suitable for highly nonlinear or non-Gaussian systems.

---

> **Note:**  
> These methods are beyond the scope of this introductory course, but are crucial for real-world applications where linearity and Gaussianity assumptions do not hold.  
> Time series analysis is a vast field, and these advanced filters form its backbone.

In [None]:
# --- Illustrating Non-Linear System (Conceptual) ---


def generate_nonlinear_data(T=100, noise_std=0.1, key=None):
    """Generates synthetic data from a simple nonlinear state-space model."""
    if key is None:
        key = jax.random.PRNGKey(0)

    states = [jnp.array([0.1])]
    observations = []

    for t in range(T):
        key, subkey_state, subkey_obs = jax.random.split(key, 3)
        # Nonlinear state transition: x_t = 0.5 * x_{t-1} + 25 * x_{t-1} / (1 + x_{t-1}^2) + 8 * cos(1.2 * t) + w_t
        # This is a common benchmark system (from Särkkä's book)
        nonlinear_term = (
            0.5 * states[-1]
            + 25 * states[-1] / (1 + states[-1] ** 2)
            + 8 * jnp.cos(1.2 * t)
        )
        state_noise = jax.random.normal(subkey_state, (1,)) * noise_std * 2
        next_state = nonlinear_term + state_noise
        states.append(next_state)

        # Nonlinear observation: y_t = x_t**2 / 20 + v_t
        obs_noise = jax.random.normal(subkey_obs, (1,)) * noise_std
        observation = states[-1] ** 2 / 20 + obs_noise
        observations.append(observation)

    return jnp.array(states[1:]), jnp.array(observations)


# Generate nonlinear data
T_nl = 100
key_nl = jax.random.PRNGKey(200)
true_states_nl, observations_nl = generate_nonlinear_data(
    T=T_nl, noise_std=0.5, key=key_nl
)

print("\n--- Simulating a Nonlinear System ---")
print(f"Generated {T_nl} time steps of nonlinear data.")

# Attempt to filter with a *linear* Kalman Filter (will perform poorly)
# We need to define linear A, Q, H, R for the Kalman filter, which won't match the true nonlinear system.
# Let's use some 'best guess' linear parameters.
A_linear_approx = jnp.array([[1.0]])  # Assume near constant velocity for small delta_t
Q_linear_approx = jnp.array(
    [[1.0]]
)  # Large process noise to try to capture nonlinearity
H_linear_approx = jnp.array([[0.1]])  # Simple linear observation approx
R_linear_approx = jnp.array([[0.5**2]])  # Observation noise
m0_linear_approx = jnp.array([0.0])
P0_linear_approx = jnp.array([[1.0]])

filtered_means_nl, filtered_covs_nl, _, _ = kalman_filter(
    observations_nl,
    A_linear_approx,
    Q_linear_approx,
    H_linear_approx,
    R_linear_approx,
    m0_linear_approx,
    P0_linear_approx,
)
filtered_stds_nl = jnp.sqrt(jnp.diagonal(filtered_covs_nl, axis1=1, axis2=2))

fig_nonlinear_fail = plot_time_series_results(
    true_states_nl,
    observations_nl,
    filtered_means_nl,
    filtered_stds_nl,
    title="Nonlinear System: Kalman Filter (Linear Assumption) Performance",
)
fig_nonlinear_fail.update_layout(height=600, width=800)
fig_nonlinear_fail.show()


As expected, the linear Kalman Filter struggles to accurately track the highly nonlinear true state. Its estimates deviate significantly, and while its uncertainty grows, it doesn't fully capture the complex dynamics. This illustrates the need for more advanced, approximate non-linear filters (like EKF, UKF, or Particle Filters) when dealing with such systems.

#### 6. Summary

To summarize (Slide 40):

* **Markov Chains** capture finite memory of a time series through conditional independence.
* **Gauss-Markov models** map this state to linear algebra.
* The **Kalman filter** is the name for the corresponding algorithm.
* **SDEs (Stochastic Differential Equations)** are the continuous-time limit of discrete-time stochastic recurrence relations, providing a powerful theoretical framework for continuous-time Gauss-Markov systems.
* Parameters of the model can be learnt by optimizing the (log) evidence, which is also $\mathcal{O}(N)$.
* **Non-Gaussian models** can be learnt by approximate inference, analogous to GP models, using methods like Extended/Unscented Kalman Filters or Particle Filters.

Gauss-Markov models form the algorithmic scaffold for time-series models, and understanding their extensions to non-linear and non-Gaussian scenarios is key to applying probabilistic machine learning to real-world dynamic systems.

#### Exercises

**Exercise 1: Discretize Ornstein-Uhlenbeck Process**
Based on Slide 29, implement a `discretize_ornstein_uhlenbeck_process(delta_t, theta_sq, lambda_param)` function that returns the discrete $A_i$ and $Q_i$ matrices for the Ornstein-Uhlenbeck process. Then, use this in the simulation and Kalman filter, similar to the Wiener process example.

**Exercise 2: Visualize the Effect of Discretization Time Step**
In the Wiener process simulation (Section 3), vary `delta_t` (e.g., 1.0, 0.5, 0.05). How does the smoothness of the true state and the performance of the Kalman filter change as `delta_t` decreases? Discuss the trade-off between approximation accuracy and computational cost.

**Exercise 3: Conceptual Parameter Learning**
Consider the `log_evidence` formula for Gauss-Markov models. If you wanted to learn the `Q_std` (state noise standard deviation) for a simple 1D linear model, conceptually describe the optimization loop you would set up. What JAX functionalities would you use (e.g., `jax.value_and_grad`)? (No need to implement, just describe the steps).

**Exercise 4 (Advanced): Implement a Simple EKF Step (Conceptual)**
Take the `generate_nonlinear_data` function. Conceptually, how would you modify the `kalman_filter` to become an `extended_kalman_filter`? Specifically, describe how you would compute the linearized `A` and `H` matrices at each time step using `jax.jacobian` around the current state estimate. What challenges might arise?