# Probabilistic Machine Learning: Lecture 22 - Summary and Cleanup

#### Introduction

Welcome to Lecture 22 of Probabilistic Machine Learning! This lecture serves as a comprehensive **summary and cleanup** of the concepts we've explored throughout this course. We'll revisit the fundamental principles, key models, and powerful algorithms that form the backbone of probabilistic machine learning, highlighting their interconnections and practical applications.

This notebook will provide a structured overview, reinforcing your understanding of how probabilities, Gaussian processes, deep learning, and time series models fit into a unified probabilistic framework. We will continue to use **JAX** for any underlying numerical concepts and **Plotly** for visualizations where appropriate.

#### 1. Probabilities: The Language of Reasoning Under Uncertainty (Lectures 1-3)

The course began by establishing **probabilities as the language of reasoning under uncertainty** (Slide 3). We learned that all inference tasks can be described by assigning probabilities (or probability density functions for continuous variables) jointly to all variables in a problem.

Key rules of probability:
* **Normalization**: $\int_{\mathbb{R}^d} p(x) dx = 1$
* **Sum Rule (Marginalization)**: $p_{x_1}(x_1) = \int_{\mathbb{R}} p_X(x_1, x_2) dx_2$
* **Product Rule (Chain Rule)**: $p(x_1, x_2) = p(x_1|x_2) p(x_2)$
* **Bayes' Theorem**: $p(x_1|x_2) = \frac{p(x_1) p(x_2|x_1)}{\int p(x_1) p(x_2|x_1) dx_1}$

These fundamental rules allow us to update our beliefs about unknown quantities given new evidence.

#### 2. Exponential Families: Typed Reasoning (Lectures 4-6)

We then moved to **Exponential Families**, which provide a structured way to define probability distributions that allow for tractable inference (Slide 4). They are characterized by a specific functional form:

$$p_w(x) = h(x) \exp[\phi(x)^T w - \log Z(w)] = \frac{h(x)}{Z(w)} e^{\phi(x)^T w}$$ 

where:
* $h(x)$: base measure
* $\phi(x)$: sufficient statistics
* $w$: natural parameters
* $Z(w)$: partition function (normalizing constant)

A crucial property is the existence of **conjugate priors** for exponential families (Slide 5). If the likelihood belongs to an exponential family, there exists a prior distribution (also an exponential family) such that the posterior distribution also belongs to the same family. This simplifies Bayesian inference significantly.

#### 3. Gaussians: Inference as Linear Algebra (Lectures 5-6)

**Gaussian distributions** are a particularly important exponential family because **inference reduces to linear algebra** (Slide 6). This makes them computationally efficient and widely applicable.

Key properties of Gaussians:
* **Products of Gaussians are Gaussians**: Combining Gaussian likelihoods and priors results in a Gaussian posterior.
* **Linear projections of Gaussians are Gaussians**: If $z \sim \mathcal{N}(\mu, \Sigma)$, then $Az \sim \mathcal{N}(A\mu, A\Sigma A^T)$.
* **Marginals of Gaussians are Gaussians**: Integrating out variables from a joint Gaussian distribution results in a Gaussian marginal.
* **Conditionals of Gaussians are Gaussians**: Conditioning on observed variables in a joint Gaussian distribution results in a Gaussian conditional.

This means that Bayesian inference, when dealing with Gaussian assumptions, becomes a series of matrix operations.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set JAX to use 64-bit floats for numerical stability
jax.config.update("jax_enable_x64", True)

