# Computation and Inference in Gaussian Processes

Welcome to the final lecture in our **Probabilistic Machine Learning** series! In previous sessions, we established **Gaussian Processes (GPs)** as powerful probabilistic models for functions, explored various kernel functions, and delved into the theoretical underpinnings of **RKHS** and the connection between Bayesian GPs and Frequentist kernel methods.

---

## Lecture Overview

In this lecture, we'll bring together the concepts of computation and inference. We'll challenge the common perception of computational complexity in GPs versus deep learning and show how core linear algebra routines are, in essence, learning algorithms themselves.

**Specifically, we will:**

- **Revisit** the computational bottleneck of GPs and address common misunderstandings about their $O(N^3)$ complexity.
- **Explore** how matrix decompositions, particularly Cholesky, can be viewed as iterative "data loading" or "learning" processes.
- **Introduce** Schur complements and their role in efficiently updating inverse matrices and solutions.
- **Demonstrate** how iterative methods like Conjugate Gradients (via the Lanczos process) can provide efficient ways to update GP posteriors, offering a bridge between exact Bayesian inference and scalable iterative optimization.
- **Conclude** that for GPs, there is no fundamental separation between "computing" and "learning"; numerical algorithms are learning machines.

This lecture will provide a holistic view of the interplay between mathematical theory, numerical methods, and the practical implications for probabilistic machine learning.

---

## The Training Metaphor: Context of the Course

Our overarching goal for this course has been to develop a **probabilistic perspective** on contemporary machine learning. In this view, the process of "learning" is re-phrased as **inference**, which involves manipulating probability distributions on hypothesis spaces (e.g., function spaces for GPs).

So far, we have developed the analytic framework of **Gaussian Process (GP) regression** from first principles. This yields a clean and elegant picture, offering exact posterior distributions over functions. However, this analytic framework is practically limited in a few key ways:

- It's primarily limited to real-valued functions $f: X \to \mathbb{R}^C$ (though extensions exist for multi-output GPs).
- It inherently requires low-level linear algebra operations, particularly matrix inversions or solving linear systems involving the $N \times N$ kernel matrix, which leads to $O(N^3)$ computational complexity.

