# The Role of Linear Algebra in Gaussian Processes

Welcome to the final installment of our Gaussian Processes (GPs) blog series! Throughout this series, we've explored:

- The fundamentals of GPs as distributions over functions
- The power and flexibility of kernels
- Building complex models with additive kernels
- Learning hyperparameters for better model performance

In this lecture, we'll shift our focus to the computational heart of Gaussian Processes: **linear algebra**. While GPs provide elegant probabilistic models, their practical application often depends on efficiently solving large linear systems. Understanding the linear algebra behind GPs is crucial for both using them effectively and appreciating their computational challenges.

We'll cover:

- **Why GPs are computationally expensive, especially for large datasets**
- **The key linear algebra operations that power GP inference**
- **How these operations relate to numerical methods like LU and Cholesky decomposition**
- **The computational complexity of GPs compared to deep learning models**
- **How uncertainty quantification in GPs contributes to their computational cost**

This lecture will provide a crucial "behind-the-scenes" look at what makes GPs tick computationally, equipping you with the knowledge to understand both their power and their limitations.

## Recap from Last Week: Key GP Insights

Before diving into the computational details, let's quickly recap the main insights from our previous lecture (**Lecture 11: Understanding Kernels and GPs**):

---

### 1. **GPs as Function Distributions**

- **Gaussian Processes (GPs)** are probability distributions over function spaces.
- The exact nature of this probability space is often subtle and depends on the kernel.
- Understanding the sample space of a GP requires studying the kernel itself.

---

### 2. **Kernel-Covariance Equivalence**

- Every covariance function is a kernel, and every positive-definite kernel can serve as the covariance function of a GP.
- **Key takeaway:** Specifying a valid kernel is sufficient to define a GP.

---

### 3. **Kernels as "Infinite Matrices"**

- Kernels have eigenfunctions and eigenvalues, analogous to matrices having eigenvectors and eigenvalues.
- This allows us to conceptualize kernels as a kind of "infinite matrix" that spans a space of functions.

---

### 4. **Reproducing Kernel Hilbert Space (RKHS)**

- The RKHS is the space of all possible posterior mean functions of the GP regression method.
- This links GPs to Frequentist kernel machines.

---

### 5. **Posterior Variance as Worst-Case Error**

- The GP's posterior covariance function (the Bayesian expected squared error) has a Frequentist interpretation as a worst-case squared error in the RKHS.

---

### 6. **GP Samples vs. RKHS**

- **Crucial distinction:** Sample paths drawn from a GP generally do **not** lie in the RKHS of the kernel.
- GP samples tend to be "rougher" and reside in a larger function space than the RKHS.

---

> **Note:** This last point, about GP samples not being in the RKHS, often raises questions, which we'll address further in this lecture.

---

## What About the Samples? Why Does it Matter if Samples are Outside the RKHS?

A common question from beginners:

> *"Why does it matter in practice that our samples can be outside of the RKHS? Aren't most of them concentrated around the mean and thus inside the RKHS anyway?"*

- The intuition that functions "close" to each other pointwise should also be "close" in a function space norm is appealing but often misleading.
- Just because two functions are close at every point, i.e., $|f(x) - g(x)| < \epsilon$ for all $x \in X$, they don't have to be close in a norm on the function space, i.e., we could have $\|f - g\| \gg \epsilon$.
- Different norms measure different aspects of "closeness" or "smoothness."

---

### **Driscoll's Zero-One Law**

In most interesting cases, especially for infinite-dimensional RKHSs, GP samples do **not** lie in the RKHS of the kernel $k$. This is a consequence of a powerful result:

> **Theorem (Driscoll's Zero-One Law, simplified; see Kanagawa et al., 2018, Theorem 4.9):**  
> Let $f \sim \mathcal{GP}(m, k)$ be a Gaussian process with $m \in \mathcal{H}_k$ on the probability space $(\Omega, \mathcal{F}, P)$. If $\mathcal{H}_k$ is infinite-dimensional, then:
>
> $$
> P(f \in \mathcal{H}_k) = 0
> $$
>
> This means that if the RKHS is infinite-dimensional (which is true for common kernels like the Squared Exponential), the probability of a randomly drawn sample path from the GP actually belonging to the RKHS is zero. This is a "zero-one law" because the probability is either 0 or 1.

---

## Why Should You Care in Practice? Smoothness Properties

This theoretical distinction has practical implications, particularly concerning **smoothness properties**:

- The RKHS norm often penalizes "roughness" or "complexity."
- If GP samples almost surely have infinite RKHS norm, it implies they are "rougher" than functions typically found in the RKHS.

**This matters when:**

- **Interpreting samples:** If you're sampling from a GP to visualize possible realizations, understand that these samples are not as "smooth" (in the RKHS sense) as the posterior mean.
- **Choosing kernels:** If you have prior knowledge about the true underlying function's smoothness, choose a kernel whose sample space (not just its RKHS) aligns with that knowledge. This can lead to better generalization and faster learning.

---

## What About the Samples? Detailed Look at Sample Spaces of Gaussian Processes

It's generally difficult to talk about the precise "sample space" of a GP. Instead, one typically identifies other spaces of functions that contain the samples as a subset.

- If we know the target function lies in a certain space (e.g., it's known to be continuous or differentiable), we should choose the kernel of the GP such that its sample space matches that space as closely as possible. This can accelerate learning.

**Examples of function spaces that might contain GP samples:**

- $\mathbb{R}^X$ (the space of all real-valued functions on $X$, which is too large for practical use).
- **Banach space $C(X)$** of continuous functions. Kernels like the Squared Exponential and Matérn ($\nu > 0$) produce continuous sample paths.
- **Banach space $C^k(X)$** of $k$-times continuously differentiable functions (relevant for derivative observations or modeling smooth physical processes).
- **Sobolev spaces $W_2^k(X)$** (e.g., for inferring solutions to PDEs, where functions are required to have square-integrable derivatives up to order $k$).

---

## GP Samples are Not in the RKHS! But Almost...

While GP samples are almost surely **not** in the RKHS, they belong to a kind of "completion" of the RKHS, often referred to as a **"power" of the RKHS**.

> **Theorem (Kanagawa, 2018; restricted from Steinwart, 2017, itself generalized from Driscoll, 1973):**  
> Let $\mathcal{H}_k$ be an RKHS and $0 < \theta \leq 1$. Consider the $\theta$-power of $\mathcal{H}_k$ given by:
>
> $$
> \mathcal{H}_k^\theta = \left\{ f(x) := \sum_{i \in I} \alpha_i \lambda_i^{\theta/2} \phi_i(x) \quad \text{such that} \quad |f|^2_{\mathcal{H}_k^\theta} := \sum_{i \in I} \alpha_i^2 < \infty \right\}
> $$
>
> with $\langle f, g \rangle_{\mathcal{H}_k^\theta} := \sum_{i \in I} \alpha_i \beta_i$.
>
> Then, if $\sum_{i \in I} \lambda_i^{1-\theta} < \infty$, it implies that $f \sim \mathcal{GP}(0, k) \in \mathcal{H}_k^\theta$ with probability 1.

**Interpretation:**  
GP samples are "almost" in the RKHS; they belong to a slightly larger space where the eigenvalues $\lambda_i$ decay sufficiently fast such that the sum $\sum_{i \in I} \lambda_i^{1-\theta}$ converges. This "power" of the RKHS can be strictly larger than the RKHS itself.

## How Expensive is GP Regression? What if the Dataset is Very Large?

Gaussian Process (GP) regression is celebrated for its flexibility and principled uncertainty quantification, but these benefits come at a significant computational cost—especially as datasets grow larger.

---

### **The Computational Bottleneck**

- **Kernel Matrix Construction:**  
    In GP regression, we construct the kernel (covariance) matrix $K_{XX}$, which is of size $N \times N$, where $N$ is the number of training data points.
- **Matrix Inversion or Linear Solves:**  
    Making predictions and computing the marginal likelihood both require either inverting $K_{XX}$ or solving linear systems involving it.

---

### **Why is This Expensive?**

- **Cubic Complexity:**  
    The direct inversion of an $N \times N$ matrix has computational complexity $\mathcal{O}(N^3)$.  
    - *Implication:* Doubling the dataset size increases computation time by a factor of eight!
- **Memory Usage:**  
    Storing the kernel matrix requires $\mathcal{O}(N^2)$ memory, which can also become prohibitive for large $N$.

---

### **Scalability in Practice**

- For small datasets (hundreds of points), standard GP regression is practical and efficient.
- For medium to large datasets (thousands or more), the cubic scaling quickly becomes a bottleneck, making naive GP regression infeasible.

---

### **Why Does This Matter?**

Understanding these computational challenges is crucial for:
- **Choosing the right model for your data size**
- **Selecting appropriate approximation methods for scalability**
- **Appreciating the trade-offs between model expressiveness and computational feasibility**

---

> **Next:**  
> We'll dive deeper into the linear algebra operations at the heart of GP regression, and explore how modern numerical methods and approximations can help overcome these challenges.

# The Inside View on GP Regression: A Deeper Look into the Probabilistic ML Stack

Gaussian Process (GP) regression is a powerful and flexible tool in probabilistic machine learning. To truly understand its computational cost and structure, let's break down the GP regression process into different "layers" of abstraction, from the high-level application to the low-level numerical operations.

---

## 1. Application Layer

**What problem are we solving?**

- **Goal:** Learn a function $f: \mathcal{X} \to \mathbb{R}$ from input-output pairs $\{(x_i, y_i)\}_{i=1}^N$.
- **Setting:** This is a supervised machine learning problem, where we observe data and want to make predictions at new, unseen points.

---

## 2. Model Layer

**How do we model the data probabilistically?**

- **Prior:** We assume $f$ is a sample from a Gaussian process:
    $$
    p(f) = \mathcal{GP}(\mu, k)
    $$
    where $\mu$ is the mean function and $k$ is the kernel (covariance) function.

- **Likelihood:** The observed data $y$ are noisy observations of $f$ at the training inputs $X = [x_1, ..., x_N]$:
    $$
    p(y \mid f_X) = \mathcal{N}(y \mid f_X, \sigma^2 I_N)
    $$
    where $f_X = [f(x_1), ..., f(x_N)]^\top$ and $\sigma^2$ is the noise variance.

- **Inference:** Our goal is to compute the posterior distribution over functions given the data:
    $$
    p(f \mid y)
    $$

---

## 3. Object Layer

**How do we represent the model in code?**

We can encapsulate the GP model using Python classes. Below is a minimal, modular implementation using JAX for efficient computation.



## How Do We Actually Solve Linear Systems of Equations? The LU Decomposition

Solving a linear system $A\mathbf{x} = \mathbf{b}$ is a fundamental operation in numerical linear algebra. One of the most widely used methods for this is the **LU Decomposition**.

---

### **What is LU Decomposition?**

- **LU Decomposition** expresses any square matrix $A$ as the product of a **lower triangular matrix** $L$ and an **upper triangular matrix** $U$:
    $$
    A = LU
    $$
- This factorization allows us to solve linear systems efficiently by breaking the problem into simpler steps.

---

### **How Does the Decomposition Work?**

The process involves recursively partitioning the matrix $A$:

- At each step $i$, we write:
    $$
    A^{(i)} = \begin{pmatrix}
        \alpha^{(i)} & (u^{(i)})^\top \\
        b^{(i)} & B^{(i)}
    \end{pmatrix}
    $$
    where:
    - $\alpha^{(i)}$ is the pivot (diagonal element)
    - $u^{(i)}$ is the row vector above the diagonal
    - $b^{(i)}$ is the column vector below the diagonal
    - $B^{(i)}$ is the remaining submatrix

- The corresponding $L$ and $U$ blocks are:
    $$
    L^{(i)} = \begin{pmatrix}
        1 & 0 \\
        l^{(i)} & L^{(i+1)}
    \end{pmatrix}, \quad
    U^{(i)} = \begin{pmatrix}
        \alpha^{(i)} & (u^{(i)})^\top \\
        0 & U^{(i+1)}
    \end{pmatrix}
    $$

- The recursion is defined by:
    $$
    l^{(i)} = \frac{1}{\alpha^{(i)}} b^{(i)}
    $$
    $$
    A^{(i+1)} := L^{(i+1)} U^{(i+1)} = B^{(i)} - l^{(i)} (u^{(i)})^\top
    $$

- If all pivots $\alpha^{(i)}$ are non-zero, the recursion terminates and an LU decomposition exists.

- **Pivoting** (reordering rows/columns) is often used in practice for numerical stability, typically by choosing the largest absolute value in the current column as the pivot.

---

### **Computational Complexity**

- **LU decomposition** of an $N \times N$ matrix requires approximately $\frac{2}{3}N^3$ floating-point operations (flops).
- This is the dominant cost when solving a single linear system.

---

### **Solving Linear Systems with LU Decomposition**

Once $A = LU$ is known, solving $A\mathbf{x} = \mathbf{b}$ becomes much more efficient:

1. **Forward Substitution:**  
     Solve $L\mathbf{y} = \mathbf{b}$ for $\mathbf{y}$ (since $L$ is lower triangular).  
     - Cost: $\mathcal{O}(N^2)$ flops.

2. **Backward Substitution:**  
     Solve $U\mathbf{x} = \mathbf{y}$ for $\mathbf{x}$ (since $U$ is upper triangular).  
     - Cost: $\mathcal{O}(N^2)$ flops.

**Total cost:**  
- $\frac{2}{3}N^3$ (for decomposition) $+$ $2N^2$ (for both substitutions).

---

### **Why is LU Decomposition Useful?**

- If you need to solve multiple systems with the same $A$ but different $\mathbf{b}$, you only compute the LU decomposition once. Each subsequent solve then costs only $2N^2$ flops.
- This is much more efficient than inverting $A$ directly or recomputing the decomposition for each new right-hand side.

---

### **Summary Table**

| Step                      | Operation                | Complexity      |
|---------------------------|--------------------------|----------------|
| LU Decomposition          | $A = LU$                 | $\frac{2}{3}N^3$ flops |
| Forward Substitution      | $L\mathbf{y} = \mathbf{b}$ | $N^2$ flops    |
| Backward Substitution     | $U\mathbf{x} = \mathbf{y}$ | $N^2$ flops    |

---

### **Key Takeaways**

- LU decomposition is a foundational tool for efficiently solving linear systems.
- It is especially advantageous when solving for multiple right-hand sides.
- Understanding its computational cost is crucial for scaling up to large problems, such as those encountered in Gaussian Process regression and other machine learning applications.

In [None]:
import jax.numpy as jnp
import numpy as np  # For creating a simple matrix example


def manual_lu_decomposition(A: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Performs LU decomposition manually for a square matrix (without pivoting for simplicity).

    Args:
        A: A square JAX array of shape (N, N).

    Returns:
        A tuple (L, U) where L is lower triangular and U is upper triangular,
        such that L @ U = A.
    """
    N = A.shape[0]
    L = jnp.eye(N, dtype=A.dtype)
    U = jnp.copy(A)

    for i in range(N):
        # Check for zero pivot (simplified, real implementations use pivoting)
        if U[i, i] == 0:
            print(
                f"Warning: Zero pivot encountered at step {i}. LU decomposition may fail or be unstable."
            )
            # In a real scenario, this would require pivoting or a different decomposition.
            return L, U  # Return current state

        # Calculate multipliers for the current column
        # L[j, i] = U[j, i] / U[i, i] for j > i
        for j in range(i + 1, N):
            multiplier = U[j, i] / U[i, i]
            L = L.at[j, i].set(multiplier)
            # Perform row operation on U: U[j, :] = U[j, :] - multiplier * U[i, :]
            U = U.at[j, :].set(U[j, :] - multiplier * U[i, :])
    return L, U


def forward_substitution(L: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    """Solves Ly = b for y where L is a lower triangular matrix."""
    N = L.shape[0]
    y = jnp.zeros(N, dtype=b.dtype)
    for i in range(N):
        y = y.at[i].set((b[i] - jnp.dot(L[i, :i], y[:i])) / L[i, i])
    return y


def backward_substitution(U: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """Solves Ux = y for x where U is an upper triangular matrix."""
    N = U.shape[0]
    x = jnp.zeros(N, dtype=y.dtype)
    for i in range(N - 1, -1, -1):
        x = x.at[i].set((y[i] - jnp.dot(U[i, i + 1 :], x[i + 1 :])) / U[i, i])
    return x


# --- Example Usage ---
# Create a sample matrix A
np.random.seed(1)
A_np = np.random.rand(4, 4)
A = jnp.asarray(A_np)

print("Original Matrix A:\n", A)

# Perform manual LU decomposition
L_manual, U_manual = manual_lu_decomposition(A)
print("\nManual L:\n", L_manual)
print("\nManual U:\n", U_manual)

# Verify decomposition: L @ U should be close to A
A_reconstructed = L_manual @ U_manual
print("\nManual Reconstruction (L @ U):\n", A_reconstructed)
print(
    f"\nMax absolute difference (Manual LU): {jnp.max(jnp.abs(A - A_reconstructed)):.2e}"
)

# Solve a linear system A @ x = b
b_np = np.array([10.0, 12.0, 14.0, 16.0])
b = jnp.asarray(b_np)

print("\nRight-hand side b:\n", b)

# Solve Ly = b using forward substitution
y_solved = forward_substitution(L_manual, b)
print("\nSolved y (from Ly=b):\n", y_solved)

# Solve Ux = y using backward substitution
x_solved = backward_substitution(U_manual, y_solved)
print("\nSolved x (from Ux=y):\n", x_solved)

# Verify the solution with direct solve (for comparison)
x_direct = jnp.linalg.solve(A, b)
print("\nDirectly solved x (from A@x=b):\n", x_direct)

print(
    f"\nMax absolute difference (Manual vs Direct solve): {jnp.max(jnp.abs(x_solved - x_direct)):.2e}"
)


## The Cholesky Decomposition: For Symmetric Positive Definite Matrices

When working with Gaussian Processes, the kernel matrix $K_{XX}$ (and thus $K_{XX} + \sigma^2 I_N$) is always **symmetric positive definite**. For such matrices, we can use a more efficient and numerically stable factorization: the **Cholesky decomposition**.

---

### **What is Cholesky Decomposition?**

- If $A \in \mathbb{R}^{N \times N}$ is symmetric positive definite, it can be **uniquely** decomposed as:
    $$
    A = LL^\top
    $$
    where:
    - $L$ is a lower triangular matrix with positive diagonal entries,
    - $L^\top$ is its transpose.

---

### **Recursive Partitioning for Cholesky Decomposition**

At each step, we partition $A$ and $L$ as follows:
$$
A^{(i)} = \begin{pmatrix}
        \alpha^{(i)} & (b^{(i)})^\top \\
        b^{(i)} & B^{(i)}
\end{pmatrix}, \quad
L^{(i)} = \begin{pmatrix}
        \lambda^{(i)} & 0 \\
        l^{(i)} & L^{(i+1)}
\end{pmatrix}
$$

such that:
$$
A^{(i)} = L^{(i)} (L^{(i)})^\top
$$

This leads to the following update rules:
- $\lambda^{(i)} = \sqrt{\alpha^{(i)}}$  
    (since $\alpha^{(i)} > 0$ by positive definiteness)
- $l^{(i)} = \frac{1}{\lambda^{(i)}} b^{(i)}$
- $A^{(i+1)} := B^{(i)} - l^{(i)} (l^{(i)})^\top$

---

### **Why Use Cholesky?**

- **Efficiency:**  
    The Cholesky decomposition requires approximately $\frac{1}{3}N^3$ floating-point operations (flops), which is **about half the cost** of LU decomposition for general matrices.
- **Numerical Stability:**  
    It is more stable and less error-prone for symmetric positive definite matrices, which are common in GP regression and many other ML applications.

---

### **Summary Table**

| Step                      | Operation                | Complexity      |
|---------------------------|--------------------------|----------------|
| Cholesky Decomposition    | $A = LL^\top$            | $\frac{1}{3}N^3$ flops |
| Forward/Backward Substitution | Solve $L\mathbf{y} = \mathbf{b}$, $L^\top\mathbf{x} = \mathbf{y}$ | $N^2$ flops each |

---

### **Key Takeaways**

- Cholesky decomposition is the **preferred method** for solving linear systems involving symmetric positive definite matrices.
- It is especially important in Gaussian Processes, where the kernel matrix is always symmetric positive definite.
- Understanding and using Cholesky decomposition allows for faster, more stable computations in probabilistic machine learning.

---

In [None]:
import jax.numpy as jnp
from jax.scipy.linalg import cholesky as jax_cholesky  # JAX's optimized Cholesky
import numpy as np  # For creating a simple SPD matrix example


def manual_cholesky_decomposition(A: jnp.ndarray) -> jnp.ndarray:
    """
    Performs Cholesky decomposition manually for a symmetric positive definite matrix.

    Args:
        A: A symmetric positive definite JAX array of shape (N, N).

    Returns:
        A lower triangular JAX array L such that L @ L.T = A.
    """
    N = A.shape[0]
    L = jnp.zeros((N, N), dtype=A.dtype)

    for i in range(N):
        # Calculate L[i, i]
        # L[i, i] = sqrt(A[i, i] - sum(L[i, k]^2 for k from 0 to i-1))
        sum_sq_prev_elements = jnp.sum(L[i, :i] ** 2)
        L = L.at[i, i].set(jnp.sqrt(A[i, i] - sum_sq_prev_elements))

        # Calculate L[j, i] for j > i
        # L[j, i] = (A[j, i] - sum(L[j, k] * L[i, k] for k from 0 to i-1)) / L[i, i]
        for j in range(i + 1, N):
            sum_prod_prev_elements = jnp.sum(L[j, :i] * L[i, :i])
            L = L.at[j, i].set((A[j, i] - sum_prod_prev_elements) / L[i, i])
    return L


# --- Example Usage ---
# Create a symmetric positive definite matrix
# A simple way to get an SPD matrix is A = B @ B.T for some matrix B
np.random.seed(0)
B = np.random.rand(5, 5)  # A 5x5 random matrix
A_np = B @ B.T + np.eye(5) * 1e-6  # Add small identity for strict positive definiteness
A = jnp.asarray(A_np)

print("Original Matrix A:\n", A)

# Perform manual Cholesky decomposition
L_manual = manual_cholesky_decomposition(A)
print("\nManual Cholesky L:\n", L_manual)

# Verify the decomposition: L @ L.T should be close to A
A_reconstructed_manual = L_manual @ L_manual.T
print("\nManual Reconstruction (L @ L.T):\n", A_reconstructed_manual)
print(
    f"\nMax absolute difference (Manual): {jnp.max(jnp.abs(A - A_reconstructed_manual)):.2e}"
)

# Compare with JAX's built-in Cholesky (for verification)
L_jax = jax_cholesky(A, lower=True)
print("\nJAX built-in Cholesky L:\n", L_jax)

# Verify JAX's decomposition
A_reconstructed_jax = L_jax @ L_jax.T
print(
    f"\nMax absolute difference (JAX built-in): {jnp.max(jnp.abs(A - A_reconstructed_jax)):.2e}"
)

# Note that iteration i of this process is O((N-i)^2). The first step has to “touch” all N data points.


# GP Regression: Computational Summary

Let's summarize the key computational insights from our exploration of Gaussian Process (GP) regression:

---

## 1. **GP Regression = Matrix Decomposition**

- **Training a GP** (i.e., fitting to data) fundamentally reduces to decomposing a matrix.
- Specifically, we work with the **kernel matrix**:  
    $$
    K_{XX} + \sigma^2 I_N
    $$
    where $K_{XX}$ is the $N \times N$ covariance matrix and $\sigma^2$ is the noise variance.

---

## 2. **Cholesky Decomposition: The Workhorse**

- For **symmetric positive definite matrices** (which kernel matrices always are), the **Cholesky decomposition** is the gold standard:
    $$
    K_{XX} + \sigma^2 I_N = LL^\top
    $$
    where $L$ is lower triangular.
- **Why Cholesky?**
    - **Numerically stable**
    - **Efficient**: About half the cost of LU decomposition for general matrices

---

## 3. **Efficient Posterior Computation**

- Once the Cholesky factor $L$ is available:
    - **Posterior mean** and **posterior covariance** can be computed efficiently.
    - Each prediction (for a new test point) requires only $\mathcal{O}(N^2)$ operations.

---

## 4. **The Real Bottleneck: Cubic Complexity**

- **Cholesky decomposition itself is $\mathcal{O}(N^3)$** in time and $\mathcal{O}(N^2)$ in memory.
- This is the main computational bottleneck for GPs, especially as $N$ grows large (thousands or millions of data points).

---

## 5. **Not Just "Being Bayesian"**

- The $\mathcal{O}(N^3)$ cost is **not** due to computing the full posterior covariance (i.e., "being Bayesian").
- **Even if you only want the point estimate** (the posterior mean), you must solve a linear system involving the kernel matrix, which still requires an $\mathcal{O}(N^3)$ decomposition.
- **Key point:**  
    - The cost comes from the need to solve for the "representer weights" in the kernel expansion, not from uncertainty quantification per se.

---

## 6. **Why Are GPs So Much More Expensive Than Deep Learning?**

- **Deep learning** (with stochastic gradient descent) often scales as $\mathcal{O}(N)$ (or even $\mathcal{O}(1)$ per mini-batch iteration).
- **GPs and kernel machines** scale as $\mathcal{O}(N^3)$.
- **Why?**
    - GPs require global operations (matrix decompositions) that touch all $N$ data points at once.
    - Deep learning leverages local, incremental updates (mini-batches) and does not require global matrix operations.

---

## **Summary Table: Computational Complexity**

| Operation                        | Complexity         |
|-----------------------------------|-------------------|
| Kernel matrix construction        | $\mathcal{O}(N^2)$|
| Cholesky decomposition            | $\mathcal{O}(N^3)$|
| Posterior mean/covariance (per test point) | $\mathcal{O}(N^2)$|
| Deep learning (per mini-batch)    | $\mathcal{O}(1)$ or $\mathcal{O}(N)$|

---

> **Bottom line:**  
> The main computational challenge in GP regression is the $\mathcal{O}(N^3)$ scaling of matrix decomposition. This motivates the development of scalable approximations and specialized algorithms for large datasets.

# Grass is Greener on the Deep Side? Deep Learning Scalability

To understand why deep learning scales so well compared to Gaussian Processes (GPs), let's briefly examine how deep learning models are trained and contrast this with GP inference.

---

## Deep Learning: Empirical Risk Minimization and SGD

Training a deep neural network is typically formulated as an **empirical risk minimization (ERM)** problem. The goal is to find weights $\mathbf{w}$ that minimize a loss function $L(\mathbf{w})$:

$$
\mathbf{w}^* = \arg\min_{\mathbf{w} \in \mathbb{R}^D} L(\mathbf{w}) = \arg\min_{\mathbf{w} \in \mathbb{R}^D} \sum_{i=1}^N \ell(y_i, f(\mathbf{w}, x_i)) + r(\mathbf{w})
$$

- $\ell(y_i, f(\mathbf{w}, x_i))$: Loss for data point $i$ (e.g., squared error, cross-entropy)
- $r(\mathbf{w})$: Regularization term (e.g., weight decay)

### Stochastic Gradient Descent (SGD)

This minimization is usually performed using **stochastic, first-order methods** such as SGD or Adam. These methods use **mini-batches** of data to estimate the gradient:

$$
\nabla L(\mathbf{w}) = \frac{1}{N} \sum_{i=1}^N \nabla \ell(y_i, f(\mathbf{w}, x_i)) + \nabla r(\mathbf{w}) \approx \frac{1}{B} \sum_{j=1}^B \nabla \ell(y_{i(j)}, f(\mathbf{w}, x_{i(j)})) + \nabla r(\mathbf{w}) =: g(\mathbf{w})
$$

- $B$: Mini-batch size ($B \ll N$)
- $g(\mathbf{w})$: Stochastic gradient estimate

**Key Point:**  
- Each iteration's computational cost is $\mathcal{O}(B)$, independent of the total dataset size $N$.
- This allows deep learning to scale to massive datasets.

---


In [None]:
import jax.numpy as jnp
import jax.random as random
import optax  # A JAX-based optimization library
from jax import grad

# --- Conceptual Deep Learning Training Loop ---

# 1. Simulate a very simple dataset
num_total_samples = 100000  # N: Large dataset
input_dim = 10
output_dim = 1
key = random.PRNGKey(0)

# Dummy data (inputs and targets)
X_data = random.normal(key, (num_total_samples, input_dim))
y_data = random.normal(key, (num_total_samples, output_dim))


# 2. Define a very simple "deep learning" model (e.g., a linear model for simplicity)
# In reality, this would be a multi-layer neural network
def simple_model(params, x):
    """A simple linear model: y = x @ W + b"""
    return jnp.dot(x, params["W"]) + params["b"]


# 3. Define a simple loss function (Mean Squared Error)
def mse_loss(params, x_batch, y_batch):
    predictions = simple_model(params, x_batch)
    return jnp.mean((predictions - y_batch) ** 2)


# 4. Initialize model parameters
model_params = {
    "W": random.normal(key, (input_dim, output_dim)),
    "b": jnp.zeros(output_dim),
}

# 5. Setup optimizer
learning_rate = 0.01
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(model_params)

# Get the gradient function
loss_grad_fn = grad(mse_loss)

# --- Training Loop (Illustrating O(B) cost per iteration) ---
num_training_steps = 1000
batch_size = 64  # B: Mini-batch size (fixed, much smaller than N)

print(f"Simulating Deep Learning Training (N={num_total_samples}, B={batch_size})")
print("Cost per iteration is O(B), independent of N.")

for step in range(num_training_steps):
    # Randomly select a mini-batch
    batch_indices = random.randint(key, (batch_size,), 0, num_total_samples)
    x_batch = X_data[batch_indices]
    y_batch = y_data[batch_indices]

    # Compute loss and gradients for the mini-batch
    # This operation's cost depends on B, not N
    grads = loss_grad_fn(model_params, x_batch, y_batch)

    # Update model parameters
    updates, opt_state = optimizer.update(grads, opt_state, model_params)
    model_params = optax.apply_updates(model_params, updates)

    if step % 200 == 0:
        current_loss = mse_loss(model_params, x_batch, y_batch)
        print(f"Step {step}, Batch Loss: {current_loss:.4f}")

print("\nDeep Learning training simulation complete.")
print(
    "The key takeaway is that each step's computation depends on batch_size, not total_samples."
)



## Can GPs Be Trained with SGD? Re-phrasing Inference as Optimization

Given the scalability of SGD, a natural question arises:

> **Can we train GPs using SGD-like methods?**

### Step 1: GP Inference as Optimization

The **posterior mean** $\mu_y(X)$ at the training data $X$ is the mode of $p(f_X \mid y)$. We can find this mode by minimizing the negative log-posterior:

$$
\mu_y(X) = \arg\max_{f_X \in \mathbb{R}^N} \log p(f_X \mid y) = \arg\max_{f_X \in \mathbb{R}^N} \log p(y \mid f_X) + \log p(f_X)
$$

Or, equivalently:

$$
\mu_y(X) = \arg\min_{f_X \in \mathbb{R}^N} -\log p(y \mid f_X) - \log p(f_X)
$$

For a **Gaussian likelihood** and a **GP prior** (with zero mean for simplicity):

$$
\mu_y(X) = \arg\min_{f_X \in \mathbb{R}^N} \frac{1}{2\sigma^2} \sum_{i=1}^N |y_i - (f_X)_i|^2 + \frac{1}{2} (f_X - \mu_X)^\top K_{XX}^{-1} (f_X - \mu_X) + \text{const.}
$$

- The first term is the **data fidelity** (likelihood).
- The second term is the **regularization** (prior).

This is a **convex optimization problem**.

**Note:**  
Because of the conditional independence $f(\cdot) \perp\!\!\!\perp y \mid f(X)$, it suffices to find $\mu_y(X)$; predictions at new points depend only on $f(X)$.

---


In [None]:
import jax.numpy as jnp
import jax.random as random
from jax import grad, jit
from typing import Callable
import optax  # A JAX-based optimization library


# --- Re-using kernel definition ---
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """Computes the Squared Exponential (RBF) kernel matrix."""
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# --- Optimization Problem for f_X ---
# We want to minimize:
# L(f_X) = (1 / (2 * sigma_noise^2)) * ||y - f_X||^2 + (1/2) * (f_X - mu_X)^T @ K_XX_inv @ (f_X - mu_X)


@jit
def gp_posterior_loss(
    f_X: jnp.ndarray,  # The function values at training points (our "parameters" to optimize)
    y_train: jnp.ndarray,
    X_train: jnp.ndarray,
    mean_func: Callable[[jnp.ndarray], jnp.ndarray],
    kernel_func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
    noise_variance: float,
) -> float:
    """
    Computes the negative log-posterior of f_X given y_train.
    This is the loss function we would minimize to find the posterior mode of f_X.
    Note: This direct implementation requires K_XX_inv, which is O(N^3) to compute.
    """
    mu_X = mean_func(X_train)
    K_XX = kernel_func(X_train, X_train)

    # Add a small jitter for numerical stability if K_XX is nearly singular
    jitter = 1e-6 * jnp.eye(X_train.shape[0])
    K_XX_stable = K_XX + jitter

    # Compute K_XX_inv (O(N^3) operation)
    K_XX_inv = jnp.linalg.inv(K_XX_stable)

    # Likelihood term
    likelihood_term = jnp.sum((y_train - f_X) ** 2) / (2.0 * noise_variance)

    # Prior term
    prior_diff = f_X - mu_X
    prior_term = 0.5 * jnp.dot(prior_diff.T, jnp.dot(K_XX_inv, prior_diff))

    return likelihood_term + prior_term


# --- Example Usage ---
key = random.PRNGKey(789)
N_train = 50  # Number of training points
X_train_opt = jnp.linspace(-5, 5, N_train)[:, None]
y_train_opt = jnp.sin(X_train_opt).squeeze() + 0.2 * random.normal(key, (N_train,))

zero_mean_opt = lambda x: jnp.zeros(x.shape[0])
rbf_kernel_opt = lambda x1, x2: squared_exponential_kernel(
    x1, x2, sigma=1.0, lengthscale=1.0
)
noise_var_opt = 0.1**2

# Initialize f_X (our "parameters" for optimization)
# Start with the prior mean as an initial guess
initial_f_X = zero_mean_opt(X_train_opt)

# Setup optimizer for f_X
learning_rate_fX = 0.1
optimizer_fX = optax.adam(learning_rate_fX)
opt_state_fX = optimizer_fX.init(initial_f_X)

# Get the gradient function for our loss
loss_grad_fn_fX = grad(gp_posterior_loss)

# Optimization loop
num_opt_steps = 1000
f_X_current = initial_f_X
losses_fX = []

print(f"Optimizing f_X (N={N_train}) to find posterior mode.")
for step in range(num_opt_steps):
    loss_value = gp_posterior_loss(
        f_X_current,
        y_train_opt,
        X_train_opt,
        zero_mean_opt,
        rbf_kernel_opt,
        noise_var_opt,
    )
    losses_fX.append(loss_value)

    grads = loss_grad_fn_fX(
        f_X_current,
        y_train_opt,
        X_train_opt,
        zero_mean_opt,
        rbf_kernel_opt,
        noise_var_opt,
    )
    updates, opt_state_fX = optimizer_fX.update(grads, opt_state_fX, f_X_current)
    f_X_current = optax.apply_updates(f_X_current, updates)

    if step % 200 == 0:
        print(f"Step {step}, Loss: {loss_value:.4f}")

print(
    f"Final Loss: {gp_posterior_loss(f_X_current, y_train_opt, X_train_opt, zero_mean_opt, rbf_kernel_opt, noise_var_opt):.4f}"
)
print("\nOptimized f_X (posterior mode at training points):\n", f_X_current[:5])

# Plot the optimization loss
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
plt.plot(losses_fX)
plt.xlabel("Optimization Step")
plt.ylabel("Negative Log-Posterior of f_X")
plt.title("Optimization of f_X to find Posterior Mode")
plt.grid(True)
plt.show()


# Compare optimized f_X with analytical posterior mean (from gp_predict)
# Need to define gp_predict function here or import it
def gp_predict(X_train, y_train, X_test, mean_func, kernel_func, noise_variance):
    K_train_train = kernel_func(X_train, X_train) + noise_variance * jnp.eye(
        X_train.shape[0]
    )
    K_test_train = kernel_func(X_test, X_train)
    mu_pred = mean_func(X_test) + jnp.dot(
        K_test_train, jnp.linalg.solve(K_train_train, y_train - mean_func(X_train))
    )
    # For this comparison, we only need the mean, not full covariance
    return mu_pred, None  # Return None for Sigma_pred


analytical_mu_pred_at_train, _ = gp_predict(
    X_train_opt, y_train_opt, X_train_opt, zero_mean_opt, rbf_kernel_opt, noise_var_opt
)

print(
    f"\nMax absolute difference between optimized f_X and analytical posterior mean: {jnp.max(jnp.abs(f_X_current - analytical_mu_pred_at_train)):.2e}"
)
print(
    "This shows that optimizing the negative log-posterior of f_X indeed finds the analytical posterior mean."
)



## The Problem: The Cost of Being Nonparametric

The core issue is that the "weights" in this optimization are the function values at the training points, $f_X$. This means:

- The number of parameters **grows with $N$** (the dataset size).
- Each gradient evaluation involves $N$ terms, so even with mini-batching, the cost per gradient step is at least $\mathcal{O}(N)$.

This is **not** due to being Bayesian or probabilistic—it's an inherent property of **nonparametric models** (sometimes called "infinitely-wide neural networks").

---

### Possible Solutions (Beyond This Lecture)

- **Return to a parametric representation:** $f(x) = \phi(x)^\top \mathbf{w}$
- **Finite approximations:** Use sparse GP regression, inducing points, random features, or spectral methods to reduce the effective number of parameters.

For now, we accept that we must deal with $N$ parameters.

---

## Computing Gradients for GP Optimization

Recall the optimization problem for $\mu_y(X)$:

$$
\mu_y(X) = \arg\min_{f_X \in \mathbb{R}^N} \frac{1}{2\sigma^2} \sum_{i=1}^N |y_i - (f_X)_i|^2 + \frac{1}{2} (f_X - \mu_X)^\top K_{XX}^{-1} (f_X - \mu_X)
$$

We could implement this loss directly and compute its gradient, but this still requires $K_{XX}^{-1}$, which is $\mathcal{O}(N^3)$.

---

### Analytical Solution for the Posterior Mean

We already know the analytical solution:

$$
\mu_y(\cdot) = \mu(\cdot) + k_{\cdot X} (K_{XX} + \sigma^2 I_N)^{-1} (y - \mu_X) = \mu(\cdot) + k_{\cdot X} \alpha
$$

where

$$
\alpha = (K_{XX} + \sigma^2 I_N)^{-1} (y - \mu_X)
$$

$\alpha$ minimizes the quadratic function:

$$
L(\alpha) = \frac{1}{2} \alpha^\top (K_{XX} + \sigma^2 I_N) \alpha - (y - \mu_X)^\top \alpha + \text{const.}
$$

The gradient is:

$$
\nabla_\alpha L(\alpha) = (K_{XX} + \sigma^2 I_N) \alpha - (y - \mu_X)
$$

Setting this to zero gives the solution for $\alpha$.

- **Computing the gradient:** $\mathcal{O}(N^2)$ (matrix-vector product)
- **Solving for $\alpha$:** $\mathcal{O}(N^3)$ (due to matrix decomposition)

Modern frameworks like JAX can compute these gradients efficiently, but the underlying matrix operations remain the computational bottleneck.

---

## **Summary Table: Deep Learning vs. GP Regression**

| Operation                        | Deep Learning (SGD) | GP Regression         |
|-----------------------------------|---------------------|----------------------|
| Per-iteration cost                | $\mathcal{O}(B)$    | $\mathcal{O}(N)$     |
| Parameter count                   | Fixed ($D$)         | Grows with $N$       |
| Matrix inversion/decomposition    | Rare                | Required ($\mathcal{O}(N^3)$) |

---

> **Bottom line:**  
> Deep learning scales well because its per-iteration cost is independent of dataset size, thanks to mini-batching and fixed parameter count.  
> GP regression, being nonparametric, fundamentally requires global operations that scale poorly with $N$.

In [None]:
import jax.numpy as jnp
from jax import grad, jit
from jax.scipy.linalg import solve
from typing import Callable


# --- Re-using kernel definition ---
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """Computes the Squared Exponential (RBF) kernel matrix."""
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# --- Define the quadratic loss function for alpha ---
# L(alpha) = 0.5 * alpha^T @ (K_XX + sigma^2 I) @ alpha - (y - mu_X)^T @ alpha + const.
@jit
def alpha_loss_function(
    alpha: jnp.ndarray,
    y_diff: jnp.ndarray,  # (y - mu_X)
    K_XX_noisy: jnp.ndarray,  # (K_XX + sigma^2 I)
) -> float:
    """
    Computes the quadratic loss function for alpha.
    Minimizing this loss gives the representer weights.
    """
    term1 = 0.5 * jnp.dot(alpha.T, jnp.dot(K_XX_noisy, alpha))
    term2 = jnp.dot(y_diff.T, alpha)
    return term1 - term2


# --- Example Usage ---
# Use dummy data from previous examples
N_example = 5
X_train_ex = jnp.linspace(0, 10, N_example)[:, None]
y_train_ex = jnp.sin(X_train_ex).squeeze() + 0.1 * jnp.array(
    [0.5, -0.2, 0.1, -0.3, 0.4]
)
zero_mean_ex = lambda x: jnp.zeros(x.shape[0])
rbf_kernel_ex = lambda x1, x2: squared_exponential_kernel(
    x1, x2, sigma=1.0, lengthscale=1.0
)
noise_variance_ex = 0.1**2

# Precompute necessary terms
mu_X_ex = zero_mean_ex(X_train_ex)
y_diff_ex = y_train_ex - mu_X_ex
K_XX_ex = rbf_kernel_ex(X_train_ex, X_train_ex)
K_XX_noisy_ex = K_XX_ex + noise_variance_ex * jnp.eye(N_example)

# Define the gradient function using JAX's grad
grad_alpha_loss = grad(alpha_loss_function)

# Compute the gradient at an arbitrary alpha (e.g., zeros)
alpha_initial = jnp.zeros(N_example)
gradient_at_initial_alpha = grad_alpha_loss(alpha_initial, y_diff_ex, K_XX_noisy_ex)

print("Initial alpha (zeros):\n", alpha_initial)
print("\nGradient of L(alpha) at initial alpha:\n", gradient_at_initial_alpha)

# The analytical solution for alpha_star is when the gradient is zero:
# (K_XX + sigma^2 I) @ alpha - (y - mu_X) = 0
# alpha_star = inv(K_XX + sigma^2 I) @ (y - mu_X)
alpha_star_analytical = solve(K_XX_noisy_ex, y_diff_ex)

print("\nAnalytical alpha_star:\n", alpha_star_analytical)

# Compute the gradient at the analytical alpha_star (should be close to zero)
gradient_at_alpha_star = grad_alpha_loss(
    alpha_star_analytical, y_diff_ex, K_XX_noisy_ex
)
print(
    "\nGradient of L(alpha) at analytical alpha_star (should be near zero):\n",
    gradient_at_alpha_star,
)
print(
    f"Max absolute value of gradient at alpha_star: {jnp.max(jnp.abs(gradient_at_alpha_star)):.2e}"
)

# This illustrates that the gradient computation itself is feasible,
# but finding the alpha_star (by setting gradient to zero or optimizing)
# still involves solving the linear system, which is the O(N^3) step.


# What About the Uncertainty? Getting the Posterior Covariance

Quantifying uncertainty is a **major advantage of Gaussian Processes (GPs)**. The **posterior covariance** $k_y(\cdot, \circ)$ (or $\operatorname{Cov}(f_X, f_X)$ for training points) provides a principled measure of this uncertainty.

---

## The Posterior Covariance in GP Regression

### **1. The Laplace Approximation (Exact for Gaussians)**

For Gaussian models, the **Laplace approximation** to the posterior is **exact**. The negative log-posterior is a quadratic function of $f_X$:

$$
-\log p(f_X \mid y) = \frac{1}{2\sigma^2} \|y - f_X\|^2 + \frac{1}{2} (f_X - \mu_X)^\top K_{XX}^{-1} (f_X - \mu_X) + \text{const.}
$$

- The first term is the **likelihood** (data fit).
- The second term is the **prior** (regularization).

---

### **2. Computing the Posterior Covariance**

The **Hessian** (second derivative) of the negative log-posterior with respect to $f_X$ gives the **inverse posterior covariance**:

$$
\nabla \nabla^\top \log p(f_X \mid y) = -(\sigma^{-2} I + K_{XX}^{-1})
$$

For a Gaussian, we know:

$$
\nabla \nabla^\top \log \mathcal{N}(x; m, V) = -V^{-1}
$$

Thus, the **posterior covariance** on $f_X$ is:

$$
\operatorname{Cov}(f_X, f_X) = (\sigma^{-2} I + K_{XX}^{-1})^{-1}
$$

---

### **3. The Woodbury Matrix Identity**

Using the **Woodbury matrix identity**, we can rewrite the posterior covariance as:

$$
\operatorname{Cov}(f_X, f_X) = K_{XX} - K_{XX} (\sigma^2 I + K_{XX})^{-1} K_{XX}
$$

- This form is often more numerically stable and interpretable.
- However, we **still need to invert** $(K_{XX} + \sigma^2 I)$ (or solve a linear system involving it), which is $\mathcal{O}(N^3)$ in time.

---

## **Summary Table: Posterior Covariance Computation**

| Step                              | Formula                                                                 | Computational Cost      |
|------------------------------------|------------------------------------------------------------------------|------------------------|
| Inverse posterior covariance       | $(\sigma^{-2} I + K_{XX}^{-1})$                                        | $\mathcal{O}(N^3)$     |
| Posterior covariance (Woodbury)    | $K_{XX} - K_{XX} (\sigma^2 I + K_{XX})^{-1} K_{XX}$                    | $\mathcal{O}(N^3)$     |

---

## **Why Does This Matter?**

- The ability to **quantify uncertainty** is a key reason to use GPs.
- Computing the posterior covariance is **computationally expensive** for large datasets, motivating scalable approximations.

---

## **Let's Illustrate with Code**

We'll use a simple RBF kernel and some dummy data to demonstrate how to compute the posterior covariance in practice.

In [None]:
import jax.numpy as jnp
from jax.scipy.linalg import solve
from jax.linalg import (
    inv,
)  # For direct inverse illustration, but solve is preferred for stability


# Assume squared_exponential_kernel is defined as in previous notebooks
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """
    Computes the Squared Exponential (RBF) kernel matrix.
    """
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# Dummy data and parameters
N = 5  # Number of training points
X_train = jnp.linspace(0, 10, N)[:, None]
y_train = jnp.sin(X_train).squeeze() + 0.1 * jnp.array(
    [0.5, -0.2, 0.1, -0.3, 0.4]
)  # Some dummy observations
mean_func = lambda x: jnp.zeros(x.shape[0])
noise_variance = 0.1**2  # sigma^2

# Compute K_XX
K_XX = squared_exponential_kernel(X_train, X_train, sigma=1.0, lengthscale=1.0)

# K_XX_noisy = K_XX + sigma^2 * I
K_XX_noisy = K_XX + noise_variance * jnp.eye(N)

# Method 1: Using the Woodbury Identity form for posterior covariance
# Cov(fX, fX) = K_XX - K_XX @ inv(sigma^2 * I + K_XX) @ K_XX
# This is equivalent to K_XX - K_XX @ inv(K_XX_noisy) @ K_XX
# For numerical stability, use solve instead of inv where possible
term_inv_K_XX_noisy_K_XX = solve(K_XX_noisy, K_XX)
post_cov_fX_woodbury = K_XX - jnp.dot(K_XX, term_inv_K_XX_noisy_K_XX)

print("Posterior Covariance (Woodbury Identity):\n", post_cov_fX_woodbury)

# Alternative form: (sigma^-2 * I + K_XX^-1)^-1
# This requires K_XX_inv, which is also O(N^3)
K_XX_inv = inv(K_XX)
post_cov_fX_laplace_direct = inv((1 / noise_variance) * jnp.eye(N) + K_XX_inv)

print("\nPosterior Covariance (Laplace Direct Form):\n", post_cov_fX_laplace_direct)

# Check if they are numerically close
print(
    f"\nMax absolute difference between Woodbury and Laplace forms: {jnp.max(jnp.abs(post_cov_fX_woodbury - post_cov_fX_laplace_direct)):.2e}"
)


## Second Approach: Jacobian/Sensitivity of the Minimizer

Recall that the optimal $\alpha^\star$ for the GP regression problem is given by:

$$
\alpha^\star = (K_{XX} + \sigma^2 I_N)^{-1}(y - \mu_X)
$$

This is the minimizer of the quadratic loss $L(\alpha)$. The **Jacobian** of this minimizer with respect to the observations $y$ is:

$$
\frac{d\alpha^\star}{dy} = (K_{XX} + \sigma^2 I_N)^{-1}
$$

This Jacobian tells us how sensitive the optimal weights $\alpha^\star$ are to changes in the observed data $y$.

---

### Posterior Covariance via the Jacobian

The posterior covariance function $k_y(\cdot, \circ)$ can be expressed in terms of this Jacobian:

$$
\begin{align*}
k_y(\cdot, \circ) &= k_{\cdot \circ} - k_{\cdot X}(K_{XX} + \sigma^2 I_N)^{-1}k_{X \circ} \\
                  &= k_{\cdot \circ} - k_{\cdot X} \left( \frac{d\alpha^\star}{dy} \right) k_{X \circ}
\end{align*}
$$

- $k_{\cdot X}$: Covariance vector between test point(s) and training points
- $k_{X \circ}$: Covariance vector between training points and another test point

This formulation highlights that the posterior covariance is directly related to the sensitivity of the solution $\alpha^\star$ to the data $y$.

---

> **Note:**  
> Even with this alternative perspective, computing the posterior covariance **exactly** still requires operations with at least $\Omega(N^3)$ time complexity, due to the need to invert or solve linear systems involving the $N \times N$ kernel matrix.

---

## Let's Illustrate the Jacobian $\frac{d\alpha^\star}{dy}$ in Code

We'll use JAX to compute and visualize the Jacobian, and show how it relates to the posterior covariance.

In [None]:
import jax.numpy as jnp
from jax.scipy.linalg import solve
from jax import jacfwd  # For forward-mode automatic differentiation

# Re-using K_XX_noisy from the previous code block
# K_XX_noisy = K_XX + noise_variance * jnp.eye(N)


# Define the function for alpha_star
# alpha_star(y) = inv(K_XX_noisy) @ (y - mean_func(X_train))
def alpha_star_func(y_obs: jnp.ndarray) -> jnp.ndarray:
    """
    Computes alpha_star given observations y_obs.
    K_XX_noisy and mean_func(X_train) are assumed to be fixed (closure over outer scope).
    """
    # Using solve for numerical stability instead of direct inv
    return solve(K_XX_noisy, y_obs - mean_func(X_train))


# Compute the Jacobian of alpha_star_func with respect to y_obs
# jacfwd computes the Jacobian matrix by applying forward-mode AD.
jacobian_alpha_star_dy = jacfwd(alpha_star_func)(y_train)

print("\nJacobian d(alpha_star)/dy:\n", jacobian_alpha_star_dy)

# Compare with the inverse of K_XX_noisy
K_XX_noisy_inv_computed = inv(K_XX_noisy)
print("\nInverse of (K_XX + sigma^2 I):\n", K_XX_noisy_inv_computed)

print(
    f"\nMax absolute difference between Jacobian and direct inverse: {jnp.max(jnp.abs(jacobian_alpha_star_dy - K_XX_noisy_inv_computed)):.2e}"
)

# This confirms that the Jacobian of alpha_star with respect to y is indeed the inverse of (K_XX + sigma^2 I).
# The full posterior covariance calculation (k_y(bullet, circ)) then uses this inverse.
# These code snippets demonstrate the core linear algebra operations involved in calculating the posterior covariance, highlighting the O(N3) complexity due to matrix inversion or solving linear systems.

In [None]:
import jax.numpy as jnp
from jax.scipy.linalg import solve
from jax.linalg import (
    inv,
)  # For direct inverse illustration, but solve is preferred for stability
from jax import jacfwd  # For forward-mode automatic differentiation


# Assume squared_exponential_kernel is defined as in previous notebooks
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """
    Computes the Squared Exponential (RBF) kernel matrix.
    """
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# Dummy data and parameters
N = 5  # Number of training points
X_train = jnp.linspace(0, 10, N)[:, None]
y_train = jnp.sin(X_train).squeeze() + 0.1 * jnp.array(
    [0.5, -0.2, 0.1, -0.3, 0.4]
)  # Some dummy observations
mean_func = lambda x: jnp.zeros(x.shape[0])
noise_variance = 0.1**2  # sigma^2

# Compute K_XX
K_XX = squared_exponential_kernel(X_train, X_train, sigma=1.0, lengthscale=1.0)

# K_XX_noisy = K_XX + sigma^2 * I
K_XX_noisy = K_XX + noise_variance * jnp.eye(N)

# Method 1: Using the Woodbury Identity form for posterior covariance
# Cov(fX, fX) = K_XX - K_XX @ inv(sigma^2 * I + K_XX) @ K_XX
# This is equivalent to K_XX - K_XX @ inv(K_XX_noisy) @ K_XX
# For numerical stability, use solve instead of inv where possible
term_inv_K_XX_noisy_K_XX = solve(K_XX_noisy, K_XX)
post_cov_fX_woodbury = K_XX - jnp.dot(K_XX, term_inv_K_XX_noisy_K_XX)

print("Posterior Covariance (Woodbury Identity):\n", post_cov_fX_woodbury)

# Alternative form: (sigma^-2 * I + K_XX^-1)^-1
# This requires K_XX_inv, which is also O(N^3)
K_XX_inv = inv(K_XX)
post_cov_fX_laplace_direct = inv((1 / noise_variance) * jnp.eye(N) + K_XX_inv)

print("\nPosterior Covariance (Laplace Direct Form):\n", post_cov_fX_laplace_direct)

# Check if they are numerically close
print(
    f"\nMax absolute difference between Woodbury and Laplace forms: {jnp.max(jnp.abs(post_cov_fX_woodbury - post_cov_fX_laplace_direct)):.2e}"
)


# --- Illustrate Jacobian/Sensitivity approach ---
# Define the function for alpha_star
# alpha_star(y) = inv(K_XX_noisy) @ (y - mean_func(X_train))
# We wrap K_XX_noisy and mean_func(X_train) in a closure for jacfwd
def alpha_star_func_closure(y_obs: jnp.ndarray) -> jnp.ndarray:
    """
    Computes alpha_star given observations y_obs, using precomputed K_XX_noisy.
    """
    return solve(K_XX_noisy, y_obs - mean_func(X_train))


# Compute the Jacobian of alpha_star_func_closure with respect to y_obs
# jacfwd computes the Jacobian matrix by applying forward-mode AD.
jacobian_alpha_star_dy = jacfwd(alpha_star_func_closure)(y_train)

print("\nJacobian d(alpha_star)/dy:\n", jacobian_alpha_star_dy)

# Compare with the inverse of K_XX_noisy
K_XX_noisy_inv_computed = inv(K_XX_noisy)
print("\nInverse of (K_XX + sigma^2 I):\n", K_XX_noisy_inv_computed)

print(
    f"\nMax absolute difference between Jacobian and direct inverse: {jnp.max(jnp.abs(jacobian_alpha_star_dy - K_XX_noisy_inv_computed)):.2e}"
)

# This confirms that the Jacobian of alpha_star with respect to y is indeed the inverse of (K_XX + sigma^2 I).
# The full posterior covariance calculation (k_y(bullet, circ)) then uses this inverse.


# Summary: The Role of Linear Algebra in Gaussian Processes

This lecture provided a comprehensive exploration of the computational foundations of Gaussian Processes (GPs), with a special focus on the pivotal role of linear algebra. Here are the key takeaways:

---

## 1. **GP Model Instantiation is Computationally Free**

- **Defining a GP** (specifying the mean and kernel functions) incurs negligible computational cost.
- The real computational challenges arise only when fitting the model to data.

---

## 2. **"Training" a GP = Solving an Optimization Problem**

- **Training a GP** means finding the posterior mean function, which is equivalent to minimizing the negative log-posterior.
- This is a **convex quadratic optimization problem** and can be solved analytically or with gradient-based methods.
- The optimization involves the kernel (covariance) matrix, which encodes the relationships between all pairs of training points.

---

## 3. **Efficient Point Predictions (After Training)**

- Once the kernel matrix has been decomposed (e.g., via Cholesky decomposition), **making predictions at new test points is efficient**:
    - **Posterior mean:** $\mathcal{O}(N)$ per test point
    - **Posterior variance:** $\mathcal{O}(N^2)$ per test point
- The initial matrix decomposition is the main computational hurdle.

---

## 4. **The $\mathcal{O}(N^3)$ Bottleneck**

- **Both the posterior mean and covariance** require solving linear systems involving the $N \times N$ kernel matrix.
- **Cholesky decomposition** (or similar matrix factorization) has a computational complexity of $\mathcal{O}(N^3)$.
- There is currently **no known linear-time algorithm** for exact GP inference with general kernels.

---

## 5. **Nonparametric Nature Drives the Cost**

- The $\mathcal{O}(N^3)$ scaling is a consequence of the **nonparametric** nature of GPs:
    - The number of "parameters" (function values at training points) grows with the dataset size $N$.
- This is **not** a result of being Bayesian or probabilistic.
- In contrast, **deep learning models** have a fixed number of parameters and leverage stochastic optimization, allowing for $\mathcal{O}(N)$ or even $\mathcal{O}(1)$ per-iteration scaling.

---

## **Looking Ahead**

- In the next lecture, we will:
    - Dive deeper into **Cholesky decomposition** and its practical implementation.
    - Discuss **data loading strategies** for large-scale GP models.
    - Explore advanced methods for **uncertainty estimation** in GPs, especially for large datasets.

---

> **Bottom line:**  
> Mastery of linear algebra is essential for understanding both the power and the computational limitations of Gaussian Processes. Efficient matrix operations are at the heart of scalable and robust GP inference.