# --- Utility Functions (from previous lectures, adapted) ---
def plot_time_series_results(true_signal, observations, estimated_mean, estimated_std=None, title="", fig=None, row=None, col=None):
    """Plots time series data with estimates and uncertainty."""
    if fig is None:
        fig = go.Figure()
    time_steps = jnp.arange(len(true_signal))

    fig.add_trace(go.Scatter(
        x=time_steps,
        y=np.asarray(true_signal).flatten(),
        mode='lines',
        name='True Signal',
        line=dict(color='blue', width=2)
    ), row=row, col=col)

    fig.add_trace(go.Scatter(
        x=time_steps,
        y=np.asarray(observations).flatten(),
        mode='markers',
        name='Observations',
        marker=dict(color='red', size=4, opacity=0.6)
    ), row=row, col=col)

    fig.add_trace(go.Scatter(
        x=time_steps,
        y=np.asarray(estimated_mean).flatten(),
        mode='lines',
        name='Estimated Mean',
        line=dict(color='green', width=1, dash='dash')
    ), row=row, col=col)

    if estimated_std is not None:
        estimated_std_np = np.asarray(estimated_std).flatten()
        upper_bound = np.asarray(estimated_mean).flatten() + 2 * estimated_std_np
        lower_bound = np.asarray(estimated_mean).flatten() - 2 * estimated_std_np
        fig.add_trace(go.Scatter(
            x=np.concatenate([time_steps, time_steps[::-1]]),
            y=np.concatenate([upper_bound, lower_bound[::-1]]),
            fill='toself',
            fillcolor='rgba(0,255,0,0.1)',
            line_color='rgba(255,255,255,0)',
            name='2 Std Dev Uncertainty',
            showlegend=False
        ), row=row, col=col)

    fig.update_layout(title_text=title, title_x=0.5,
                      xaxis_title='Time Step',
                      yaxis_title='Value')
    return fig

def generate_linear_gaussian_data(T=100, state_dim=1, obs_dim=1, A=None, Q_std=0.1, H=None, R_std=0.1, m0=None, P0_std=1.0, key=None):
    """Generates synthetic data from a linear Gaussian state-space model."""
    if key is None:
        key = jax.random.PRNGKey(0)

    if A is None: A = jnp.eye(state_dim) * 0.9 # State transition matrix
    if H is None: H = jnp.eye(obs_dim, state_dim) # Observation matrix
    
    Q = jnp.eye(state_dim) * Q_std**2 # State noise covariance
    R = jnp.eye(obs_dim) * R_std**2 # Observation noise covariance

    if m0 is None: m0 = jnp.zeros(state_dim) # Initial state mean
    P0 = jnp.eye(state_dim) * P0_std**2 # Initial state covariance

    states = [m0]
    observations = []

    for t in range(T):
        key, subkey_state, subkey_obs = jax.random.split(key, 3)
        # State transition: x_t = A @ x_{t-1} + w_t, w_t ~ N(0, Q)
        state_noise = jax.random.multivariate_normal(subkey_state, jnp.zeros(state_dim), Q)
        next_state = A @ states[-1] + state_noise
        states.append(next_state)

        # Observation: y_t = H @ x_t + v_t, v_t ~ N(0, R)
        obs_noise = jax.random.multivariate_normal(subkey_obs, jnp.zeros(obs_dim), R)
        observation = H @ next_state + obs_noise
        observations.append(observation)

    return jnp.array(states[1:]), jnp.array(observations), A, Q, H, R, m0, P0

# Kalman Filter (from Lecture 20)
@jax.jit
def kalman_filter(observations, A, Q, H, R, m0, P0):
    """
    Implements the Kalman Filter for a linear Gaussian state-space model.
    """
    T = observations.shape[0]
    state_dim = m0.shape[0]

    filtered_means = jnp.zeros((T, state_dim))
    filtered_covs = jnp.zeros((T, state_dim, state_dim))
    predicted_means = jnp.zeros((T, state_dim))
    predicted_covs = jnp.zeros((T, state_dim, state_dim))

    m_prev = m0
    P_prev = P0

    for t in range(T):
        # Prediction Step (Time Update)
        m_minus = A @ m_prev # Predictive mean
        P_minus = A @ P_prev @ A.T + Q # Predictive covariance

        predicted_means = predicted_means.at[t].set(m_minus)
        predicted_covs = predicted_covs.at[t].set(P_minus)

        # Update Step (Measurement Update)
        z = observations[t] - H @ m_minus # Innovation residual
        S = H @ P_minus @ H.T + R # Innovation covariance
        K = P_minus @ H.T @ jnp.linalg.inv(S) # Kalman gain

        m_t = m_minus + K @ z # Updated mean
        P_t = (jnp.eye(state_dim) - K @ H) @ P_minus # Updated covariance

        filtered_means = filtered_means.at[t].set(m_t)
        filtered_covs = filtered_covs.at[t].set(P_t)

        m_prev = m_t
        P_prev = P_t

    return filtered_means, filtered_covs, predicted_means, predicted_covs