Over the coming lectures (and in the broader context of machine learning research), we trace a path all the way to contemporary deep learning. This involves thinking both about the **model** (what we're trying to learn) and about the **computation** (how we actually perform the learning).

------

### Addressing a Common Misunderstanding: Computational Complexity

- **GPs are $O(N^3)$**
- **Deep learning is $O(1)$ (per training step, for a fixed batch size)**

This comparison is often misleading because it compares a hard upper bound (for exact GP inference) to a loose lower bound (for a single stochastic gradient descent step in deep learning). The true picture is more nuanced, and we'll explore why this distinction is critical.

---


## The Cholesky Decomposition: Recap & Pseudocode

In the previous lecture, we introduced the **Cholesky decomposition** as a fundamental tool for exact Gaussian Process (GP) inference. This method efficiently and stably decomposes a symmetric positive definite (SPD) matrix $A$ into a lower triangular matrix $L$ such that:

$$
A = L L^\top
$$
### Mathematical Formulation: Iterative Cholesky Decomposition

Given a symmetric positive definite matrix $A \in \mathbb{R}^{n \times n}$, the Cholesky decomposition finds a lower triangular matrix $L$ such that:

$$
A = L L^\top
$$

The entries of $L$ are computed iteratively as follows:

- **Diagonal entries:**
  $$
  L_{ii} = \sqrt{A_{ii} - \sum_{k=1}^{i-1} L_{ik}^2}
  $$

- **Off-diagonal entries (for $j > i$):**
  $$
  L_{ji} = \frac{1}{L_{ii}} \left( A_{ji} - \sum_{k=1}^{i-1} L_{jk} L_{ik} \right)
  $$

**Iterative Process:**

1. For $i = 1$ to $n$:
    - Compute $L_{ii}$ using the formula above.
    - For each $j = i+1$ to $n$, compute $L_{ji}$.

This process "loads" one column of $L$ at a time, updating the remaining submatrix at each step. The algorithm is numerically stable and efficient for SPD matrices, making it the standard approach for GP inference and many other applications in scientific computing.
### Conceptual Pseudocode for Cholesky Decomposition

Below is a step-by-step pseudocode outlining the Cholesky decomposition process:

---

**Algorithm 1: Cholesky Decomposition**

- **Input:** Symmetric Positive Definite (SPD) matrix $A \in \mathbb{R}^{n \times n}$
- **Output:** Lower triangular matrix $L$, such that $LL^\top = A$

  ```python
  def cholesky_decomposition(A):
    """
    Perform Cholesky decomposition on a symmetric positive definite (SPD) matrix A.
    Returns lower triangular matrix L such that A = L @ L.T

    Args:
      A (jax.numpy.ndarray): SPD matrix of shape (n, n)

    Returns:
      L (jax.numpy.ndarray): Lower triangular matrix of shape (n, n)
    """
    import jax.numpy as jnp

    n = A.shape[0]
    L = jnp.zeros_like(A)

    for i in range(n):
      # Compute the diagonal element
      sum_k = jnp.sum(L[i, :i] ** 2)
      L = L.at[i, i].set(jnp.sqrt(A[i, i] - sum_k))

      # Compute the off-diagonal elements
      for j in range(i + 1, n):
        sum_k = jnp.sum(L[j, :i] * L[i, :i])
        L = L.at[j, i].set((A[j, i] - sum_k) / L[i, i])

    return L
  ```

  **Step-by-step Explanation:**

  - **Initialization:**  
    - Create a zero matrix $L$ of the same shape as $A$.
  - **Iterative Construction:**  
    - For each row $i$:
    - Compute the diagonal entry $L_{ii}$ using previously computed values.
    - For each row $j > i$, compute the off-diagonal entries $L_{ji}$.
  - **Return:**  
    - The lower triangular matrix $L$ such that $A = LL^\top$.

  **Key Observations:**
  - The cost of each iteration decreases as $i$ increases, since the submatrix being updated shrinks.
  - The overall computational complexity is $O(N^3)$, dominated by the first few iterations.

  **Visual Intuition:**
  - Imagine the matrix as a grid:
    - **Blue areas:** Already processed (loaded) into $L$.
    - **Gray area:** Remaining submatrix to be decomposed.
  - As the algorithm proceeds, the blue area grows and the gray area shrinks, mirroring the learning process.

  ---


## Computing the Inverse Alongside the Decomposition: Cholesky as a Linear Solver

The **Cholesky decomposition** is not just a tool for factorizing a matrix; it is also deeply connected to solving linear systems and, perhaps surprisingly, to iteratively building an approximation of the matrix inverse.

### Motivation

Suppose we want to find a low-rank approximation $C_i \approx A^{-1}$ as we proceed through the Cholesky decomposition. This is useful in Gaussian Processes and other kernel methods, where inverting large matrices is a computational bottleneck.

### Key Observation

If we have an approximation $L_i L_i^\top \approx A$, can we also approximate $A^{-1}$?

- If $L_i L_i^\top \approx A$, then $(A^{-1} L_i)(A^{-1} L_i)^\top \approx A^{-1}$.
- Define $C_i = (A^{-1} L_i)(A^{-1} L_i)^\top$ as our inverse approximation.

The core idea is to **track the effect of the inverse on the columns of $L_i$**. Consider the last column: $(A^{-1} L_i)_{:i} = A^{-1} l_i$. This term can be related to previous steps:

$$
A^{-1} l_i = A^{-1} \frac{A'_{:i}}{A'_{ii}}
= A^{-1} (A - L_{i-1} L_{i-1}^\top) e_i / \|e_i\|_{A'}
= (I - A^{-1} L_{i-1} L_{i-1}^\top) e_i / \|e_i\|_{A'}
= (I - C_{i-1} A) e_i / \|e_i\|_{A'}
$$

This shows that the inverse's action on the current column of $L$ is related to the **residual of the previous inverse approximation**.

---
### Algorithm 2: Cholesky Decomposition

**Input:** spd matrix $A$

```code
procedure CHOLESKY(A):
    A_prime = A
    L = []
    for i in range(n):
        e_i = canonical_basis_vector(i, n)
        l_i = A_prime[:, i] / sqrt(A_prime[i, i])
        A_prime = A_prime - jnp.outer(l_i, l_i)
        L.append(l_i)
    return stack_columns(L)
```

---

**Step-by-Step Explanation:**

1. **Initialization:**
   - Set $A'$ (A\_prime) to the input matrix $A$.
   - Initialize an empty list $L$ to store the columns of the Cholesky factor.

2. **Iterative Decomposition:**
   - For each index $i$ from $0$ to $n-1$:
     - Select the $i$-th canonical basis vector $e_i$ (a vector with $1$ at position $i$, $0$ elsewhere).
     - Compute the $i$-th column of $L$:
       - $l_i = A'_{:, i} / \sqrt{A'_{i, i}}$
       - This normalizes the $i$-th column of the current residual matrix.
     - Update the residual matrix $A'$:
       - $A' = A' - l_i l_i^\top$
       - This subtracts the outer product of $l_i$ with itself, removing the contribution of the $i$-th component.
     - Append $l_i$ to the list $L$.

3. **Return:**
   - After all iterations, stack the columns in $L$ to form the lower-triangular Cholesky factor.

---

**Intuition:**

- At each step, we extract the next column of $L$ by normalizing the current column of the residual matrix.
- We then "deflate" the matrix by removing the contribution of this column, ensuring that subsequent columns are orthogonalized.
- This process continues until all columns are processed, resulting in a lower-triangular matrix $L$ such that $A \approx LL^\top$.

---

**Key Points:**

- The Cholesky decomposition builds $L$ one column at a time, each time updating the matrix to account for what has already been "explained."
- The normalization by $\sqrt{A'_{i, i}}$ ensures numerical stability and that $L$ is lower-triangular.
- This iterative approach mirrors how information is sequentially incorporated in probabilistic inference and learning.


### Algorithm 3: Cholesky with Inverse Approximation

This algorithm integrates the computation of a low-rank inverse approximation $C_i$ directly into the Cholesky iterations.

**Input:** SPD matrix $A$

**Output:** Lower triangular $L_i$ (such that $L_i L_i^\top \approx A$), and low-rank $C_i \approx A^{-1}$

#### Step-by-Step Procedure: Cholesky with Inverse Approximation

Below is a clear, stepwise breakdown of the algorithm, with explanations for each step:

---

**Procedure:** `CHOLESKY(A)`

1. **Initialization**
    - Set $A' \leftarrow A$ (copy of the input matrix to be updated at each step)
    - Set $C_0 = 0$ (initialize the inverse approximation as a zero matrix)
    - Set $L_0 = [\ ]$ (initialize an empty list for columns of $L$)

2. **Iterative Updates (for $i = 1$ to $n$):**
    - **a. Select Action Vector**
        - $s_i \leftarrow e_i$  
          *(Choose the $i$-th canonical basis vector; this "loads" the $i$-th data point/column)*
    - **b. Compute Residual Direction**
        - $d_i \leftarrow (I - C_{i-1}A) s_i$  
          *(Find the part of $s_i$ not yet explained by the current inverse approximation)*
    - **c. Compute Normalization (Schur Complement)**
        - $\eta_i \leftarrow s_i^\top A d_i = e_i^\top A' e_i = \|e_i\|_{A'}^2$  
          *(Measures the "new information" added by $d_i$; ensures numerical stability)*
    - **d. Compute New Cholesky Column**
        - $l_i \leftarrow A \left(\frac{1}{\sqrt{\eta_i}}\right) d_i$  
          *(Construct the $i$-th column of $L$; scales $d_i$ appropriately)*
    - **e. Update Inverse Approximation**
        - $C_i \leftarrow C_{i-1} + \frac{1}{\eta_i} d_i d_i^\top$  
          *(Rank-1 update to the inverse estimate using the new direction)*
    - **f. Update Residual Matrix**
        - $A' \leftarrow A - L_i L_i^\top = A(A^{-1} - C_i)A = A(I - C_i A)$  
          *(Deflates $A$ by removing the contribution of the new column; prepares for next step)*
    - **g. Store Cholesky Column**
        - $L_i = (L_{i-1},\ l_i)$  
          *(Append $l_i$ as a new column to $L$)*

3. **Return**
    - Output the final lower-triangular matrix $L_n$ and the inverse approximation $C_n$.

---

**Summary Table of Steps**

| Step | Operation | Purpose |
|------|-----------|---------|
| 1    | $s_i \leftarrow e_i$ | Selects the $i$-th data direction |
| 2    | $d_i \leftarrow (I - C_{i-1}A) s_i$ | Finds unexplained component |
| 3    | $\eta_i \leftarrow s_i^\top A d_i$ | Normalizes update (Schur complement) |
| 4    | $l_i \leftarrow A (1/\sqrt{\eta_i}) d_i$ | Forms new Cholesky column |
| 5    | $C_i \leftarrow C_{i-1} + (1/\eta_i) d_i d_i^\top$ | Updates inverse approximation |
| 6    | $A' \leftarrow A - L_i L_i^\top$ | Updates residual matrix |
| 7    | $L_i = (L_{i-1},\ l_i)$ | Appends new column to $L$ |
```python
def cholesky_with_inverse_approximation(A):
  """
  Perform Cholesky decomposition with simultaneous low-rank inverse approximation.
  Args:
    A (jax.numpy.ndarray): SPD matrix of shape (n, n)
  Returns:
    L (jax.numpy.ndarray): Lower triangular matrix (Cholesky factor)
    C (jax.numpy.ndarray): Approximate inverse of A
  """
  import jax.numpy as jnp
  n = A.shape[0]
  A_prime = A.copy()
  C = jnp.zeros_like(A)
  L_cols = []
  for i in range(n):
    s_i = jnp.eye(n)[i]  # canonical basis vector
    d_i = s_i - C @ (A @ s_i)
    eta_i = s_i.T @ A @ d_i
    l_i = (A @ d_i) / jnp.sqrt(eta_i)
    C = C + (1 / eta_i) * jnp.outer(d_i, d_i)
    L_cols.append(l_i)
    # Optionally update A_prime if needed for further analysis
  L = jnp.stack(L_cols, axis=1)
  return L, C
```
---

**Intuition:**  
- Each iteration "loads" a new data direction, refines the inverse approximation, and updates the Cholesky factor.
- The process is analogous to sequentially learning from each data point, with uncertainty and mean estimates improving at every step.

---

**Key Takeaway:**  
This algorithm not only factorizes the matrix $A$ but also builds an increasingly accurate approximation to $A^{-1}$, step by step—mirroring the learning process in Gaussian Process inference.

**Computational Complexity:** The computational complexity of each step $i$ is $O(iN^2)$, which is still $O(N^3)$ overall.

**Key Insight:** Cholesky can be seen as an iterative learning algorithm for the kernel matrix and its inverse. Each step refines the approximation based on processing one "action" vector $s_i$.


In [None]:
import jax.numpy as jnp
import numpy as np
from jax.scipy.linalg import solve


# Assume squared_exponential_kernel is defined as in previous notebooks
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """Computes the Squared Exponential (RBF) kernel matrix."""
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


def iterative_cholesky_with_inverse_approx(A: jnp.ndarray):
    """
    Conceptual implementation of Cholesky with inverse approximation (Algorithm 3).
    This is for illustration and might not be numerically stable for large N.
    """
    N = A.shape[0]
    C_i = jnp.zeros((N, N), dtype=A.dtype)  # C_0 = 0
    L_i_cols = []  # To store columns of L

    # For verification later
    A_inv_true = jnp.linalg.inv(A)

    print("--- Iterative Cholesky with Inverse Approximation ---")
    for i in range(N):
        s_i = jnp.eye(N, dtype=A.dtype)[i, :]  # e_i (canonical basis vector)

        # d_i = (I - C_{i-1}A)s_i
        d_i = s_i - jnp.dot(C_i, jnp.dot(A, s_i))

        # eta_i = s_i^T A d_i
        eta_i = jnp.dot(s_i.T, jnp.dot(A, d_i))

        # l_i = A (1/sqrt(eta_i)) d_i
        # Note: This step is simplified. In the actual algorithm, l_i is a column of L.
        # For this conceptual code, we'll focus on C_i update.
        # The true l_i would be derived from the Cholesky update directly.
        # For inverse approximation, d_i is the key.

        # C_i = C_{i-1} + (1/eta_i) d_i d_i^T
        # Add a small epsilon to eta_i to prevent division by zero if it's too small
        eta_i_stable = eta_i + 1e-12  # Jitter for stability
        C_i_new = C_i + (1 / eta_i_stable) * jnp.outer(d_i, d_i)
        C_i = C_i_new  # Update C_i for next iteration

        print(f"\nIteration {i + 1}:")
        print(f"  eta_{i + 1}: {eta_i:.4f}")
        print(
            f"  Max abs diff C_{i + 1} vs True A_inv: {jnp.max(jnp.abs(C_i - A_inv_true)):.2e}"
        )

    print("\n--- Final Results ---")
    print("Final approximate inverse C_N:\n", C_i)
    print("\nTrue Inverse A_inv:\n", A_inv_true)
    print(
        f"\nFinal Max abs diff C_N vs True A_inv: {jnp.max(jnp.abs(C_i - A_inv_true)):.2e}"
    )
    return C_i


# --- Example Usage ---
# Create a symmetric positive definite matrix (e.g., from an RBF kernel)
N_matrix = 10  # Size of the matrix
X_data_matrix = jnp.linspace(-2, 2, N_matrix)[:, None]
K_XX_matrix = squared_exponential_kernel(
    X_data_matrix, X_data_matrix, sigma=1.0, lengthscale=0.5
)
# Add noise for positive definiteness (like in GP K_XX + sigma^2 I)
A_example = K_XX_matrix + 0.1**2 * jnp.eye(N_matrix)

iterative_cholesky_with_inverse_approx(A_example)


# Takeaways: Cholesky Iterations and Uncertainty

The iterative nature of the Cholesky decomposition reveals important insights about its computational properties and its role in Gaussian Process (GP) inference:

---

### 💡 **Key Insights**

- **Superlinear Expense:**  
    - Each iteration $i$ of Cholesky is $\mathcal{O}((N-i)^2)$, but the total computational complexity is $\mathcal{O}(N^3)$.
    - This means that as the matrix gets wider, each step becomes more expensive, and the overall cost grows rapidly with $N$.

- **Comparison to SGD:**  
    - Cholesky is computationally more expensive than a single step of Stochastic Gradient Descent (SGD), which is typically $\mathcal{O}(B)$ (where $B$ is the batch size) and does **not** depend on the total dataset size $N$.
    - In contrast, Cholesky's cost is tied directly to the full dataset (or kernel matrix).

- **Key Difference: Uncertainty Quantification:**  
    - **Cholesky** provides direct access to uncertainty: we obtain the full posterior covariance matrix (or its inverse), which quantifies our uncertainty about predictions.
    - **SGD** (in its basic form) only provides a point estimate (the mean), without uncertainty.

- **One "Epoch":**  
    - Cholesky completes "training" (i.e., exact inference) in a single pass (one "epoch") over the data or matrix columns.
    - SGD typically requires many epochs to converge to a good solution.

---

### 🧮 **Posterior Mean and Covariance in GP Regression**

The posterior mean $\mu_y(\cdot)$ and posterior covariance $k_y(\cdot, \circ)$ in GP regression are computed as:

$$
\mu_y(\cdot) = \mu_\cdot + k_{\cdot X} (K_{XX} + \sigma^2 I)^{-1} (y - \mu_X)
$$

$$
k_y(\cdot, \circ) = k_{\cdot \circ} - k_{\cdot X} (K_{XX} + \sigma^2 I)^{-1} k_{X \circ}
$$

- Here, $(K_{XX} + \sigma^2 I)^{-1}$ is the inverse of the noisy kernel matrix, typically computed via Cholesky decomposition.

---

### 🔄 **Iterative Approximation with Cholesky**

If we denote $C_i$ as our iterative approximation to $(K_{XX} + \sigma^2 I)^{-1}$ at step $i$, then at each step we can approximate the posterior mean and covariance as:

$$
\mu_y(\cdot) \approx \mu_\cdot + k_{\cdot X} C_i (y - \mu_X)
$$

$$
k_y(\cdot, \circ) \approx k_{\cdot \circ} - k_{\cdot X} C_i k_{X \circ}
$$

- **Interpretation:**  
    As the Cholesky decomposition proceeds, we are **iteratively refining our estimates** of the posterior mean and covariance.
    - Early iterations give rough approximations.
    - As more columns are processed, the estimates become more accurate.

---

### 📌 **Summary**

- Cholesky decomposition is more computationally intensive than SGD, but it provides **exact inference** and **uncertainty quantification** in one pass.
- The iterative process allows us to see how uncertainty estimates improve as we process more data.
- This perspective highlights the deep connection between numerical linear algebra and probabilistic inference in machine learning.

# Computing the Solution Alongside the Decomposition: Cholesky as a Dataloader

Building on the idea of iteratively approximating the inverse, we can also iteratively compute the solution vector $\alpha = A^{-1}y$ (which corresponds to the representer weights for the posterior mean in GPs).

## Algorithm 4: Cholesky with Inverse $C$ and Solution $\alpha$

This algorithm extends Algorithm 3 by also tracking an estimate of the solution vector $\alpha_i \approx A^{-1}y$.

**Input:** SPD matrix $A$, vector $y$

**Output:** Lower triangular $L_i$, low-rank $C_i \approx A^{-1}$, solution estimate $\alpha_i \approx A^{-1}y$


The iterative update for the solution vector $\alpha_i$ in the Cholesky/IterGP algorithm is:

$$
\alpha_i = \alpha_{i-1} + \frac{1}{\eta_i} d_i \left( d_i^\top y \right)
$$

where:
- $\alpha_{i-1}$ is the previous solution estimate,
- $d_i = (I - C_{i-1}A) e_i$ is the update direction,
- $\eta_i = e_i^\top A d_i$ is the normalization (Schur complement),
- $y$ is the data vector.

**Expanded for clarity:**

$$
\begin{align*}
d_i &= (I - C_{i-1}A) e_i \\
\eta_i &= e_i^\top A d_i \\
\alpha_i &= \alpha_{i-1} + \frac{1}{\eta_i} d_i (d_i^\top y)
\end{align*}
$$

This formula shows how each iteration refines the solution $\alpha$ by projecting the data $y$ onto the new direction $d_i$, scaled by the Schur complement $\eta_i$.


## Iterative Solution Update

At each iteration, the algorithm refines its estimate of the solution vector $\alpha = A^{-1}y$ using the current approximation of the inverse $C_i \approx A^{-1}$:

$$
\alpha_i = C_i y = \left(C_{i-1} + \frac{1}{\eta_i} d_i d_i^\top\right) y = \alpha_{i-1} + \frac{1}{\eta_i} d_i d_i^\top y
$$

This update can be further interpreted in terms of the residual of the previous estimate:

$$
\alpha_i = \alpha_{i-1} + \frac{1}{\eta_i} d_i \left[ y - A \alpha_{i-1} \right]_i
$$

where $[\,\cdot\,]_i$ denotes the $i$-th component (or projection) of the residual vector.

---

### **Interpretation**

- **Data Loading:**  
    Each step "loads" new information from the data vector $y$ by projecting the current residual (the difference between $y$ and the current prediction $A \alpha_{i-1}$) onto the update direction $d_i$.
- **Refinement:**  
    The update incrementally improves the solution, making $\alpha_i$ a better approximation to $A^{-1}y$ as more columns are processed.
- **Connection to Learning:**  
    This process mirrors how learning algorithms iteratively refine their predictions as they see more data.

---

### **Summary Table: Iterative Solution Update**

| Step | What Happens? | Mathematical Operation |
|------|---------------|-----------------------|
| 1    | Compute residual | $r_{i-1} = y - A \alpha_{i-1}$ |
| 2    | Project onto $d_i$ | $d_i^\top r_{i-1}$ |
| 3    | Scale and update | $\alpha_i = \alpha_{i-1} + \frac{1}{\eta_i} d_i (d_i^\top r_{i-1})$ |

---

> **In summary:**  
> The Cholesky decomposition, when paired with this iterative solution update, acts as a "dataloader"—sequentially incorporating information from $y$ to refine the solution $\alpha$. This perspective highlights the deep connection between numerical linear algebra and the learning process in probabilistic models like Gaussian Processes.


In [None]:
import jax.numpy as jnp
import numpy as np
from jax.scipy.linalg import solve


# Assume squared_exponential_kernel is defined as in previous notebooks
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """Computes the Squared Exponential (RBF) kernel matrix."""
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


def iterative_cholesky_with_solution_approx(A: jnp.ndarray, y_vec: jnp.ndarray):
    """
    Conceptual implementation of Cholesky with inverse and solution approximation (Algorithm 4).
    This is for illustration and might not be numerically stable for large N.
    """
    N = A.shape[0]
    C_i = jnp.zeros((N, N), dtype=A.dtype)  # C_0 = 0
    alpha_i = jnp.zeros(N, dtype=y_vec.dtype)  # alpha_0 = 0

    # For verification later
    alpha_true = jnp.linalg.solve(A, y_vec)

    print("--- Iterative Cholesky with Solution Approximation ---")
    for i in range(N):
        s_i = jnp.eye(N, dtype=A.dtype)[i, :]  # e_i (canonical basis vector)

        # d_i = (I - C_{i-1}A)s_i
        d_i = s_i - jnp.dot(C_i, jnp.dot(A, s_i))

        # eta_i = s_i^T A d_i
        eta_i = jnp.dot(s_i.T, jnp.dot(A, d_i))

        # Add a small epsilon to eta_i to prevent division by zero if it's too small
        eta_i_stable = eta_i + 1e-12

        # C_i = C_{i-1} + (1/eta_i) d_i d_i^T
        C_i_new = C_i + (1 / eta_i_stable) * jnp.outer(d_i, d_i)
        C_i = C_i_new

        # alpha_i = alpha_{i-1} + (1/eta_i) d_i d_i^T y
        alpha_i_new = alpha_i + (1 / eta_i_stable) * jnp.dot(d_i, jnp.dot(d_i.T, y_vec))
        alpha_i = alpha_i_new

        print(f"\nIteration {i + 1}:")
        print(f"  eta_{i + 1}: {eta_i:.4f}")
        print(
            f"  Max abs diff alpha_{i + 1} vs True alpha: {jnp.max(jnp.abs(alpha_i - alpha_true)):.2e}"
        )

    print("\n--- Final Results ---")
    print("Final approximate alpha_N:\n", alpha_i)
    print("\nTrue alpha:\n", alpha_true)
    print(
        f"\nFinal Max abs diff alpha_N vs True alpha: {jnp.max(jnp.abs(alpha_i - alpha_true)):.2e}"
    )
    return alpha_i


# --- Example Usage ---
# Create a symmetric positive definite matrix (e.g., from an RBF kernel)
N_matrix = 10  # Size of the matrix
X_data_matrix = jnp.linspace(-2, 2, N_matrix)[:, None]
K_XX_matrix = squared_exponential_kernel(
    X_data_matrix, X_data_matrix, sigma=1.0, lengthscale=0.5
)
# Add noise for positive definiteness (like in GP K_XX + sigma^2 I)
A_example = K_XX_matrix + 0.1**2 * jnp.eye(N_matrix)

# Create a corresponding y vector
y_example = jnp.sin(X_data_matrix).squeeze() + 0.2 * jnp.array(np.random.rand(N_matrix))

iterative_cholesky_with_solution_approx(A_example, y_example)


## Cholesky as Iterative Book-Keeping: Adding Data Points One-by-One

The iterative process of Cholesky decomposition—especially when used to compute the inverse and solution estimates—can be elegantly interpreted as a form of **effective book-keeping**. In this view, Cholesky is akin to **sequentially adding datapoints** to update a Gaussian Process (GP) posterior.

---

### **Sequential GP Posterior Updates: The Big Picture**

Suppose we have:
- A "training set" $f_X$ with a Gaussian prior:  
    $$
    p(f_X) = \mathcal{N}(f_X; \mu_X, k_{XX})
    $$
- A Gaussian likelihood:  
    $$
    p(y \mid f_X) = \mathcal{N}(y; f_X, \sigma^2 I)
    $$
- Define $K := k_{XX} + \sigma^2 I$ and $\tilde{y} := y - \mu_X$.

Now, imagine the observations in $y$ **arrive one at a time**. At iteration $i-1$, we've processed the first $i-1$ data points and have a posterior mean based on them:
$$
\mu_{i-1}(X) = \mu_X + k_{X, X_{[:i-1]}} K_{[:i-1], [:i-1]}^{-1} \tilde{y}_{[:i-1]}
$$

When the $i$-th data point arrives, we update to the full posterior mean after observing $i$ points:
$$
\mu_i(X) = \mu_X + 
\begin{pmatrix}
k_{X, X_{[:i]}} & k_{X, X_i}
\end{pmatrix}
\begin{pmatrix}
K_{[:i], [:i]} & K_{[:i], i} \\
K_{i, [:i]} & K_{i, i}
\end{pmatrix}^{-1}
\begin{pmatrix}
\tilde{y}_{[:i]} \\
\tilde{y}_i
\end{pmatrix}
$$

---

### **Cholesky = Sequential Posterior Updates**

- **Each step of Cholesky** effectively incorporates the information from one new "dimension" or "data point" into the overall system.
- This is achieved by updating the inverse and solution estimates using the **Schur complement**, which efficiently handles the addition of new rows/columns to the kernel matrix.
- The process mirrors how we would update the GP posterior if we received data points one at a time.

---

### **Key Takeaways**

- **Cholesky decomposition is not just a matrix factorization**—it is a *learning process* that sequentially "loads" information from each data point.
- This perspective helps us understand why Cholesky is so powerful for GP inference: it provides a principled, efficient way to update our beliefs as new data arrives.
- The connection to Schur complements highlights the deep interplay between linear algebra and probabilistic inference.

---

> **In summary:**  
> The Cholesky decomposition, when viewed through the lens of probabilistic machine learning, is a beautiful example of how numerical algorithms can be interpreted as iterative learning or inference procedures—each step refining our understanding as if we were adding one more observation to our dataset.

# Schur Complements: Low-Rank Updates to Matrices, Inverses, and Solutions

*By Issai Schur (1875–1941)*

---

Schur complements are a fundamental concept in block matrix algebra. They provide a powerful framework for understanding how the inverse of a matrix changes when we add or remove data points (or dimensions). This is especially important in **Gaussian Processes (GPs)**, where we often need to update the inverse of a kernel matrix as new data arrives.

---

## What is a Schur Complement?

Suppose we have a block matrix $A$ partitioned as:

$$
A = \begin{pmatrix}
P & Q \\
R & S
\end{pmatrix}
$$

where $P$ is invertible.

The **Schur complement** of $P$ in $A$ is defined as:

$$
M = S - R P^{-1} Q
$$

---

## Inverse of a Block Matrix Using Schur Complements

The inverse of $A$ can be written in terms of its blocks and the Schur complement $M$:

$$
A^{-1} = 
\begin{pmatrix}
P^{-1} + P^{-1} Q M^{-1} R P^{-1} & -P^{-1} Q M^{-1} \\
- M^{-1} R P^{-1} & M^{-1}
\end{pmatrix}
$$

This formula allows us to compute the inverse of $A$ efficiently if we already know $P^{-1}$ and $M^{-1}$.

Alternatively, the inverse can be expressed to highlight the **low-rank update** structure:

$$
A^{-1} = 
\begin{pmatrix}
P^{-1} & 0 \\
0 & 0
\end{pmatrix}
+
\begin{pmatrix}
- P^{-1} Q \\
I
\end{pmatrix}
M^{-1}
\begin{pmatrix}
- R P^{-1} & I
\end{pmatrix}
$$

---

## Application to Gaussian Processes: Adding a Data Point

In GPs, when we add a new data point, we augment the kernel matrix. Suppose our kernel matrix is:

$$
K = \begin{pmatrix}
K_{XX} & k_{X i} \\
k_{i X} & k_{ii}
\end{pmatrix}
$$

- $K_{XX}$: Kernel matrix for the first $N-1$ points
- $k_{X i}$: Cross-covariance vector between existing points and the new point
- $k_{ii}$: Kernel value at the new point

The inverse $K^{-1}$ can be updated using the Schur complement:

$$
K^{-1} = 
\begin{pmatrix}
K_{XX}^{-1} + K_{XX}^{-1} k_{X i} \eta^{-1} k_{i X} K_{XX}^{-1} & -K_{XX}^{-1} k_{X i} \eta^{-1} \\
- \eta^{-1} k_{i X} K_{XX}^{-1} & \eta^{-1}
\end{pmatrix}
$$

where

$$
\eta = k_{ii} - k_{i X} K_{XX}^{-1} k_{X i}
$$

is the **Schur complement**.

---

## Why is This Important?

- **Efficiency:**  
    Schur complements allow us to update the inverse of a matrix when adding (or removing) a data point, without recomputing the entire inverse from scratch. This avoids the full $\mathcal{O}(N^3)$ cost of matrix inversion.

- **Low-Rank Updates:**  
    The update to the inverse in Algorithm 3 ($C_i = C_{i-1} + \frac{1}{\eta_i} d_i d_i^\top$) is a direct application of the **rank-1 update formula** derived from Schur complements.

- **Probabilistic Interpretation:**  
    In GPs, this means we can efficiently update our posterior as new data arrives, making online or sequential inference practical.

---

## Summary

- **Schur complements** provide the mathematical machinery for efficient, incremental updates to matrix inverses.
- They are central to scalable Gaussian Process inference and many other areas in numerical linear algebra and statistics.
- Understanding Schur complements gives deep insight into how algorithms like Cholesky and iterative GP updates work under the hood.

---

> **Further Reading:**  
> - [Wikipedia: Schur complement](https://en.wikipedia.org/wiki/Schur_complement)  
> - Rasmussen & Williams, "Gaussian Processes for Machine Learning", Section 2.2.2  
> - Matrix Cookbook: [Schur Complement Section](https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf)

Cleaned-up Algorithm: Iterative GP Regression (Numerics Layer)

This algorithm presents a more generalized and cleaned-up version of the iterative process for GP regression, focusing on the numerics layer. It shows how the inverse estimate (Ci​) and the solution estimate (αi​) are updated iteratively.

---
### Algorithm 6: Iterative GP Regression (Numerics Layer)

**Input:**  
- $K = k_{XX} + \sigma^2 I$ (kernel matrix with noise)
- $\bar{y} = y - \mu_X$ (centered targets)

**Output:**  
- $S = [s_j]_{j \leq n}$ (list of processed vectors)
- $C_n \approx K^{-1}$ (approximate inverse)
- $\alpha_n \approx K^{-1} \bar{y}$ (approximate solution)

---

#### **Procedure: TRAIN($K$, $\bar{y}$)**

1. **Initialization**
  - $C_0 \leftarrow 0$ &nbsp;&nbsp;*(Initial inverse estimate)*
  - $\alpha_0 \leftarrow 0$ &nbsp;&nbsp;*(Initial solution estimate)*

2. **Iterative Updates (for $i = 1$ to $n$):**
  - **a. Select Action Vector**
    - $s_i \leftarrow e_i$  
      *(Load: Select the $i$-th canonical basis vector)*
  - **b. Compute Observation**
    - $z_i \leftarrow K s_i$  
      *(Apply $K$ to the action vector)*
  - **c. Compute Low-Rank Update / Residual**
    - $d_i \leftarrow (I - C_{i-1} K) s_i = s_i - C_{i-1} z_i$  
      *(Residual direction for update)*
  - **d. Compute Schur Complement / Normalization**
    - $\eta_i \leftarrow s_i^\top K d_i = z_i^\top d_i$  
      *(Normalization constant for stability)*
  - **e. Update Inverse Estimate**
    - $C_i \leftarrow C_{i-1} + \frac{1}{\eta_i} d_i d_i^\top$  
      *(Rank-1 update to inverse estimate)*
  - **f. Update Solution Estimate**
    - $\alpha_i \leftarrow \alpha_{i-1} + \frac{1}{\eta_i} d_i (d_i^\top \bar{y})$  
      *(Update solution using new direction)*

3. **Return**
  - $S = [s_j]_{j \leq n}$, $\alpha_n$, $C_n$

---

**Intuition:**  
- Each iteration "loads" a new data direction, refines the inverse and solution estimates, and updates the Cholesky factor.
- This process is analogous to sequentially learning from each data point, with uncertainty and mean estimates improving at every step.

---


## GP Prediction: Mean and Uncertainty Formulae

After training a Gaussian Process (GP) model, making predictions at new test points $x$ is efficient and interpretable. The prediction step uses the statistics computed during training to provide both the **predictive mean** and **predictive uncertainty**.

---

### **Mathematical Formulation**

Given:
- $k[x, S]$: Covariance vector (or matrix) between test points $x$ and the set of processed training vectors $S$.
- $\alpha$: Solution vector (posterior weights), typically $\alpha = K^{-1}(y - \mu_X)$.
- $C$: Inverse covariance estimate, typically $C \approx K^{-1}$.
- $k_{xx}$: Prior covariance at the test points.

The predictive mean and variance are:

$$
\begin{align*}
\mu_x &= \mu(x) + k[x, S]\, \alpha \\
v_{xx} &= k_{xx} - k[x, S]\, C\, k[S, x]
\end{align*}
$$

- $\mu(x)$: Mean function evaluated at $x$ (often zero).
- $k[x, S]$: Covariance between test points and training points.
- $k[S, x]$: Transpose of $k[x, S]$.

---

### **Step-by-Step Explanation**

1. **Compute Covariance to Training Data**
  - $k[x, S]$ gives how similar each test point is to each training point.

2. **Predictive Mean**
  - $\mu_x = \mu(x) + k[x, S]\, \alpha$
  - This is the expected value of the function at $x$ given the observed data.

3. **Predictive Variance (Uncertainty)**
  - $v_{xx} = k_{xx} - k[x, S]\, C\, k[S, x]$
  - This quantifies the model's uncertainty at $x$, accounting for both prior uncertainty and information gained from the data.

---

### **Intuition**

- The **predictive mean** is a weighted sum of the observed data, where the weights reflect both the similarity (via the kernel) and the influence of each training point.
- The **predictive variance** starts with the prior uncertainty ($k_{xx}$) and subtracts the amount of uncertainty explained by the training data (the second term).

---

### **Summary Table**

| Step                | Formula                                         | Interpretation                       |
|---------------------|-------------------------------------------------|--------------------------------------|
| Predictive Mean     | $\mu_x = \mu(x) + k[x, S]\, \alpha$             | Expected value at $x$                |
| Predictive Variance | $v_{xx} = k_{xx} - k[x, S]\, C\, k[S, x]$       | Model uncertainty at $x$             |

---

> **In summary:**  
> GP prediction is efficient and interpretable: you get both a mean prediction and a principled uncertainty estimate at every test point, using only matrix-vector products with the statistics computed during training.


## Key Takeaways from the Iterative GP Algorithm

- **Iterative Book-Keeping:**  
    The algorithm updates both the inverse estimate $C_i$ and the solution estimate $\alpha_i$ in an iterative fashion. At each step, it processes one "action" vector $s_i$ (often chosen as a canonical basis vector $e_i$), incrementally refining both the mean and uncertainty estimates.

- **Computational Cost per Step:**  
    - Each iteration involves matrix-vector products and outer products.
    - For $s_i = e_i$, the cost of computing $d_i$ is $\mathcal{O}(N^2)$ (since $C_{i-1} z_i$ is a matrix-vector product), and updating $C_i$ is also $\mathcal{O}(N^2)$.
    - With $N$ iterations, the total computational complexity is $\mathcal{O}(N^3)$, matching the cost of standard Cholesky decomposition.

- **Direct Uncertainty Quantification:**  
    Unlike stochastic gradient descent (SGD), this iterative process provides both:
    - The point estimate $\alpha_i$ (used for the predictive mean)
    - The uncertainty estimate $C_i$ (an approximation to the inverse covariance)
    
    This means that, at every step, you have access to both the mean and the uncertainty of your predictions—one of the key strengths of Gaussian Process inference.

---

**Summary Table**

| Feature                | Iterative GP Algorithm      | SGD                        |
|------------------------|----------------------------|----------------------------|
| Updates                | $C_i$, $\alpha_i$          | Parameter vector only      |
| Per-step cost          | $\mathcal{O}(N^2)$         | $\mathcal{O}(N)$           |
| Total cost (for $N$)   | $\mathcal{O}(N^3)$         | $\mathcal{O}(N \cdot T)$   |
| Uncertainty estimate   | Yes ($C_i$)                | No                         |
| Exact solution         | Yes (after $N$ steps)      | No (approximate)           |

---

> **In summary:**  
> The iterative GP algorithm is a principled, numerically stable way to compute both the predictive mean and uncertainty in Gaussian Process regression. Each iteration incrementally "loads" information from the data, making it a powerful alternative to standard optimization methods that only provide point estimates.

In [None]:
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from typing import Callable


# --- Re-using kernel definition ---
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """Computes the Squared Exponential (RBF) kernel matrix."""
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# --- Iterative GP Training Procedure (Algorithm 6) ---
def iterative_gp_train(
    K_matrix: jnp.ndarray,  # K_XX + sigma^2 I
    y_bar: jnp.ndarray,  # y - mu_X
    num_iterations: int,  # Corresponds to N for exact Cholesky
    mean_func_train_data: Callable[
        [jnp.ndarray], jnp.ndarray
    ],  # Mean function at training points
    X_train_data: jnp.ndarray,  # Original training inputs
):
    N = K_matrix.shape[0]
    C_i = jnp.zeros((N, N), dtype=K_matrix.dtype)  # Inverse estimate
    alpha_i = jnp.zeros(N, dtype=y_bar.dtype)  # Solution estimate

    # Store intermediate results for plotting
    all_alpha_estimates = [alpha_i]
    all_C_estimates = [C_i]

    print("--- Starting Iterative GP Training ---")
    for i in range(num_iterations):
        # s_i: Action - load (here, canonical basis vector e_i)
        s_i = jnp.eye(N, dtype=K_matrix.dtype)[i, :]

        # z_i: Observation - compute (K @ s_i)
        z_i = jnp.dot(K_matrix, s_i)

        # d_i: Low-rank update (s_i - C_{i-1} @ z_i)
        d_i = s_i - jnp.dot(C_i, z_i)

        # eta_i: Schur complement (s_i^T @ K @ d_i = z_i^T @ d_i)
        eta_i = jnp.dot(z_i.T, d_i)

        # Add small jitter for numerical stability
        eta_i_stable = eta_i + 1e-12

        # C_i: Inverse estimate update
        C_i_new = C_i + (1 / eta_i_stable) * jnp.outer(d_i, d_i)
        C_i = C_i_new

        # alpha_i: Solution estimate update
        alpha_i_new = alpha_i + (1 / eta_i_stable) * jnp.dot(d_i, jnp.dot(d_i.T, y_bar))
        alpha_i = alpha_i_new

        all_alpha_estimates.append(alpha_i)
        all_C_estimates.append(C_i)

        if (i + 1) % (N // 5) == 0 or (i + 1) == N:  # Print progress
            print(
                f"  Iteration {i + 1}/{N}: Max abs diff alpha vs true: {jnp.max(jnp.abs(alpha_i - jnp.linalg.solve(K_matrix, y_bar))):.2e}"
            )

    print("--- Iterative GP Training Complete ---")
    return all_alpha_estimates, all_C_estimates

In [None]:
# --- GP Prediction Procedure (Algorithm 6, PREDICT) ---
def iterative_gp_predict(
    X_test: jnp.ndarray,
    X_train_data: jnp.ndarray,
    mean_func_test: Callable[[jnp.ndarray], jnp.ndarray],
    kernel_func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
    alpha_final: jnp.ndarray,
    C_final: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Performs GP prediction using the final alpha and C estimates from iterative training.
    """
    # k_xS: Covariance to Observations
    k_xS = kernel_func(X_test, X_train_data)

    # mu_x: Point estimate
    mu_x = mean_func_test(X_test) + jnp.dot(k_xS, alpha_final)

    # v_xx: Uncertainty
    k_xx = kernel_func(X_test, X_test)
    v_xx = k_xx - jnp.dot(k_xS, jnp.dot(C_final, k_xS.T))

    return mu_x, v_xx


In [None]:
# --- Example Usage (Simulated Data) ---
key = random.PRNGKey(10)
N_data = 20  # Number of training points for this example
X_train_sim = jnp.linspace(-5, 5, N_data)[:, None]
y_true = jnp.sin(X_train_sim) * jnp.exp(-0.1 * X_train_sim**2)
y_train_sim = y_true.squeeze() + 0.1 * random.normal(key, (N_data,))

# GP parameters
mean_func_sim = lambda x: jnp.zeros(x.shape[0])
kernel_func_sim = lambda x1, x2: squared_exponential_kernel(
    x1, x2, sigma=1.0, lengthscale=1.0
)
noise_var_sim = 0.1**2

# Precompute K_matrix and y_bar
K_matrix_sim = kernel_func_sim(X_train_sim, X_train_sim) + noise_var_sim * jnp.eye(
    N_data
)
y_bar_sim = y_train_sim - mean_func_sim(X_train_sim)

# Run iterative training
all_alpha_estimates, all_C_estimates = iterative_gp_train(
    K_matrix_sim, y_bar_sim, N_data, mean_func_sim, X_train_sim
)

# Get final alpha and C estimates
alpha_final_sim = all_alpha_estimates[-1]
C_final_sim = all_C_estimates[-1]

# Generate test points for prediction
X_test_sim = jnp.linspace(-6, 6, 100)[:, None]

# Perform prediction using the final estimates
mu_pred_sim, cov_pred_sim = iterative_gp_predict(
    X_test_sim,
    X_train_sim,
    mean_func_sim,
    kernel_func_sim,
    alpha_final_sim,
    C_final_sim,
)
std_pred_sim = jnp.sqrt(jnp.diag(cov_pred_sim))


In [None]:
import plotly.graph_objs as go

# Create the main trace for the predictive mean
trace_mean = go.Scatter(
    x=X_test_sim.squeeze(),
    y=mu_pred_sim,
    mode="lines",
    name="GP Predictive Mean (Iterative)",
    line=dict(color="red"),
)

# Create the confidence interval as a filled area
trace_ci = go.Scatter(
    x=jnp.concatenate([X_test_sim.squeeze(), X_test_sim.squeeze()[::-1]]),
    y=jnp.concatenate(
        [mu_pred_sim - 2 * std_pred_sim, (mu_pred_sim + 2 * std_pred_sim)[::-1]]
    ),
    fill="toself",
    fillcolor="rgba(255,0,0,0.2)",
    line=dict(color="rgba(255,255,255,0)"),
    hoverinfo="skip",
    showlegend=True,
    name="95% Confidence Interval",
)

# Training data as scatter points
trace_train = go.Scatter(
    x=X_train_sim.squeeze(),
    y=y_train_sim,
    mode="markers",
    name="Training Data",
    marker=dict(color="blue", size=7, opacity=0.8),
)

layout = go.Layout(
    title="Iterative Gaussian Process Regression",
    xaxis=dict(title="X"),
    yaxis=dict(title="Y"),
    legend=dict(x=0.01, y=0.99),
    template="plotly_white",
    hovermode="closest",
)

fig = go.Figure(data=[trace_ci, trace_mean, trace_train], layout=layout)
fig.show()

In [None]:
# --- Plotting Iterative Posterior Updates (Conceptual) ---
# This part visualizes how the posterior mean and variance evolve with each iteration.
# It can be computationally intensive for many iterations/large N.
# We'll plot a few intermediate steps.

import plotly.graph_objs as go

num_plots = min(N_data, 5)  # Plot up to 5 intermediate steps + final
plot_indices = jnp.linspace(0, N_data, num_plots, endpoint=True, dtype=int)

plotly_figs = []
for k, idx in enumerate(plot_indices):
    if idx == 0:
        continue

    current_alpha = all_alpha_estimates[idx]
    current_C = all_C_estimates[idx]

    # Predict with current estimates
    mu_iter, cov_iter = iterative_gp_predict(
        X_test_sim,
        X_train_sim,
        mean_func_sim,
        kernel_func_sim,
        current_alpha,
        current_C,
    )
    std_iter = jnp.sqrt(jnp.diag(cov_iter))

    # Plotly traces
    trace_mean = go.Scatter(
        x=X_test_sim.squeeze(),
        y=mu_iter,
        mode="lines",
        name="Posterior Mean",
        line=dict(color="red"),
    )
    trace_ci = go.Scatter(
        x=jnp.concatenate([X_test_sim.squeeze(), X_test_sim.squeeze()[::-1]]),
        y=jnp.concatenate([mu_iter - 2 * std_iter, (mu_iter + 2 * std_iter)[::-1]]),
        fill="toself",
        fillcolor="rgba(255,0,0,0.2)",
        line=dict(color="rgba(255,255,255,0)"),
        hoverinfo="skip",
        showlegend=True,
        name="95% Confidence Interval",
    )
    trace_train_processed = go.Scatter(
        x=X_train_sim[:idx, 0],
        y=y_train_sim[:idx],
        mode="markers",
        name="Processed Data",
        marker=dict(color="blue", size=8, opacity=0.9),
    )
    traces = [trace_ci, trace_mean, trace_train_processed]

    if idx < N_data:
        trace_train_unprocessed = go.Scatter(
            x=X_train_sim[idx:, 0],
            y=y_train_sim[idx:],
            mode="markers",
            name="Unprocessed Data",
            marker=dict(color="gray", size=7, opacity=0.5),
        )
        traces.append(trace_train_unprocessed)

    layout = go.Layout(
        title=f"Iterative GP Posterior Updates<br>Iteration {idx}/{N_data}",
        xaxis=dict(title="X"),
        yaxis=dict(title="Y"),
        legend=dict(x=0.01, y=0.99),
        template="plotly_white",
        hovermode="closest",
        height=400,
        width=700,
    )
    fig_iter = go.Figure(data=traces, layout=layout)
    plotly_figs.append(fig_iter)
    fig_iter.show()


# Why Load Individual Data Points? Generalization to Projections

In the iterative Cholesky algorithm (Algorithm 6), we used the **canonical basis vectors** $s_i = e_i$ (where $e_i$ is a vector with $1$ at the $i$-th position and $0$ elsewhere). This corresponds to sequentially "loading" individual data points into our computation.

---

## How Does the Algorithm Interact with the Data?

The core operations in each iteration are:

- $z_i = K s_i$  
    *This computes a column of the kernel matrix $K$ if $s_i = e_i$.*

- $y_i = d_i^\top y$  
    *This computes a component of the residual vector.*

**Key Insight:**  
These are both **linear projections** of the data!

---

## Why Is This Profound?

- The Cholesky algorithm, in this context, is not just a matrix factorization.  
- It is a process of **sequentially loading information** from individual data points (or, more precisely, their corresponding columns in the kernel matrix).
- Each step treats these as new "observations" and performs a stable, rank-1 update to both the inverse and the solution.

---

## Generalizing Beyond Individual Data Points

This realization opens up powerful new directions for scalable and flexible GP inference:

### 1. **Batch Loading**
- Instead of $s_i \in \mathbb{R}^N$ being a single canonical basis vector, we could use $s_i \in \mathbb{R}^{N \times b}$, representing a **batch of $b$ vectors**.
- This leads to **rank-$b$ updates** at each step, potentially speeding up convergence and making better use of modern hardware.

### 2. **Random Projections**
- Instead of loading specific data points or batches, we could project the data onto **random directions**.
- This is the foundation of **random feature approximations** in Gaussian Processes, which can dramatically reduce computational cost.

---

## The Role of the "Policy" for $s_i$

- The **choice of policy** for selecting $s_i$ (which vectors to load at each step) becomes a **critical design decision** in scalable GP algorithms.
- Different choices can lead to different trade-offs between computational efficiency, statistical accuracy, and scalability.

---

### **Summary Table**

| Approach                | $s_i$ Choice                | Update Type   | Example Use Case                  |
|-------------------------|-----------------------------|--------------|-----------------------------------|
| Sequential (classic)    | $e_i$ (canonical basis)     | Rank-1       | Exact Cholesky, classic GP        |
| Batch                   | Multiple $e_i$'s            | Rank-$b$     | Mini-batch Cholesky, scalable GP  |
| Random Projections      | Random directions           | Rank-$b$     | Random features, fast GP approx.  |

---

> **In summary:**  
> The iterative Cholesky algorithm is more than just a numerical routine—it is a *framework for loading and processing information from data*. By generalizing the way we select and load $s_i$, we can design more scalable and flexible Gaussian Process algorithms that are well-suited to modern machine learning challenges.

# IterGP: Generalized Iterative Gaussian Process Regression Algorithm

The **IterGP algorithm** (Wenger, Pleiss, Pförtner, Hennig, Cunningham, NeurIPS 2022) generalizes iterative GP regression by allowing **flexible projections**. Instead of always using the canonical basis vectors $e_i$, it introduces a **POLICY** function to select the projection vectors $s_i$ at each step.

---

## 🌐 **Algorithm 7: Iterative GP Regression (Numerics Layer — Generalized)**

### **Key Idea**

- Instead of always loading one data point at a time (using $e_i$), we can load **arbitrary projections** $s_i$ at each step.
- The choice of $s_i$ is governed by a **policy**, which can be designed for efficiency, scalability, or statistical optimality.

---

### **Inputs**

- $K = k_{XX} + \sigma^2 I$  
  *Kernel matrix with noise (size $N \times N$).*
- $\bar{y} = y - \mu_X$  
  *Centered targets (subtract mean function at training points).*
- Initial guesses: $\alpha_0$, $C_0$  
  *Typically zeros.*

---

### **Outputs**

- $S = [s_j]_{j \leq n}$  
  *List of all projection vectors used (each $s_j \in \mathbb{R}^{N \times k_j}$).*
- $C_n \approx K^{-1}$  
  *Approximate inverse of the kernel matrix after $n$ iterations.*
- $\alpha_n \approx K^{-1} \bar{y}$  
  *Approximate solution (posterior mean weights).*

---

### **IterGP Update Steps**

At each iteration $i = 1, \ldots, n$:

1. **Select Projection(s):**
   $$
   s_i \leftarrow \text{POLICY}(S_{<i}, Z_{<i})
   $$
   - $s_i \in \mathbb{R}^{N \times k_i}$ (can be a single vector or a batch).
   - The policy can use all previous projections $S_{<i}$ and kernel columns $Z_{<i}$.

2. **Compute Projected Kernel Column(s):**
   $$
   z_i = K s_i
   $$
   - $z_i \in \mathbb{R}^{N \times k_i}$.

3. **Compute Low-Rank Update Direction:**
   $$
   d_i = (I - C_{i-1} K) s_i = s_i - C_{i-1} z_i
   $$
   - $d_i \in \mathbb{R}^{N \times k_i}$.
   - This is the "new information" not yet explained by previous steps.

4. **Compute Schur Complement (Normalization Matrix):**
   $$
   H_i = s_i^\top K d_i = z_i^\top d_i
   $$
   - $H_i \in \mathbb{R}^{k_i \times k_i}$.
   - Ensures numerical stability and proper scaling.

5. **Update Inverse Estimate (Rank-$k_i$ Update):**
   $$
   C_i = C_{i-1} + d_i H_i^{-1} d_i^\top
   $$
   - $C_i$ is the new approximation to $K^{-1}$.

6. **Update Solution Estimate (Posterior Mean Weights):**
   $$
   \alpha_i = \alpha_{i-1} + d_i H_i^{-1} d_i^\top \bar{y}
   $$
   - $\alpha_i$ is the new approximation to $K^{-1} \bar{y}$.

---

### **Summary Table of IterGP Steps**

| Step | Operation | Purpose |
|------|-----------|---------|
| 1    | $s_i \leftarrow \text{POLICY}(S_{<i}, Z_{<i})$ | Choose projection(s) |
| 2    | $z_i = K s_i$ | Project kernel |
| 3    | $d_i = s_i - C_{i-1} z_i$ | Find unexplained direction |
| 4    | $H_i = z_i^\top d_i$ | Normalize (Schur complement) |
| 5    | $C_i = C_{i-1} + d_i H_i^{-1} d_i^\top$ | Update inverse estimate |
| 6    | $\alpha_i = \alpha_{i-1} + d_i H_i^{-1} d_i^\top \bar{y}$ | Update solution estimate |

---

### **Prediction Step (Same as Standard GP)**

Given a new test point $x$:

- **Covariance to Projections:**
  $$
  k_{xS} = k(x, S)
  $$
- **Predictive Mean:**
  $$
  \mu_x = \mu(x) + k_{xS} \alpha
  $$
- **Predictive Variance:**
  $$
  v_{xx} = k_{xx} - k_{xS} C k_{Sx}
  $$

---

### **Intuition & Practical Impact**

- **Flexibility:**  
  The policy for $s_i$ can be tailored: sequential, batch, random, or conjugate directions (Lanczos/CG).
- **Scalability:**  
  Batch or random projections can dramatically speed up convergence and leverage modern hardware.
- **Uncertainty Quantification:**  
  At every step, $C_i$ provides an up-to-date uncertainty estimate, not just a point prediction.

---

> **In summary:**  
> The IterGP framework unifies and generalizes iterative GP regression. By allowing flexible, policy-driven projections, it enables scalable, efficient, and uncertainty-aware inference for modern probabilistic machine learning.

# What is the Optimal Projection?  
## Data Loading as a Training Policy

The **choice of POLICY** for selecting the projection vectors $s_i$ is crucial in iterative Gaussian Process (GP) regression. Ideally, we want to choose projections that are **maximally informative** at each step, leading to the fastest reduction in the error of our inverse and solution estimates.

---

### The Ideal Case: Eigenvectors of $K$

Suppose we could choose the projections $s_i = u_i$ along the **eigenvectors** of the kernel matrix $K = U \Lambda U^\top$, where $U$ contains the eigenvectors and $\Lambda$ contains the eigenvalues. The iterative algorithm would then simplify beautifully:

- **Projection:**  
    $$
    z_i = K s_i = U \Lambda U^\top u_i = \lambda_i u_i
    $$
- **Residual Update:**  
    $$
    d_i = (I - C_{i-1} K) s_i = u_i - \lambda_i C_{i-1} u_i
    $$
    If $C_{i-1}$ is the inverse of the first $i-1$ eigencomponents, then $d_i = u_i$.
- **Schur Complement:**  
    $$
    H_i = s_i^\top K d_i = u_i^\top U \Lambda U^\top u_i = \lambda_i
    $$
- **Inverse Update:**  
    $$
    C_i = C_{i-1} + d_i H_i^{-1} d_i^\top = C_{i-1} + u_i \lambda_i^{-1} u_i^\top = \sum_{j \leq i} u_j \lambda_j^{-1} u_j^\top
    $$
- **Solution Estimate:**  
    $$
    \alpha = C_i y = \sum_{j \leq i} u_j \lambda_j^{-1} (u_j^\top y)
    $$

---

### Why Is This Optimal?

If the $u_i$ are sorted by decreasing $\lambda_i$, this policy will **maximally reduce the residual** $\lvert K^{-1}y - \alpha \rvert_{K^2}$ in each iteration.  
- The largest eigenvalues correspond to the most dominant modes of variation in the kernel.
- Incorporating them first provides the most significant reduction in uncertainty.

---

### The Catch

However, we **do not have the eigenvectors of $K$ available upfront**. Computing them is itself an $\mathcal{O}(N^3)$ operation, which defeats the purpose of an iterative, scalable algorithm.

---

### The Solution: Iterative Methods

This is where **iterative methods** like the **Lanczos process** come into play. These methods can efficiently approximate the dominant eigenvectors and eigenvalues of $K$ without computing the full eigendecomposition, enabling scalable and effective projection policies for large-scale GP inference.

---

#### **Summary Table: Projection Choices in Iterative GP Regression**

| Approach                | $s_i$ Choice                | Update Type   | Example Use Case                  |
|-------------------------|-----------------------------|--------------|-----------------------------------|
| Sequential (classic)    | $e_i$ (canonical basis)     | Rank-1       | Exact Cholesky, classic GP        |
| Batch                   | Multiple $e_i$'s            | Rank-$b$     | Mini-batch Cholesky, scalable GP  |
| Random Projections      | Random directions           | Rank-$b$     | Random features, fast GP approx.  |
| **Optimal (theoretical)** | Eigenvectors of $K$         | Rank-1       | Fastest convergence (in theory)   |

---

> **In summary:**  
> The optimal projection policy would use the eigenvectors of $K$, but this is computationally infeasible for large $N$. Iterative methods like Lanczos provide practical alternatives, allowing us to approximate these optimal directions and achieve efficient, scalable GP inference.

# The Lanczos Process: Iterative Construction of Conjugate Projections

When working with large symmetric matrices (like the kernel matrix $K$ in Gaussian Processes), directly computing their eigenvectors is computationally infeasible. However, we still want to construct "good" projection vectors $s_i$ that help us efficiently solve linear systems or approximate matrix functions. This is where the **Lanczos process** comes in.

---

## What is the Lanczos Process?

The **Lanczos process** (Kornél Lánczos, 1950) is an iterative algorithm that builds an **orthonormal basis** for the **Krylov subspace** associated with a symmetric matrix $K$ and an initial vector $s_0$.

- **Krylov Subspace:**  
    For a matrix $K$ and vector $s_0$, the Krylov subspace of order $n$ is:
    $$
    \mathcal{K}_n(K, s_0) = \text{span}\{s_0, Ks_0, K^2s_0, \ldots, K^{n-1}s_0\}
    $$

- **Goal:**  
    Construct a sequence of vectors $s_1, s_2, \ldots, s_N$ that are **orthonormal** and **K-conjugate** (i.e., $s_i^\top K s_j = 0$ for $i \neq j$).

---

## How Does the Lanczos Process Work?

- **Initialization:**  
    Start with a normalized vector $s_0$.

- **Iteration:**  
    At each step, generate a new vector by applying $K$ to the previous vector, then orthogonalize it against the previous two vectors. This ensures the new vector is orthogonal (in the $K$-inner product sense) to all previous ones.

- **Result:**  
    After $N$ steps, you obtain a set of vectors $S = [s_1, s_2, \ldots, s_N]$ such that:
    $$
    S^\top K S = T
    $$
    where $T$ is a **tridiagonal matrix** with diagonal entries $\alpha_i$ and sub-diagonal entries $\beta_i$.

---

## The Tridiagonal Matrix $T$

The matrix $T$ has the following structure:
$$
T = 
\begin{pmatrix}
\alpha_1 & \beta_2 & 0 & \cdots & 0 \\
\beta_2 & \alpha_2 & \beta_3 & \ddots & \vdots \\
0 & \beta_3 & \alpha_3 & \ddots & 0 \\
\vdots & \ddots & \ddots & \ddots & \beta_N \\
0 & \cdots & 0 & \beta_N & \alpha_N
\end{pmatrix}
$$

- $S = [s_1, s_2, \ldots, s_N]$ is the matrix of Lanczos vectors.
- $\alpha_i$ are the diagonal elements.
- $\beta_i$ are the sub-diagonal elements.

---

## Why is the Lanczos Process Important?

- **Efficient Projections:**  
    The Lanczos vectors $s_i$ are "good" directions for iterative algorithms like Conjugate Gradients, as they are K-orthogonal and span the most relevant subspace for solving $K\alpha = y$.
- **Dimensionality Reduction:**  
    By working in the Krylov subspace, we can approximate solutions to large linear systems or eigenvalue problems using only a small number of vectors.
- **Foundation for CG:**  
    The Conjugate Gradient (CG) method is essentially the Lanczos process applied to solving linear systems.

---

## Summary

- The Lanczos process provides a principled way to iteratively construct an orthonormal basis of projections for a symmetric matrix $K$.
- These projections are **K-orthogonal** (conjugate), making them ideal for efficient numerical algorithms in Gaussian Processes and other areas of scientific computing.
- The process transforms $K$ into a much simpler tridiagonal matrix $T$, capturing the essential structure needed for computation.

---

> **In essence:**  
> The Lanczos process is a powerful tool for scalable computation with large symmetric matrices, enabling efficient iterative algorithms for both mean and uncertainty estimation in probabilistic models.


### Simplified Lanczos Process (Algorithm)

The **Lanczos process** is an efficient iterative algorithm for constructing an orthonormal basis of the Krylov subspace for a symmetric matrix $K$ and an initial vector $s_0$. It is widely used in numerical linear algebra for eigenvalue problems and for solving large linear systems (e.g., in Conjugate Gradients).

---

#### **Algorithm: Simplified Lanczos Process**

**Input:**  
- Symmetric matrix $K \in \mathbb{R}^{N \times N}$
- Initial vector $s_0 \in \mathbb{R}^N$

**Output:**  
- Sequences of scalars $\alpha_i$, $\beta_i$
- Lanczos vectors $s_i$

---

### **Step-by-Step Algorithm (with Math and Intuition)**

1. **Initialization**
  - Normalize the initial vector:
    $$
    s_1 = \frac{s_0}{\|s_0\|}
    $$
    *Start with a unit-norm vector $s_1$.*

2. **First Iteration**
  - Compute the first matrix-vector product:
    $$
    z_1 = K s_1
    $$
    *Apply $K$ to the starting vector.*

  - Compute the first diagonal element:
    $$
    \alpha_1 = s_1^\top z_1 = s_1^\top K s_1
    $$
    *$\alpha_1$ captures how much $K$ stretches $s_1$ along itself.*

  - Compute the first residual:
    $$
    l_1 = z_1 - \alpha_1 s_1
    $$
    *Remove the component of $z_1$ along $s_1$ to get the part orthogonal to $s_1$.*

3. **Subsequent Iterations ($i = 2, \ldots, N$)**
  - Compute the sub-diagonal element:
    $$
    \beta_i = \|l_{i-1}\|
    $$
    *$\beta_i$ measures the norm of the new direction; if it's zero, the process terminates.*

  - Normalize to get the next Lanczos vector:
    $$
    s_i = \frac{l_{i-1}}{\beta_i}
    $$
    *$s_i$ is orthogonal to all previous $s_j$.*

  - Apply $K$ to the new vector:
    $$
    z_i = K s_i
    $$

  - Compute the next diagonal element:
    $$
    \alpha_i = s_i^\top z_i = s_i^\top K s_i
    $$
    *$\alpha_i$ is the Rayleigh quotient for $s_i$.*

  - Compute the new residual:
    $$
    l_i = z_i - \alpha_i s_i - \beta_i s_{i-1}
    $$
    *Subtract projections onto $s_i$ and $s_{i-1}$ to maintain orthogonality.*

---

### **Summary Table of Steps**

| Step | Formula | Explanation |
|------|---------|-------------|
| 1 | $s_1 = \frac{s_0}{\|s_0\|}$ | Normalize initial vector |
| 2 | $z_1 = K s_1$ | Matrix-vector product |
| 3 | $\alpha_1 = s_1^\top K s_1$ | Diagonal entry of tridiagonal matrix |
| 4 | $l_1 = z_1 - \alpha_1 s_1$ | Orthogonalize against $s_1$ |
| 5 | $\beta_i = \|l_{i-1}\|$ | Sub-diagonal entry (norm) |
| 6 | $s_i = \frac{l_{i-1}}{\beta_i}$ | Next orthonormal vector |
| 7 | $z_i = K s_i$ | Matrix-vector product |
| 8 | $\alpha_i = s_i^\top K s_i$ | Diagonal entry |
| 9 | $l_i = z_i - \alpha_i s_i - \beta_i s_{i-1}$ | Orthogonalize against previous two vectors |

---

### **Intuition**

- **Orthonormal Basis:** Each $s_i$ is orthogonal to all previous $s_j$ (for $j < i$), forming a basis for the Krylov subspace $\mathcal{K}_n(K, s_0)$.
- **Tridiagonalization:** The process builds a tridiagonal matrix $T$ such that $S^\top K S = T$, where $S = [s_1, \ldots, s_n]$.
- **Efficiency:** Only requires matrix-vector products and inner products—no full matrix inversion or eigendecomposition.
- **Termination:** If $\beta_i = 0$, the process has found an invariant subspace (exact for $K$).

---

> **In summary:**  
> The Lanczos process is a powerful, efficient way to extract the most "informative" directions for $K$ using only matrix-vector products, making it essential for scalable Gaussian Process inference and large-scale linear algebra.

## Key Points: The Lanczos Process in Iterative GP Inference

- **Computational Cost:**  
    - Each iteration of the Lanczos process requires a single matrix-vector product $K s_i$, which is $\mathcal{O}(N^2)$ for a dense matrix $K$.
    - If $K$ is sparse or admits a fast matrix-vector product (e.g., via structure or approximation), this cost can be significantly reduced.

- **Numerical Stability:**  
    - The basic Lanczos algorithm, as presented, is a simplified version and can become numerically unstable for long sequences due to loss of orthogonality among the generated vectors.
    - In practice, **re-orthogonalization techniques** are used to maintain numerical stability and ensure the orthogonality of the Lanczos vectors.

- **Combination with IterGP:**  
    - The Lanczos process can be seamlessly integrated with the iterative GP algorithm (Algorithm 7).
    - The **POLICY** function in IterGP can be implemented by generating Lanczos vectors at each step.
    - This integration adds only one extra line of $\mathcal{O}(N)$ (or $\mathcal{O}(N^2)$ for the matrix-vector product) per iteration for selecting $s_i$.

- **Conjugate Gradients (CG):**  
    - A particularly important choice for the initial vector $s_0$ is the **residual of the linear system**:
      $$
      s_0 = K \alpha_0 - y
      $$
      or, more commonly for the gradient of the quadratic form,
      $$
      s_0 = y - K \alpha_0
      $$
    - This choice leads to the **Conjugate Gradient (CG) method**, a powerful iterative solver for symmetric positive definite linear systems.
    - CG implicitly constructs Lanczos vectors and is widely used for large-scale problems due to its efficiency and scalability.

---

**Summary:**  
The Lanczos process is a foundational tool for efficient iterative inference in Gaussian Processes. By leveraging fast matrix-vector products and careful numerical techniques, it enables scalable computation of both the mean and uncertainty in GP regression. When combined with the IterGP framework, it provides a principled and practical approach to large-scale probabilistic machine learning.

In [None]:
import jax.numpy as jnp
import numpy as np  # For random initial vector
import matplotlib.pyplot as plt


def conceptual_lanczos_process(
    K_matrix: jnp.ndarray, s0: jnp.ndarray, num_iterations: int
):
    """
    Conceptual implementation of the Lanczos process.
    This version is simplified and may be numerically unstable for many iterations.
    """
    N = K_matrix.shape[0]
    if num_iterations > N:
        num_iterations = N  # Cannot generate more than N Lanczos vectors

    # Initialize lists to store alpha and beta values
    alphas = []
    betas = []

    # Store Lanczos vectors (orthonormal basis)
    S_vectors = []

    # Step 1: Initialize
    s_prev = s0 / jnp.linalg.norm(s0)  # Normalize initial vector
    S_vectors.append(s_prev)

    z_curr = jnp.dot(K_matrix, s_prev)
    alpha_curr = jnp.dot(s_prev.T, z_curr)
    alphas.append(alpha_curr)

    l_curr = z_curr - alpha_curr * s_prev

    print("--- Starting Conceptual Lanczos Process ---")
    for i in range(1, num_iterations):
        beta_curr = jnp.linalg.norm(l_curr)

        if beta_curr < 1e-10:  # Break if l_curr is zero (exact subspace found)
            print(f"  Converged at iteration {i} (beta is near zero).")
            break

        s_curr = l_curr / beta_curr
        S_vectors.append(s_curr)
        betas.append(beta_curr)

        z_next = jnp.dot(K_matrix, s_curr)
        alpha_next = jnp.dot(s_curr.T, z_next)
        alphas.append(alpha_next)

        l_next = (
            z_next - alpha_next * s_curr - beta_curr * s_prev
        )  # Note: beta_curr is beta_{i+1}
        l_curr = l_next
        s_prev = s_curr

        print(f"  Iteration {i + 1}: alpha={alpha_next:.4f}, beta={beta_curr:.4f}")

    # Construct the tridiagonal matrix T for verification
    T = jnp.diag(jnp.array(alphas))
    if len(betas) > 0:
        T += jnp.diag(jnp.array(betas), k=1)
        T += jnp.diag(jnp.array(betas), k=-1)

    # Convert list of vectors to a matrix
    S_matrix = jnp.stack(S_vectors, axis=1)

    print("\n--- Lanczos Process Results ---")
    print("Tridiagonal Matrix T:\n", T)
    print(
        "\nVerification: S^T @ K @ S (should be close to T):\n",
        jnp.dot(S_matrix.T, jnp.dot(K_matrix, S_matrix)),
    )

    return alphas, betas, S_vectors


# --- Example Usage ---
# Create a symmetric positive definite matrix
N_matrix = 10
X_data_matrix = jnp.linspace(-2, 2, N_matrix)[:, None]
K_XX_matrix = squared_exponential_kernel(
    X_data_matrix, X_data_matrix, sigma=1.0, lengthscale=0.5
)
A_example = K_XX_matrix + 0.1**2 * jnp.eye(N_matrix)  # Add jitter for SPD

# Random initial vector
key = random.PRNGKey(2)
s0_example = random.normal(key, (N_matrix,))

# Run Lanczos process for a few iterations
num_lanczos_iterations = 5
alphas, betas, S_vectors = conceptual_lanczos_process(
    A_example, s0_example, num_lanczos_iterations
)

# Plot the Lanczos vectors (basis functions)
plt.figure(figsize=(10, 6))
for i, s_vec in enumerate(S_vectors):
    plt.plot(jnp.arange(N_matrix), s_vec, label=f"s_{i + 1}")
plt.xlabel("Index")
plt.ylabel("Value")
plt.title("First few Lanczos Vectors")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from typing import Callable


# --- Re-using kernel definition ---
def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """Computes the Squared Exponential (RBF) kernel matrix."""
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# --- Iterative GP Training Procedure (Algorithm 7, adapted for CG) ---
def iterative_gp_train_cg(
    K_matrix: jnp.ndarray,  # K_XX + sigma^2 I
    y_bar: jnp.ndarray,  # y - mu_X
    num_iterations: int,
    mean_func_train_data: Callable[[jnp.ndarray], jnp.ndarray],
    X_train_data: jnp.ndarray,
):
    N = K_matrix.shape[0]

    # Initialize alpha and C (from Algorithm 7, with alpha_0=0, C_0=0)
    alpha_i = jnp.zeros(N, dtype=y_bar.dtype)
    C_i = jnp.zeros((N, N), dtype=K_matrix.dtype)

    # Conjugate Gradient specific initializations
    r_i = y_bar - jnp.dot(
        K_matrix, alpha_i
    )  # Initial residual (gradient of L(alpha) at alpha_0)
    p_i = r_i  # Initial search direction

    # Store intermediate results for plotting
    all_alpha_estimates = [alpha_i]
    all_C_estimates = [C_i]
    all_residual_norms = [jnp.linalg.norm(r_i)]

    print("--- Starting Iterative GP Training (Conjugate Gradients) ---")
    for i in range(num_iterations):
        # This is where the 'POLICY' for s_i comes in.
        # For CG, s_i is the search direction p_i.
        s_i = p_i

        # z_i = K @ s_i (matrix-vector product)
        z_i = jnp.dot(K_matrix, s_i)

        # d_i: Low-rank update (s_i - C_{i-1} @ z_i)
        # For CG, this is related to the update direction.
        # In the context of Algorithm 7, d_i is the part of s_i that is K-orthogonal to previous s_j.
        # For CG, the search directions p_i are K-orthogonal by construction.
        # Here, we'll use the definition from Algorithm 7 for C_i and alpha_i updates.
        d_i = s_i - jnp.dot(
            C_i, z_i
        )  # This d_i is equivalent to p_i if C_i is the inverse of the subspace.

        # eta_i: Schur complement (s_i^T @ K @ d_i = z_i^T @ d_i)
        eta_i = jnp.dot(z_i.T, d_i)  # This will be p_i^T @ K @ p_i for CG

        # Add small jitter for numerical stability
        eta_i_stable = eta_i + 1e-12

        # Alpha update (from CG algorithm)
        alpha_step = jnp.dot(r_i.T, r_i) / eta_i_stable
        alpha_i_new = alpha_i + alpha_step * p_i
        alpha_i = alpha_i_new

        # Residual update (from CG algorithm)
        r_i_new = r_i - alpha_step * z_i

        # Beta for next search direction (from CG algorithm)
        beta_next = jnp.dot(r_i_new.T, r_i_new) / jnp.dot(r_i.T, r_i)

        # Update search direction
        p_i_new = r_i_new + beta_next * p_i

        r_i = r_i_new
        p_i = p_i_new

        # C_i: Inverse estimate update (from Algorithm 7)
        # This update is O(N^2) and is what makes the C_i matrix grow.
        # In practical CG, C_i is not explicitly formed.
        C_i_new = C_i + (1 / eta_i_stable) * jnp.outer(d_i, d_i)
        C_i = C_i_new

        all_alpha_estimates.append(alpha_i)
        all_C_estimates.append(C_i)
        all_residual_norms.append(jnp.linalg.norm(r_i))

        if (i + 1) % (N // 5) == 0 or (i + 1) == N:
            print(f"  Iteration {i + 1}/{N}: Residual norm: {jnp.linalg.norm(r_i):.2e}")
            if jnp.linalg.norm(r_i) < 1e-6:
                print("  Residual norm very small, likely converged.")
                break

    print("--- Iterative GP Training (Conjugate Gradients) Complete ---")
    return all_alpha_estimates, all_C_estimates, all_residual_norms


In [None]:
# --- GP Prediction Procedure (Algorithm 6, PREDICT) ---
# Re-using the same prediction function as before
def iterative_gp_predict(
    X_test: jnp.ndarray,
    X_train_data: jnp.ndarray,
    mean_func_test: Callable[[jnp.ndarray], jnp.ndarray],
    kernel_func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
    alpha_final: jnp.ndarray,
    C_final: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Performs GP prediction using the final alpha and C estimates from iterative training.
    """
    k_xS = kernel_func(X_test, X_train_data)
    mu_x = mean_func_test(X_test) + jnp.dot(k_xS, alpha_final)
    k_xx = kernel_func(X_test, X_test)
    v_xx = k_xx - jnp.dot(k_xS, jnp.dot(C_final, k_xS.T))
    return mu_x, v_xx


In [None]:
# --- Example Usage (Simulated Data) ---
key = random.PRNGKey(11)
N_data = 20  # Number of training points for this example
X_train_sim = jnp.linspace(-5, 5, N_data)[:, None]
y_true = jnp.sin(X_train_sim) * jnp.exp(-0.1 * X_train_sim**2)
y_train_sim = y_true.squeeze() + 0.1 * random.normal(key, (N_data,))

# GP parameters
mean_func_sim = lambda x: jnp.zeros(x.shape[0])
kernel_func_sim = lambda x1, x2: squared_exponential_kernel(
    x1, x2, sigma=1.0, lengthscale=1.0
)
noise_var_sim = 0.1**2

# Precompute K_matrix and y_bar
K_matrix_sim = kernel_func_sim(X_train_sim, X_train_sim) + noise_var_sim * jnp.eye(
    N_data
)
y_bar_sim = y_train_sim - mean_func_sim(X_train_sim)

# Run iterative training with CG-like updates
all_alpha_estimates_cg, all_C_estimates_cg, all_residual_norms_cg = (
    iterative_gp_train_cg(K_matrix_sim, y_bar_sim, N_data, mean_func_sim, X_train_sim)
)

# Get final alpha and C estimates
alpha_final_cg = all_alpha_estimates_cg[-1]
C_final_cg = all_C_estimates_cg[-1]

# Generate test points for prediction
X_test_sim = jnp.linspace(-6, 6, 100)[:, None]

# Perform prediction using the final estimates
mu_pred_cg, cov_pred_cg = iterative_gp_predict(
    X_test_sim, X_train_sim, mean_func_sim, kernel_func_sim, alpha_final_cg, C_final_cg
)
std_pred_cg = jnp.sqrt(jnp.diag(cov_pred_cg))


In [None]:
import plotly.graph_objs as go

# Create the main trace for the predictive mean (CG Iterative)
trace_mean_cg = go.Scatter(
    x=X_test_sim.squeeze(),
    y=mu_pred_cg,
    mode="lines",
    name="GP Predictive Mean (CG Iterative)",
    line=dict(color="red"),
)

# Create the confidence interval as a filled area
trace_ci_cg = go.Scatter(
    x=jnp.concatenate([X_test_sim.squeeze(), X_test_sim.squeeze()[::-1]]),
    y=jnp.concatenate(
        [mu_pred_cg - 2 * std_pred_cg, (mu_pred_cg + 2 * std_pred_cg)[::-1]]
    ),
    fill="toself",
    fillcolor="rgba(255,0,0,0.2)",
    line=dict(color="rgba(255,255,255,0)"),
    hoverinfo="skip",
    showlegend=True,
    name="95% Confidence Interval",
)

# Training data as scatter points
trace_train_cg = go.Scatter(
    x=X_train_sim.squeeze(),
    y=y_train_sim,
    mode="markers",
    name="Training Data",
    marker=dict(color="blue", size=7, opacity=0.8),
)

layout_cg = go.Layout(
    title="Iterative Gaussian Process Regression with Conjugate Gradients",
    xaxis=dict(title="X"),
    yaxis=dict(title="Y"),
    legend=dict(x=0.01, y=0.99),
    template="plotly_white",
    hovermode="closest",
)

fig_cg = go.Figure(data=[trace_ci_cg, trace_mean_cg, trace_train_cg], layout=layout_cg)
fig_cg.show()

# Plot the residual norm convergence (log scale)
fig_residual = go.Figure()
fig_residual.add_trace(
    go.Scatter(
        x=jnp.arange(len(all_residual_norms_cg)),
        y=all_residual_norms_cg,
        mode="lines+markers",
        name="Residual Norm",
        line=dict(color="purple"),
    )
)
fig_residual.update_layout(
    title="Convergence of Residual Norm in CG-like Iteration",
    xaxis_title="Iteration",
    yaxis_title="Residual Norm (log scale)",
    yaxis_type="log",
    template="plotly_white",
    hovermode="closest",
)
fig_residual.show()


In [None]:
# --- Plotting Iterative Posterior Updates (Conceptual) ---
# This part visualizes how the posterior mean and variance evolve with each iteration.
# It can be computationally intensive for many iterations/large N.
# We'll plot a few intermediate steps.

import plotly.graph_objs as go

# Plot up to 5 intermediate steps + final
num_plots = min(len(all_alpha_estimates_cg), 5)
plot_indices = jnp.linspace(
    0, len(all_alpha_estimates_cg) - 1, num_plots, endpoint=True, dtype=int
)

plotly_figs_cg = []
for k, idx in enumerate(plot_indices):
    current_alpha = all_alpha_estimates_cg[idx]
    current_C = all_C_estimates_cg[idx]

    # Predict with current estimates
    mu_iter, cov_iter = iterative_gp_predict(
        X_test_sim,
        X_train_sim,
        mean_func_sim,
        kernel_func_sim,
        current_alpha,
        current_C,
    )
    std_iter = jnp.sqrt(jnp.diag(cov_iter))

    # Plotly traces
    trace_mean = go.Scatter(
        x=X_test_sim.squeeze(),
        y=mu_iter,
        mode="lines",
        name="Posterior Mean",
        line=dict(color="red"),
    )
    trace_ci = go.Scatter(
        x=jnp.concatenate([X_test_sim.squeeze(), X_test_sim.squeeze()[::-1]]),
        y=jnp.concatenate([mu_iter - 2 * std_iter, (mu_iter + 2 * std_iter)[::-1]]),
        fill="toself",
        fillcolor="rgba(255,0,0,0.2)",
        line=dict(color="rgba(255,255,255,0)"),
        hoverinfo="skip",
        showlegend=True,
        name="95% Confidence Interval",
    )
    trace_train = go.Scatter(
        x=X_train_sim.squeeze(),
        y=y_train_sim,
        mode="markers",
        name="Training Data",
        marker=dict(color="blue", size=7, opacity=0.8),
    )

    layout = go.Layout(
        title=f"Iterative GP Posterior Updates (CG)<br>Iteration {idx}/{len(all_alpha_estimates_cg) - 1}",
        xaxis=dict(title="X"),
        yaxis=dict(title="Y"),
        legend=dict(x=0.01, y=0.99),
        template="plotly_white",
        hovermode="closest",
        height=400,
        width=700,
    )
    fig_iter = go.Figure(data=[trace_ci, trace_mean, trace_train], layout=layout)
    plotly_figs_cg.append(fig_iter)
    fig_iter.show()


# 🚀 Takeaways: Iterative GP Updates with Lanczos

The synergy between **iterative Gaussian Process (GP) algorithms** and **projection methods** like the **Lanczos process** (which underpins Conjugate Gradients) offers both deep theoretical insights and practical computational advantages.

---

## 🌐 Projection-Based Interaction

- **Beyond Individual Data Points:**  
    Instead of loading one data point at a time, we can load **linear projections** of the data.  
    This flexibility allows for more efficient and informative updates in each iteration.

---

## 🎯 Optimal Projections with Lanczos

- **Maximally Informative Directions:**  
    The **Lanczos process** efficiently constructs projections that are "maximally informative"—they reduce the residual of the linear system as quickly as possible.
- **K-Orthogonality:**  
    These projections are **K-orthogonal** (conjugate with respect to the kernel matrix $K$), which is ideal for optimization and ensures rapid convergence.

---

## 📏 Native Mean and Uncertainty Quantification

- **Matching Mean and Uncertainty:**  
    By tracking projections and associated statistics (like $C_i$ and $\alpha_i$), we can compute both the **predictive mean** and **uncertainty** natively.
- **No Trade-off:**  
    This means we do **not** have to sacrifice uncertainty quantification for scalability—both are achieved together.

---

## ⚡ Halving Compute Time (for Exact Solve)

- **Complexity:**  
    Exact methods like Cholesky and Conjugate Gradients (CG) have $\mathcal{O}(N^3)$ complexity in the worst case.
- **Practical Speed:**  
    However, **CG often converges to a good solution in far fewer than $N$ iterations**, making it much faster for large datasets in practice.

---

# 🤖 Computation and Inference: There Is Really No Difference

This lecture culminates in a profound realization:

> **For Gaussian Processes, "computing" and "learning" are fundamentally the same. Numerical algorithms are, in essence, learning machines.**

---

## 🔢 Least-Squares as Linear Algebra

- In **least-squares regression** (equivalent to the GP posterior mean), "training" reduces to solving a linear system $A x = b$.
- The solution is a direct outcome of linear algebraic computation.

---

## 🏗️ Numerical Methods as Data Loaders

- **Cholesky decomposition** and **Conjugate Gradients** can be viewed as **smart data loaders**:
        - They load projections of the data in a specific, efficient order.
        - For example:
                - Choosing $s_i = e_{j(i)}$ (canonical basis vectors with a pivoting policy $j(i)$) yields the **pivoted Cholesky decomposition**.
                - The **Lanczos process**, initialized with $s_0 = K \alpha_0 - y$ (the gradient of the objective), yields the **preconditioned Conjugate Gradient (CG) method**. Here, $C_0$ acts as a preconditioner (an initial guess for $K^{-1}$).

---

## 🧠 Prior Guesses as Initializers

- Think of $\alpha_0$ and $C_0$ as **prior guesses** for the solution $\alpha$ and the inverse $K^{-1}$.
- The iterative algorithms then **refine these estimates** step by step.

---

## 📚 Bayesian Interpretations of Iterative Solvers

- There are **deep Bayesian interpretations** of these iterative solvers:
        - The point estimates for $\alpha$ and $K^{-1}$ can be seen as **posterior means**.
        - The associated uncertainty can also be quantified (see Hennig, 2015; Wenger & Hennig, 2021; Hennig, Osborne, Kersting, 2022).
- This further **blurs the line** between numerical optimization and Bayesian inference.

---

## 📝 Summary

- The **choice of numerical algorithm** is not just an implementation detail—it fundamentally shapes how the model "learns" from data and quantifies uncertainty.
- **Computation = Inference:**  
    In GPs, the act of computation is inseparable from the act of learning.

---

### 📖 **References for Further Reading**

- Hennig, P. (2015). Probabilistic Interpretation of Linear Solvers. *SIAM Journal on Optimization*.
- Wenger, J., & Hennig, P. (2021). Probabilistic Linear Solvers for Machine Learning. *Proceedings of the IEEE*.
- Hennig, P., Osborne, M. A., & Kersting, H. (2022). Probabilistic Numerics: Computation as Machine Learning. *Cambridge University Press*.

---

# IterGP (Final Algorithm with Initial Guesses)

This is the most general form of the IterGP algorithm, incorporating initial guesses for the inverse $C_0$ and solution $\alpha_0$.

## Algorithm 8: Iterative GP Regression (Numerics Layer)

**Input:** Sufficient statistics $K = k_{XX} + \sigma^2 I$, $\bar{y} = y - \mu_X$, initial guesses $\alpha_0$, $C_0$

**Output:** Defragmented statistics $S$ (set of processed vectors), $C_n \approx K^{-1}$, $\alpha_n \approx K^{-1}y$



### IterGP Training Procedure: Mathematical Formulation and Step-by-Step Explanation

The IterGP algorithm iteratively refines estimates of the inverse covariance and solution vector in Gaussian Process regression. Below, we present the procedure as a sequence of mathematical updates, with clear explanations for each step.

---

#### **Initialization**

- **Inverse Estimate:**  
    $$
    C_0 \leftarrow \text{initial guess (often } 0 \text{)}
    $$
- **Solution Estimate:**  
    $$
    \alpha_0 \leftarrow \text{initial guess (often } 0 \text{)}
    $$

---

#### **Iterative Updates (for $i = 1$ to $n$):**

1. **Select Projection Vector(s) (Action/Policy):**
     $$
     s_i \leftarrow \text{POLICY}(S_{<i}, Z_{<i})
     $$
     - $s_i \in \mathbb{R}^{N \times k_i}$  
     - *Choose the next direction(s) to load, possibly using previous projections and observations.*

2. **Compute Projected Kernel Column(s) (Observation):**
     $$
     z_i = K s_i
     $$
     - $z_i \in \mathbb{R}^{N \times k_i}$  
     - *Apply the kernel matrix to the projection(s).*

3. **Compute Low-Rank Update Direction:**
     $$
     d_i = (I - C_{i-1} K) s_i = s_i - C_{i-1} z_i
     $$
     - $d_i \in \mathbb{R}^{N \times k_i}$  
     - *Find the component of $s_i$ not yet explained by previous updates.*

4. **Compute Schur Complement (Normalization Matrix):**
     $$
     H_i = s_i^\top K d_i = z_i^\top d_i
     $$
     - $H_i \in \mathbb{R}^{k_i \times k_i}$  
     - *Normalization for numerical stability and correct scaling.*

5. **Update Inverse Estimate:**
     $$
     C_i = C_{i-1} + d_i H_i^{-1} d_i^\top
     $$
     - *Rank-$k_i$ update to the inverse estimate.*

6. **Update Solution Estimate:**
     $$
     \alpha_i = \alpha_{i-1} + d_i H_i^{-1} d_i^\top \bar{y}
     $$
     - *Refine the solution vector using the new direction.*

---

#### **Return**

- **All Projections:** $S = [s_j]_{j \leq n}$
- **Final Solution Estimate:** $\alpha_n$
- **Final Inverse Estimate:** $C_n$

---

#### **Step-by-Step Intuition**

- **Step 1:** *Choose the next direction(s) to load information from the data. This can be a single data point, a batch, or a projection (e.g., from the Lanczos process).*
- **Step 2:** *Compute how the kernel "sees" this direction—i.e., the corresponding column(s) of the kernel matrix.*
- **Step 3:** *Determine what part of this direction is new (not already explained by previous steps).*
- **Step 4:** *Calculate a normalization factor to ensure updates are stable and properly scaled.*
- **Step 5:** *Update the running estimate of the inverse covariance matrix using a low-rank correction.*
- **Step 6:** *Update the solution vector (posterior mean weights) using the new information.*

---

This formulation provides a flexible, modular, and scalable framework for iterative Gaussian Process regression, enabling both mean and uncertainty estimation at every step.



**Explanation:**

- **Inputs:**
    - $x$: Test point(s) where we want to make predictions.
    - $S$: Set of processed training vectors (could be data points or projections).
    - $\alpha$: Solution vector (posterior weights), typically $\alpha = K^{-1}(y - \mu_X)$.
    - $C$: Inverse covariance estimate, typically $C \approx K^{-1}$.

- **Steps:**
    1. **Covariance to Observations:**  
        $k[x, S]$ computes the similarity between $x$ and each vector in $S$ using the kernel function.
    2. **Predictive Mean:**  
        $\mu_x = \mu(x) + k[x, S]\, \alpha$  
        This gives the expected value of the function at $x$ given the observed data.
    3. **Predictive Variance:**  
        $v_{xx} = k_{xx} - k[x, S]\, C\, k[S, x]$  
        This quantifies the model's uncertainty at $x$, accounting for both prior uncertainty and information gained from the data.

---

This algorithm forms the basis for highly scalable Gaussian Process (GP) methods, where the **POLICY** function can be designed to efficiently explore the data and approximate the necessary quantities. By updating $S$, $\alpha$, and $C$ iteratively, we can make fast, uncertainty-aware predictions at any test location.

# Relationship to Gradient Descent

Let's explore the connection and key differences between **iterative Cholesky/Conjugate Gradient (CG) methods** and **Gradient Descent (GD)** in the context of Gaussian Process (GP) regression.

---

## Problem Setup: GP Posterior Mean as a Quadratic Minimization

Recall that the solution $\alpha^\star$ for the GP posterior mean is the minimizer of the following quadratic objective:

$$
\alpha^\star = \arg\min_{\alpha \in \mathbb{R}^N} L(\alpha) = \arg\min_{\alpha \in \mathbb{R}^N} \frac{1}{2} \alpha^\top (K_{XX} + \sigma^2 I_N) \alpha - (y - \mu_X)^\top \alpha
$$

- $K_{XX}$: Kernel matrix for the training data
- $\sigma^2 I_N$: Noise term (diagonal matrix)
- $y$: Observed targets
- $\mu_X$: Mean function evaluated at training points

The **analytical solution** is:

$$
\alpha^\star = (K_{XX} + \sigma^2 I_N)^{-1} (y - \mu_X)
$$

The **Jacobian** of this solution with respect to $y$ is:

$$
\frac{d\alpha^\star}{dy} = (K_{XX} + \sigma^2 I_N)^{-1}
$$

This Jacobian is also the **posterior covariance** in GP regression.

---

## Gradient Descent (GD) Update Rule

Gradient Descent iteratively updates the parameter vector $\alpha$ using the gradient of the loss:

$$
\alpha_{t+1} = \alpha_t - \eta \nabla_\alpha L(\alpha_t)
$$

where the gradient is:

$$
\nabla_\alpha L(\alpha) = (K_{XX} + \sigma^2 I_N)\alpha - (y - \mu_X)
$$

- $\eta$: Learning rate (step size)

---

## Key Differences: Cholesky/CG vs. Gradient Descent

### 1. **Nature of the Update**

- **Cholesky / Conjugate Gradients (CG):**
    - Iteratively compute the **total derivative** (i.e., the matrix inverse).
    - Aim to find the **exact solution** to the linear system in a finite number of steps (at most $N$ for CG, exactly $N$ for Cholesky).
    - Provide both the **full inverse** $C_i$ and the **exact solution** $\alpha_i$ (up to numerical precision).
    - Directly enable **uncertainty quantification** (posterior covariance).

- **Gradient Descent (GD):**
    - Computes a **partial derivative** (the gradient) at each step and moves in that direction.
    - Typically does **not converge in finite time**; requires many iterations.
    - Only provides a **point estimate** (the mean), without direct access to uncertainty (the inverse covariance).

---

### 2. **Convergence and Computational Cost**

- **Cholesky/CG:**
    - **Converge in finite time** (at most $N$ steps for CG).
    - Each step is more computationally expensive ($\mathcal{O}(N^2)$ per iteration), but the total number of steps is bounded.
    - Suitable for problems where **exact solutions and uncertainty estimates** are required.

- **Gradient Descent:**
    - **Requires many iterations** to converge, especially if the problem is ill-conditioned.
    - Each step is computationally cheap ($\mathcal{O}(N)$ per iteration for vector operations).
    - Does **not provide uncertainty**; only gives a point estimate.

---

### 3. **Uncertainty Quantification**

- **Cholesky/CG:**  
    - By constructing the inverse (or an approximation), these methods provide **direct access to the posterior covariance**—a key feature of GPs.

- **Gradient Descent:**  
    - Only tracks the mean; **uncertainty information is lost** unless additional (often expensive) computations are performed.

---

## Implications and Trade-offs

- **Cholesky/CG** methods are powerful for **exact inference** and **uncertainty quantification** in GPs, but are more computationally intensive per iteration.
- **Gradient Descent** is scalable and simple, but sacrifices **exactness** and **uncertainty** for speed and memory efficiency.

> **Summary Table**

| Method            | Convergence      | Per-Step Cost | Uncertainty | Exact Solution | Typical Use Case                |
|-------------------|------------------|---------------|-------------|---------------|---------------------------------|
| Cholesky          | Finite ($N$)     | $\mathcal{O}(N^2)$ | Yes         | Yes           | Small/medium GPs, exact stats   |
| Conjugate Gradient| Finite ($\leq N$)| $\mathcal{O}(N^2)$ | Yes         | Yes           | Large GPs, iterative inference  |
| Gradient Descent  | Infinite         | $\mathcal{O}(N)$   | No          | No            | Large-scale, point estimation   |

---

## Visual Summary

- **Cholesky/CG:**  
    ![Cholesky/CG: Exact solution in finite steps, with uncertainty](https://raw.githubusercontent.com/your-repo/cholesky-cg-visual.png)
- **Gradient Descent:**  
    ![GD: Gradual convergence, no uncertainty](https://raw.githubusercontent.com/your-repo/gd-visual.png)

---

## Takeaway

The choice between these methods reflects a **trade-off between computational efficiency, convergence guarantees, and the richness of information** (point estimate vs. full posterior) provided by the algorithm. For Gaussian Processes, **iterative linear solvers like CG bridge the gap**—offering both scalability and uncertainty quantification, which are central to probabilistic machine learning.

# 🌟 Final Takeaways: Computation and Learning in Gaussian Processes

This lecture has illuminated the profound connections between **computation** and **inference** in probabilistic machine learning, with a special focus on **Gaussian Processes (GPs)**. Here are the key insights:

---

## 🧮 "Training" is Linear Algebra

- In **least-squares regression** (and thus in GP posterior mean estimation), "training" is fundamentally a **linear algebra problem**.
- Specifically, it reduces to **solving a linear system**:
    $$
    \alpha^\star = (K_{XX} + \sigma^2 I)^{-1}(y - \mu_X)
    $$
- This means that the process of learning from data is, at its core, a computational task rooted in matrix operations.

---

## 📦 Numerical Methods as Data Loaders

- **Numerical algorithms** such as **Cholesky decomposition** and **Conjugate Gradients (CG)** can be viewed as **smart data loaders**.
- These methods **iteratively load projections of the data** (e.g., columns of the kernel matrix or linear combinations thereof) in a specific, efficient order.
- Each iteration refines our estimates of both the **solution** (the mean) and the **uncertainty** (the covariance), making the process a form of learning.

---

## 🔄 No Separation of Computing and Learning

- For Gaussian Processes, there is **no fundamental separation** between "computing" and "learning."
- **Numerical algorithms are learning machines**:  
        - They process information from the data step by step.
        - They iteratively refine both the solution and our uncertainty about it.
- This perspective unifies the concepts of computation and inference, showing that **learning is computation, and computation is learning**.

---

## 🚀 Why Does This Matter?

- **Scalability:**  
        Understanding this connection is crucial for scaling probabilistic models to large datasets.
- **Algorithm Design:**  
        It inspires the development of new, efficient algorithms that bridge the gap between **exact Bayesian inference** and the **practical demands of big data**.
- **Uncertainty Quantification:**  
        Unlike many optimization methods, these algorithms provide not just point estimates, but also **uncertainty quantification**—a hallmark of probabilistic machine learning.

---

> **In summary:**  
> The act of "training" a Gaussian Process is not just about fitting a model—it's about performing computation in a way that is inherently probabilistic and iterative. By embracing this perspective, we can design algorithms that are both efficient and principled, bringing together the best of computation and learning.

---