# RTS Smoother (from Lecture 20)
@jax.jit
def rts_smoother(filtered_means, filtered_covs, predicted_means, predicted_covs, A, Q):
    """
    Implements the Rauch-Tung-Striebel (RTS) Smoother.
    """
    T = filtered_means.shape[0]
    state_dim = filtered_means.shape[1]

    smoothed_means = jnp.copy(filtered_means)
    smoothed_covs = jnp.copy(filtered_covs)

    for t in reversed(range(T - 1)):
        # Smoother Gain
        G_t = jnp.linalg.solve(predicted_covs[t+1].T, (filtered_covs[t] @ A.T).T).T

        # Smoothed Mean
        smoothed_means = smoothed_means.at[t].set(
            filtered_means[t] + G_t @ (smoothed_means[t+1] - predicted_means[t+1])
        )

        # Smoothed Covariance
        smoothed_covs = smoothed_covs.at[t].set(
            filtered_covs[t] + G_t @ (smoothed_covs[t+1] - predicted_covs[t+1]) @ G_t.T
        )

    return smoothed_means, smoothed_covs

# --- Example: Gaussian Linear Algebra ---
print("\n--- Gaussian Linear Algebra Example ---")

# Define two Gaussian distributions: p1 = N(x; m1, C1), p2 = N(x; m2, C2)
m1 = jnp.array([1.0, 2.0])
C1 = jnp.array([[0.5, 0.1], [0.1, 0.8]])

m2 = jnp.array([0.5, 1.5])
C2 = jnp.array([[0.7, -0.2], [-0.2, 0.6]])

# Product of two Gaussians (results in a Gaussian)
C_prod_inv = jnp.linalg.inv(C1) + jnp.linalg.inv(C2)
C_prod = jnp.linalg.inv(C_prod_inv)
m_prod = C_prod @ (jnp.linalg.inv(C1) @ m1 + jnp.linalg.inv(C2) @ m2)

print("Product of Gaussians - Mean:", m_prod)
print("Product of Gaussians - Covariance:\n", C_prod)

# Linear projection: z = A @ x + b, where x ~ N(m, C)
A_proj = jnp.array([[1.0, 0.5], [0.2, 1.0]])
b_proj = jnp.array([0.1, -0.3])

m_proj = A_proj @ m1 + b_proj
C_proj = A_proj @ C1 @ A_proj.T

print("\nLinear Projection - Mean:", m_proj)
print("Linear Projection - Covariance:\n", C_proj)


#### 4. Parametric Regression: Learning Functions with Gaussians (Lecture 7)

We applied Gaussian principles to **parametric regression** (Slide 7), where we model functions $f(x) = \phi(x)^T w$. If we place Gaussian priors on the weights $w$ and assume Gaussian likelihoods for the observations $y$, then the posterior distribution over the weights $p(w|y, \phi_X)$ and the posterior over the function values $p(f_X|y, \phi_X)$ are also Gaussian. All these computations involve linear algebra, providing analytic solutions for the mean and covariance of the posteriors.

This framework directly connects to L2-regularized linear regression, where the Gaussian prior corresponds to the regularization term.

#### 5. Gaussian Processes: Models with "Infinite Freedom" (Lectures 8-11)

**Gaussian Processes (GPs)** extend the concept of Gaussian distributions to functions (Slide 8). A GP is a collection of random variables, any finite number of which have a joint Gaussian distribution. It is fully specified by its mean function $m(\bullet)$ and covariance function (kernel) $k(\bullet, \circ)$:

$$f(\bullet) \sim \mathcal{GP}(m(\bullet), k(\bullet, \circ))$$

GPs are nonparametric models, meaning they don't have a fixed, finite number of parameters. Instead, they learn a distribution over functions. This gives them "infinite freedom" or flexibility, capable of approximating any continuous function given a universal kernel.

Inference in GPs (Slide 9) involves computing the posterior mean and covariance, which also have analytic forms for Gaussian likelihoods. While theoretically powerful, exact GP inference scales as $\mathcal{O}(N^3)$ with the number of data points $N$, necessitating approximations for large datasets.

#### 6. Classification: Approximate Inference with Non-Gaussian Likelihoods (Lectures 14-15)

When dealing with classification tasks, the likelihood function is typically non-Gaussian (e.g., sigmoid for binary classification, softmax for multi-class) (Slide 10). This makes the posterior distribution over the latent function (or parameters) non-Gaussian and often intractable.

To overcome this, we introduced **Laplace approximations**. The idea is to:
1.  Find the Maximum A Posteriori (MAP) estimate $\hat{f}$ of the latent function at the training points.
2.  Approximate the posterior around $\hat{f}$ with a Gaussian distribution, using the Hessian of the negative log-posterior at $\hat{f}$ to define the covariance.

This provides an approximate Gaussian posterior from which predictions and uncertainty can be derived, even for non-Gaussian likelihoods.

#### 7. Deep Learning: Any Deep Network as a GP (Lectures 16-19)

A significant insight from the course is how **deep neural networks can be viewed as Gaussian Processes** through Laplace approximations (Slide 11). This involves four key steps:

1.  Realize that the deep network's loss function (Empirical Risk Minimization) is equivalent to the negative log-posterior of its parameters.
2.  Train the deep network as usual to find the MAP estimate of its parameters, $\theta_*$.
3.  At $\theta_*$, compute a Gaussian (Laplace) approximation of the parameter posterior using the Hessian.
4.  Linearize the deep network's output function around $\theta_*$.

This process yields an approximate GP where the mean function is the trained deep network itself, and the covariance function is the **Laplace tangent kernel**. This approach allows deep learning models to approximately inherit the probabilistic functionality of GPs, providing principled uncertainty quantification and enabling capabilities like continual learning, which helps mitigate issues like pathological overconfidence.

#### 8. Markov Chains: $\mathcal{O}(T)$ Inference in Time Series (Lectures 20-21)

We then shifted our focus to **time series**, where data arrives sequentially and efficient, often real-time, inference is critical. **Markov Chains** formalize the notion of a stochastic process with a local finite memory (Slide 12).

Key operations in Markov Chains, performed in $\mathcal{O}(T)$ time for $T$ time steps:
* **Filtering**: Estimating the current state given past and current observations (Predict and Update steps).
* **Smoothing**: Estimating past states given all observations (past, current, and future).

For **Gauss-Markov Models** (where transitions and observations are linear and Gaussian), these operations have analytic solutions given by the **Kalman Filter** (for filtering) and the **Rauch-Tung-Striebel (RTS) Smoother** (for smoothing). We also saw how certain GPs (Gauss-Markov processes) can be derived from Stochastic Differential Equations (SDEs), which are continuous-time generalizations of these discrete-time linear Gaussian systems.

In [None]:
# --- Example: Kalman Filter for Time Series (from Lecture 20) ---
print("\n--- Kalman Filter for Time Series Example ---")

# Generate data
T_kf = 100 # Number of time steps
state_dim_kf = 1 # 1D state
obs_dim_kf = 1 # 1D observation

A_kf = jnp.array([[0.95]]) # State transition matrix
Q_std_kf = 0.1 # Standard deviation of state noise
H_kf = jnp.array([[1.0]]) # Observation matrix
R_std_kf = 0.5 # Standard deviation of observation noise
m0_kf = jnp.array([0.0]) # Initial state mean
P0_std_kf = 1.0 # Initial state standard deviation

key_kf = jax.random.PRNGKey(123)
true_states_kf, observations_kf, _, _, _, _, _, _ = generate_linear_gaussian_data(
    T=T_kf, state_dim=state_dim_kf, obs_dim=obs_dim_kf,
    A=A_kf, Q_std=Q_std_kf, H=H_kf, R_std=R_std_kf, m0=m0_kf, P0_std=P0_std_kf, key=key_kf
)

filtered_means_kf, filtered_covs_kf, predicted_means_kf, predicted_covs_kf = kalman_filter(
    observations_kf, A_kf, jnp.array([[Q_std_kf**2]]), H_kf, jnp.array([[R_std_kf**2]]), m0_kf, jnp.array([[P0_std_kf**2]])
)
filtered_stds_kf = jnp.sqrt(jnp.diagonal(filtered_covs_kf, axis1=1, axis2=2))

smoothed_means_kf, smoothed_covs_kf = rts_smoother(
    filtered_means_kf, filtered_covs_kf, predicted_means_kf, predicted_covs_kf, A_kf, jnp.array([[Q_std_kf**2]])
)
smoothed_stds_kf = jnp.sqrt(jnp.diagonal(smoothed_covs_kf, axis1=1, axis2=2))

plot_time_series_results(
    true_states_kf,
    observations_kf,
    filtered_means_kf,
    filtered_stds_kf,
    title='Kalman Filter Results (Summary Example)'
)
plot_time_series_results(
    true_states_kf,
    observations_kf,
    smoothed_means_kf,
    smoothed_stds_kf,
    title='RTS Smoother Results (Summary Example)'
)


#### 9. Bayesian Hierarchical Learning (Lectures 10, and 22 slides 14-22)

Beyond simply inferring variables, we also discussed **Bayesian Hierarchical Learning** (Slide 14), which provides a formalism for training hyperparameters and fitting architectures/models. The core idea is to treat model parameters (hyperparameters $\theta$) as random variables and infer them by maximizing the **model evidence (marginal likelihood)** $p(y | \theta)$ (Slide 15).

$$p(\theta | y) = \frac{p(y | \theta) p(\theta)}{\int p(y | \theta') p(\theta') d\theta'}$$

The evidence naturally balances model fit with model complexity (Occam's Razor, Slide 18), penalizing overly complex models that don't significantly improve the data likelihood. While the evidence integral is often intractable for general models, we explored two main approaches:

* **Laplace Approximations**: These provide a general, approximate way to calculate the evidence by approximating the posterior with a Gaussian (Slide 19). This was particularly relevant for deep learning models.

* **Expectation-Maximization (EM) Algorithm**: For certain models, the EM algorithm offers a tractable iterative solution for maximizing the evidence (Slides 20-21). It alternates between:
    * **E-step**: Computing the expected complete-data log-likelihood $q(\theta, \theta^t) = \int p(z | y, \theta^t) \log p(y, z | \theta) dz$.
    * **M-step**: Maximizing $q(\theta, \theta^t)$ with respect to $\theta$ to get $\theta^{t+1}$.

We saw an example of EM for Gauss-Markov models (Slide 22), where the complete-data log-likelihood neatly separates into local terms, making the expectation easy to compute. This allows for learning the system matrices ($A, Q, H, R$) from data.

#### Conclusion: The Probabilistic Machine Learning Toolbox

This course has equipped you with a powerful **toolbox** for probabilistic machine learning (Slide 13):

* **Framework**: The fundamental rules of probability (sum, product, Bayes' theorem) for reasoning under uncertainty.
* **Modeling**: Directed Graphical Models, Exponential Families, Gaussian Distributions, Kernels, Markov Chains, and Deep Networks as building blocks for complex models.
* **Computation**: Automatic differentiation, MAP estimation with Laplace approximations, and linear algebra as core computational primitives.

You've learned to approach machine learning problems from a probabilistic perspective, understanding not just predictions but also the uncertainty associated with them. This foundation is invaluable for building robust, interpretable, and data-efficient AI systems.

#### Exercises

These exercises encourage you to reflect on the entire course content.

**Exercise 1: Connecting Concepts**
Choose two distinct concepts from the course (e.g., "Exponential Families" and "Deep Learning as GPs") and explain, in your own words, how they are related or how insights from one can inform the other. Provide a brief example or scenario where understanding this connection would be beneficial.

**Exercise 2: Probabilistic vs. Deterministic Approaches**
Reflect on a real-world problem (e.g., medical diagnosis, financial forecasting, autonomous driving). Discuss why a probabilistic machine learning approach might be preferred over a purely deterministic one for this problem. Specifically, how would uncertainty quantification add value?

**Exercise 3: The Power of Linear Algebra**
The course emphasized "Inference as Linear Algebra" for Gaussians. Identify a specific algorithm or model discussed in the course (e.g., Kalman Filter, GP regression) and explain how its core computations are fundamentally linear algebraic operations. Why is this computationally advantageous?

**Exercise 4: Future Directions**
Based on the summary, what aspect of probabilistic machine learning or its applications are you most interested in exploring further? Briefly explain why and suggest a potential next step for learning more about it (e.g., a specific paper, book, or advanced course topic).