In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Understanding Gaussian Processes: Kernels and Their Properties

Welcome back to the blog! In our last post, we introduced **Gaussian Processes** as a powerful non-parametric model for regression, allowing us to model a distribution over functions. We saw how GPs are defined by a **mean function** and a **covariance function** (also known as a **kernel**).

In this post, we'll dive deeper into the fascinating world of **kernels**. We'll explore different types of kernels, understand how they shape the functions a GP can represent, and learn how new kernels can be constructed from existing ones.

## Recap: Lazy Evaluation and Kernels

Let's quickly recap the connection we made in the previous lecture between probabilistic parametric regression and Gaussian Processes.

---

### From Linear Models to Gaussian Processes

We started with a linear model with features $\phi(x)$ and weights $w$ with a Gaussian prior:

$$
f(x) = \phi(x)^\top w
$$

$$
p(w) = \mathcal{N}(w; \mu, \Sigma)
$$

We found that the marginal distribution over the function $f(\cdot)$ (before observing any data) is also a Gaussian Process:

$$
p(f(\cdot)) = \int p(f(\cdot) \mid w) p(w) \, dw
$$

Assuming $p(f(\cdot) \mid w) = \mathcal{N}(f(\cdot); \phi(\cdot)^\top w, \sigma^2 I)$ (where $\sigma^2$ here represents the variance of the function itself before adding observation noise), this integral results in:

$$
p(f(\cdot)) = \mathcal{N}(f(\cdot); \phi(\cdot)^\top \mu, \phi(\cdot)^\top \Sigma \phi(\circ) + \sigma^2 I)
$$

> **Note:** The slide uses $\sigma I$ in the first line, which might be a slight abuse of notation or implies a specific context. The standard derivation for the marginal GP typically assumes the function values themselves are distributed with a mean and covariance determined by the prior over $w$, and then observation noise is added later. Let's stick to the standard interpretation where the kernel defines the covariance of the noise-free function values.

---

### GP Mean and Kernel Functions

So, the marginal distribution of the function $f(\cdot)$ is a GP with:

- **Mean function:** $m(\cdot) := \phi(\cdot)^\top \mu$
- **Covariance function (kernel):** $k(\cdot, \circ) := \phi(\cdot)^\top \Sigma \phi(\circ)$

This shows how the mean and kernel functions of a GP naturally emerge from a probabilistic parametric model with a Gaussian prior on the weights. The "lazy evaluation" refers to the fact that we define the distribution over functions directly through the mean and kernel, rather than explicitly working with the (potentially infinite) feature vectors and weights.

---

### What is a Kernel?

We also revisited the definition of a kernel:  
A function $k: \mathcal{X} \times \mathcal{X} \to \mathbb{R}$ is a **Mercer (or positive definite) kernel** if for any finite set of points $X = x_1, \ldots, x_N$, the matrix $K_{XX}$ with $[K_{XX}]_{ij} = k(x_i, x_j)$ is symmetric and positive semidefinite ($v^\top K_{XX} v \geq 0$ for all $v \in \mathbb{R}^N$).

## Kernels and Feature Expansions Revisited

The connection between **kernels** and **feature expansions** is fundamental. As stated by **Mercer's Theorem**, a function is a valid kernel if and only if it can be written as an inner product in some (possibly infinite-dimensional) feature space:

$$
k(x, x') = \langle \phi(x), \phi(x') \rangle_{\mathcal{H}}
$$

or, more generally, with a positive measure $\nu$:

$$
k(x, x') = \int_{L} \phi_{\ell}(x)\, \phi_{\ell}(x')\, d\nu(\ell)
$$

This means that **every kernel implicitly defines a feature space**. When we use a kernel, we are essentially working with a model that has a potentially infinite number of features, but the **kernel trick** allows us to compute the necessary covariances without explicitly handling these features.

## Gaussian Process Definition Revisited

Let's restate the definition of a **Gaussian Process**, emphasizing its nature as a distribution over functions:

A function $f: \mathcal{X} \to \mathbb{R}$ is drawn from a Gaussian Process with mean function $m: \mathcal{X} \to \mathbb{R}$ and Mercer kernel $k: \mathcal{X} \times \mathcal{X} \to \mathbb{R}$, denoted as

$$
f \sim \mathcal{GP}(m, k),
$$

if for any finite set of input points $X = x_1, \ldots, x_N$, the vector of function values

$$
f_X = [f(x_1), \ldots, f(x_N)]^\top
$$

follows a multivariate Gaussian distribution:

$$
p(f_X) = \mathcal{N}(f_X; m_X, K_{XX})
$$

where

- $m_X = [m(x_1), \ldots, m(x_N)]^\top$
- $K_{XX}$ is the kernel matrix with $[K_{XX}]_{ij} = k(x_i, x_j)$

---

The Python `GaussianProcess` class we defined previously captures this idea by providing the mean and kernel functions.

In [3]:
import jax.numpy as jnp
from typing import Callable
import dataclasses


@dataclasses.dataclass
class GaussianProcess:
    """
    Conceptual representation of a Gaussian Process.

    A GP is defined by its mean function and covariance function (kernel).
    Any finite collection of function values from a GP follows a multivariate Gaussian distribution.
    """

    # Mean function: maps input(s) to mean value(s)
    m: Callable[[jnp.ndarray], jnp.ndarray]

    # Covariance function (kernel): maps two input(s) to a covariance matrix
    k: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]

    def __call__(self, X: jnp.ndarray):
        """
        Evaluates the GP at a set of input points X.
        Returns the mean vector and covariance matrix for the function values at X.

        Args:
            X: A JAX array of input points. Shape (N, D) where N is the number of points
               and D is the input dimension.

        Returns:
            A dictionary containing:
                'mu': Mean vector at X. Shape (N,).
                'Sigma': Covariance matrix at X. Shape (N, N).
        """
        # Ensure X is at least 2D
        X = jnp.atleast_2d(X)
        # Evaluate the mean function at all input points
        mu_X = self.m(X)

        # Evaluate the covariance function (kernel) for all pairs of input points
        K_XX = self.k(X, X)

        return {"mu": mu_X, "Sigma": K_XX}


# Conditioning Gaussian Processes (The Posterior GP)

One of the most powerful aspects of **Gaussian Processes** is that they are **closed under conditioning**. This means that if we start with a GP prior over functions and observe some data $(X_{\text{train}}, y_{\text{train}})$, the resulting posterior distribution over functions is also a Gaussian Process. This is incredibly convenient for analytical inference.

---

## Joint Distribution

Recall the joint distribution of training outputs $y_{\text{train}}$ and test function values $f_{\text{test}}$ from the previous lecture:

$$
\begin{pmatrix}
y_{\text{train}} \\
f_{\text{test}}
\end{pmatrix}
\sim \mathcal{N}\left(
\begin{pmatrix}
m(X_{\text{train}}) \\
m(X_{\text{test}})
\end{pmatrix},
\begin{pmatrix}
K_{\text{train,train}} + \sigma_{\text{noise}}^2 I & K_{\text{train,test}} \\
K_{\text{test,train}} & K_{\text{test,test}}
\end{pmatrix}
\right)
$$

---

## Posterior GP

The posterior distribution over the function $f$ given the training data is a new Gaussian Process, $\mathcal{GP}(m_{\text{post}}, k_{\text{post}})$, with a posterior mean function $m_{\text{post}}(x)$ and a posterior covariance function $k_{\text{post}}(x, x')$.

The formulas for the predictive mean and covariance for a finite set of test points $X_{\text{test}}$ can be generalized to define these posterior functions for any input point $x$:

### Posterior Mean Function

$$
m_{\text{post}}(x) = m(x) + k(x, X_{\text{train}}) \left[ K_{\text{train,train}} + \sigma_{\text{noise}}^2 I \right]^{-1} \left( y_{\text{train}} - m(X_{\text{train}}) \right)
$$

### Posterior Covariance Function

$$
k_{\text{post}}(x, x') = k(x, x') - k(x, X_{\text{train}}) \left[ K_{\text{train,train}} + \sigma_{\text{noise}}^2 I \right]^{-1} k(X_{\text{train}}, x')
$$

Here, $k(x, X_{\text{train}})$ is a row vector $[k(x, x_{\text{train},1}), \ldots, k(x, x_{\text{train},N_{\text{train}}})]$, and $k(X_{\text{train}}, x')$ is a column vector $[k(x_{\text{train},1}, x'), \ldots, k(x_{\text{train},N_{\text{train}}}, x')]^\top$.

---

## Notation from Slides

The slides present these formulas in a slightly different notation, emphasizing the GP nature:

$$
\mathcal{GP}\left(
f;
m_{\bullet} + k_{\bullet, X} \alpha_X,
\;
k_{\bullet, \circ} - k_{\bullet, X} (K_{XX} + \sigma^2 I)^{-1} k_{X, \circ}
\right)
$$

where:

- $\bullet$ and $\circ$ represent arbitrary input points.
- $m_{\bullet}$ is $m(x)$.
- $k_{\bullet, X}$ is $k(x, X_{\text{train}})$.
- $m_X$ is $m(X_{\text{train}})$.
- $K_{XX}$ is $K_{\text{train,train}}$.
- $k_{\bullet, \circ}$ is $k(x, x')$.
- $k_{X, \circ}$ is $k(X_{\text{train}}, x')$.
- $\alpha_X := (K_{XX} + \sigma^2 I)^{-1} (y - m_X)$ is a vector resulting from solving a linear system.

So the posterior mean function is $m(x) + k(x, X_{\text{train}}) \alpha_X$, and the posterior covariance function is $k(x, x') - k(x, X_{\text{train}}) (K_{XX} + \sigma^2 I)^{-1} k(X_{\text{train}}, x')$.

---

These match our formulas. The result of conditioning a GP on data is indeed another GP, with updated mean and covariance functions.

# Exploring Different Kernels

The choice of **kernel** is critical because it determines the properties of the functions that the GP can model. Different kernels encode different assumptions about the underlying function. Let's look at some examples beyond the Squared Exponential kernel.

---

## The Wiener Process Kernel

The **Wiener process**, also known as **Brownian motion**, is one of the oldest Gaussian Processes. It can be constructed by integrating white noise. Alternatively, as shown in the slides, it can be derived from a feature expansion using step functions $\theta(x - c_i)$:

$$
\phi_i(x) = \theta(x - c_i) = 
\begin{cases}
1 & \text{if } x \geq c_i \\
0 & \text{else}
\end{cases}
$$

With a specific choice of $\Sigma$, as the number of features $F$ goes to infinity, the kernel approaches:

$$
k(x_i, x_j) \to \sigma^2 \int_{c_0}^{c_{\max}} \theta(\min(x_i, x_j) - c) \, dc = \sigma^2 \int_{c_0}^{\min(x_i, x_j)} 1 \, dc
$$

Assuming $c_0 \leq \min(x_i, x_j)$, this integral evaluates to:

$$
k_{\text{Wiener}}(x_i, x_j) = \sigma^2 (\min(x_i, x_j) - c_0)
$$

This kernel results in functions that are **continuous but not differentiable everywhere** (like a random walk). The parameter $\sigma^2$ controls the variance, and $c_0$ is a starting point where the function value is fixed (typically at $0$ if the mean function is $0$).

---

Let's implement the Wiener kernel and sample from a GP with this kernel.

In [14]:
import jax.numpy as jnp


def wiener_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, c0: float = 0.0
) -> jnp.ndarray:
    """
    Computes the Wiener Process kernel matrix.

    Args:
        x1: First set of input points. Shape (N1, D). Assumes D=1 for this kernel.
        x2: Second set of input points. Shape (N2, D). Assumes D=1 for this kernel.
        sigma: Scaling factor (variance).
        c0: Starting point (typically where the process starts at 0).

    Returns:
        The kernel matrix K, where K[i, j] = k(x1[i], x2[j]). Shape (N1, N2).
    """
    # Ensure inputs are JAX arrays and are 1D (or can be treated as 1D)
    x1 = jnp.atleast_1d(x1).squeeze()
    x2 = jnp.atleast_1d(x2).squeeze()

    # Compute the minimum of all pairs of points
    # (N1, 1) vs (1, N2) -> (N1, N2)
    min_x = jnp.minimum(x1[:, None], x2[None, :])

    # Compute the kernel matrix
    K = sigma**2 * jnp.maximum(0.0, min_x - c0)  # Ensure non-negativity

    return K


# Example usage:
points1 = jnp.array([0.0, 1.0, 2.0])
points2 = jnp.array([0.5, 1.5])
kernel_matrix = wiener_kernel(points1, points2, sigma=1.0, c0=0.0)
print("Wiener Kernel Matrix:\n", kernel_matrix)


Wiener Kernel Matrix:
 [[0.  0. ]
 [0.5 1. ]
 [0.5 1.5]]


In [36]:
import jax.random as random
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax.scipy.linalg import cholesky  # For sampling
from functools import partial


# Function to sample from a GP (assuming zero mean for simplicity)
def sample_from_gp(
    X: jnp.ndarray, kernel_func: Callable, num_samples: int = 1
) -> jnp.ndarray:
    """
    Draws samples from a Gaussian Process with zero mean.

    Args:
        X: Input points to sample at. Shape (N, D).
        kernel_func: Covariance function (kernel).
        num_samples: Number of samples to draw.

    Returns:
        An array of samples. Shape (num_samples, N).
    """
    # Ensure X is at least 2D
    X = jnp.atleast_2d(X)

    # Compute the covariance matrix
    K_XX = kernel_func(X, X)

    # Add a small jitter for numerical stability in Cholesky decomposition
    jitter = 1e-4 * jnp.eye(X.shape[0])
    K_XX += jitter

    # Compute the Cholesky decomposition of the covariance matrix
    L = cholesky(K_XX, lower=True)

    # Generate random standard normal variables
    key = random.PRNGKey(
        jnp.sum(X).astype(int)
    )  # Simple way to get a key based on inputs
    z = random.normal(key, shape=(num_samples, X.shape[0]))

    # Compute the samples: mean + L @ z^T (mean is zero here)
    # L is (N, N), z^T is (N, num_samples). Result is (N, num_samples).
    # Transpose to get (num_samples, N)
    samples = jnp.dot(L, z.T).T

    return samples


Cholesky decomposition returns a `nan` array when the input matrix is **not positive definite** (or not numerically positive definite due to rounding errors). This often happens in Gaussian Process code if the kernel matrix is nearly singular or has negative/zero eigenvalues due to:

1. **Numerical precision issues** (the kernel matrix is very close to singular).
2. **Insufficient jitter** (the diagonal regularization is too small).
3. **A bug in the kernel function** (e.g., returning negative values on the diagonal).
4. **Duplicate or very close input points** (leading to identical or nearly identical rows/columns).

**How to debug:**
- Check if your kernel matrix `K_XX` has any `nan` or negative values on the diagonal:  
    ```python
    print(jnp.diag(K_XX))
    print(jnp.any(jnp.isnan(K_XX)), jnp.any(jnp.diag(K_XX) <= 0))
    ```
- Try increasing the jitter (e.g., `1e-6` → `1e-4` or higher):
    ```python
    jitter = 1e-4 * jnp.eye(K_XX.shape[0])
    K_XX += jitter
    ```
- Check for duplicate or very close points in your input `X_test`.

**Summary:**  
Cholesky returns `nan` if the matrix is not positive definite. Add more jitter and check your kernel and data for issues.

In [37]:
# --- Sample from a GP with Wiener Kernel ---
# Generate input points
X_test = jnp.linspace(0.0, 8.0, 200)[:, None]  # Test inputs (needs to be 2D)

# Define the Wiener kernel
wiener_sigma = 1.0
wiener_c0 = 0.0  # Start at 0
wiener_k = partial(wiener_kernel, sigma=wiener_sigma, c0=wiener_c0)

# Sample functions
num_samples = 5
wiener_samples = sample_from_gp(X_test, wiener_k, num_samples=num_samples)


In [38]:
import plotly.graph_objs as go

fig = go.Figure()

for i in range(num_samples):
    fig.add_trace(
        go.Scatter(
            x=X_test[:, 0],
            y=wiener_samples[i, :],
            mode="lines",
            name=f"Sample {i + 1}",
            opacity=0.7,
        )
    )

# Add the mean function (zero)
fig.add_trace(
    go.Scatter(
        x=X_test[:, 0],
        y=jnp.zeros_like(X_test[:, 0]),
        mode="lines",
        name="Mean Function",
        line=dict(color="black", dash="dash"),
    )
)

fig.update_layout(
    title="Samples from a Gaussian Process with Wiener Kernel",
    xaxis_title="x",
    yaxis_title="f(x)",
    template="simple_white",
)

fig.show()


# Spline Kernels

Another class of kernels can be derived from integrating simpler basis functions. The slides mention a connection to **splines**, particularly **cubic splines**.

Consider integrating the step function features used for the Wiener process:

- Integrating $\theta(x - c_i)$ once gives $\max(0, x - c_i)$, which are sometimes called **Rectified Linear Unit (ReLU)** features in the context of neural networks.
- Integrating twice gives piecewise linear functions.
- Integrating thrice gives piecewise quadratic functions, and so on.

---

## Integrating a GP

Integrating a GP also results in a GP. If $f \sim \mathcal{GP}(m, k)$, and

$$
\tilde{f}(x) = \int_{-\infty}^x f(\tilde{x})\, d\tilde{x},
$$

then $\tilde{f} \sim \mathcal{GP}(\tilde{m}, \tilde{k})$, where

- $\tilde{m}(x) = \int_{-\infty}^x m(\tilde{x})\, d\tilde{x}$
- $\tilde{k}(x, x') = \int_{-\infty}^x \int_{-\infty}^{x'} k(\tilde{x}, \tilde{x}')\, d\tilde{x}'\, d\tilde{x}$

---

## The Cubic Spline Kernel

The slides show that integrating the Wiener kernel twice (which corresponds to integrating the step function features twice) leads to a **cubic spline kernel** (specifically, a reproducing kernel for cubic splines with knots at the origin). The formula given is:

$$
k_{\text{CubicSpline}}(x_i, x_j) = \sigma^2 \left( \frac{1}{3} \min^3(x_i - x_0, x_j - x_0) + \frac{1}{2} |x_i - x_j| \min^2(x_i - x_0, x_j - x_0) \right)
$$

where $x_0$ is a reference point (often $0$). This kernel produces functions that are smoother than the Wiener process, being **twice differentiable**.

---

Let's implement this kernel.

In [39]:
import jax.numpy as jnp


def cubic_spline_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, x0: float = 0.0
) -> jnp.ndarray:
    """
    Computes the Cubic Spline kernel matrix.

    Args:
        x1: First set of input points. Shape (N1, D). Assumes D=1.
        x2: Second set of input points. Shape (N2, D). Assumes D=1.
        sigma: Scaling factor.
        x0: Reference point.

    Returns:
        The kernel matrix K. Shape (N1, N2).
    """
    # Ensure inputs are JAX arrays and are 1D
    x1 = jnp.atleast_1d(x1).squeeze()
    x2 = jnp.atleast_1d(x2).squeeze()

    # Shift inputs by x0
    x1_shifted = x1 - x0
    x2_shifted = x2 - x0

    # Compute min(x1_shifted, x2_shifted) for all pairs
    min_shifted = jnp.minimum(x1_shifted[:, None], x2_shifted[None, :])

    # Compute absolute difference |x1 - x2| for all pairs
    abs_diff = jnp.abs(x1[:, None] - x2[None, :])

    # Compute the kernel matrix
    term1 = (1.0 / 3.0) * jnp.maximum(
        0.0, min_shifted
    ) ** 3  # Use maximum to handle cases where min is less than x0
    term2 = (1.0 / 2.0) * abs_diff * jnp.maximum(0.0, min_shifted) ** 2

    K = sigma**2 * (term1 + term2)

    return K


# Example usage:
points1 = jnp.array([0.0, 1.0, 2.0])
points2 = jnp.array([0.5, 1.5])
kernel_matrix = cubic_spline_kernel(points1, points2, sigma=1.0, x0=0.0)
print("Cubic Spline Kernel Matrix:\n", kernel_matrix)


Cubic Spline Kernel Matrix:
 [[0.         0.        ]
 [0.10416667 0.5833334 ]
 [0.22916667 1.6875    ]]


In [40]:
# --- Sample from a GP with Cubic Spline Kernel ---
# Generate input points
X_test = jnp.linspace(-8.0, 8.0, 200)[:, None]  # Test inputs

# Define the Cubic Spline kernel
cs_sigma = 1.0
cs_x0 = 0.0
cs_k = partial(cubic_spline_kernel, sigma=cs_sigma, x0=cs_x0)

# Sample functions
num_samples = 5
cs_samples = sample_from_gp(X_test, cs_k, num_samples=num_samples)

In [41]:
import plotly.graph_objs as go

fig = go.Figure()

for i in range(num_samples):
    fig.add_trace(
        go.Scatter(
            x=X_test[:, 0],
            y=cs_samples[i, :],
            mode="lines",
            name=f"Sample {i + 1}",
            opacity=0.7,
        )
    )

# Add the mean function (zero)
fig.add_trace(
    go.Scatter(
        x=X_test[:, 0],
        y=jnp.zeros_like(X_test[:, 0]),
        mode="lines",
        name="Mean Function",
        line=dict(color="black", dash="dash"),
    )
)

fig.update_layout(
    title="Samples from a Gaussian Process with Cubic Spline Kernel",
    xaxis_title="x",
    yaxis_title="f(x)",
    template="simple_white",
)

fig.show()


# The Matérn Family of Kernels

The **Matérn family** is a very popular class of kernels that allows us to control the smoothness of the functions drawn from the GP using a parameter $\nu$. The general form involves the modified Bessel function of the second kind $K_\nu$:

$$
k_{\nu, \ell}(|r|) = \frac{2^{1-\nu}}{\Gamma(\nu)} \left( \frac{\sqrt{2\nu}|r|}{\ell} \right)^\nu K_\nu \left( \frac{\sqrt{2\nu}|r|}{\ell} \right)
$$

where $|r| = \|x - x'\|$ is the distance between input points and $\ell$ is the length scale.

---

## Smoothness Parameter $\nu$

The parameter $\nu$ influences the differentiability of the functions. For $\nu = p + 1/2$ where $p$ is a non-negative integer, the Matérn kernel has a simpler closed form and corresponds to functions that are $p$ times mean-square differentiable.

---

## Common Special Cases

- **$\nu = 1/2$ (Ornstein-Uhlenbeck kernel):**  
    Results in functions that are continuous but not differentiable.
    $$
    k_{1/2, \ell}(|r|) = \exp\left(-\frac{|r|}{\ell}\right)
    $$

- **$\nu = 3/2$:**  
    Functions are once mean-square differentiable.
    $$
    k_{3/2, \ell}(|r|) = \left(1 + \frac{\sqrt{3}|r|}{\ell}\right) \exp\left(-\frac{\sqrt{3}|r|}{\ell}\right)
    $$

- **$\nu = 5/2$:**  
    Functions are twice mean-square differentiable.
    $$
    k_{5/2, \ell}(|r|) = \left(1 + \frac{\sqrt{5}|r|}{\ell} + \frac{5|r|^2}{3\ell^2}\right) \exp\left(-\frac{\sqrt{5}|r|}{\ell}\right)
    $$

- **As $\nu \to \infty$ (Squared Exponential / RBF kernel):**  
    Produces infinitely differentiable functions.
    $$
    k_{\infty, \ell}(|r|) = \exp\left(-\frac{|r|^2}{2\ell^2}\right)
    $$

---

Let's implement these specific Matérn kernels and sample from GPs using them to see the effect of $\nu$.

In [8]:
import jax.numpy as jnp


def matern_kernel(
    x1: jnp.ndarray,
    x2: jnp.ndarray,
    nu: float,
    lengthscale: float = 1.0,
    sigma: float = 1.0,
) -> jnp.ndarray:
    """
    Computes the Matérn kernel matrix for specific nu values (1/2, 3/2, 5/2).

    Args:
        x1: First set of input points. Shape (N1, D).
        x2: Second set of input points. Shape (N2, D).
        nu: The parameter controlling smoothness (1/2, 3/2, or 5/2).
        lengthscale: Length scale hyperparameter.
        sigma: Output variance (amplitude) hyperparameter.

    Returns:
        The kernel matrix K. Shape (N1, N2).
    Raises:
        ValueError: If nu is not one of the implemented values.
    """
    # Ensure inputs are JAX arrays and have at least 2 dimensions
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)

    # Compute the Euclidean distance between all pairs of points
    # (N1, 1, D) - (1, N2, D) -> (N1, N2, D)
    # Sum over the last dimension and take sqrt: (N1, N2)
    dist = jnp.sqrt(jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1))

    # Compute the scaled distance
    scaled_dist = dist / lengthscale

    if nu == 0.5:
        K = sigma**2 * jnp.exp(-scaled_dist)
    elif nu == 1.5:
        K = (
            sigma**2
            * (1.0 + jnp.sqrt(3.0) * scaled_dist)
            * jnp.exp(-jnp.sqrt(3.0) * scaled_dist)
        )
    elif nu == 2.5:
        K = (
            sigma**2
            * (1.0 + jnp.sqrt(5.0) * scaled_dist + (5.0 / 3.0) * scaled_dist**2)
            * jnp.exp(-jnp.sqrt(5.0) * scaled_dist)
        )
    else:
        raise ValueError(f"Matérn nu={nu} not implemented. Use 0.5, 1.5, or 2.5.")

    return K


# Implement the RBF kernel again for comparison (Matérn nu -> infinity)
def rbf_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, lengthscale: float = 1.0, sigma: float = 1.0
) -> jnp.ndarray:
    """
    Computes the Squared Exponential (RBF) kernel matrix.
    This is the limit of the Matérn kernel as nu -> infinity.
    """
    # Ensure inputs are JAX arrays and have at least 2 dimensions
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)

    # Compute the squared Euclidean distance
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)

    # Compute the kernel matrix
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# Example usage:
# points1 = jnp.array([[0.0], [1.0], [2.0]])
# points2 = jnp.array([[0.5], [1.5]])
# matern_1_2_k = matern_kernel(points1, points2, nu=0.5, lengthscale=1.0, sigma=1.0)
# print("Matérn nu=1/2 Kernel Matrix:\n", matern_1_2_k)


In [42]:
# --- Sample from GPs with Different Matérn Kernels ---
# Generate input points
X_test = jnp.linspace(-8.0, 8.0, 200)[:, None]  # Test inputs

# Define kernels with different nu values
matern_1_2_k = lambda x1, x2: matern_kernel(x1, x2, nu=0.5, lengthscale=2.0, sigma=1.0)
matern_3_2_k = lambda x1, x2: matern_kernel(x1, x2, nu=1.5, lengthscale=2.0, sigma=1.0)
matern_5_2_k = lambda x1, x2: matern_kernel(x1, x2, nu=2.5, lengthscale=2.0, sigma=1.0)
rbf_k = lambda x1, x2: rbf_kernel(
    x1, x2, lengthscale=2.0, sigma=1.0
)  # Matérn nu -> infinity

# Sample functions from each GP
num_samples_per_kernel = 3
matern_1_2_samples = sample_from_gp(
    X_test, matern_1_2_k, num_samples=num_samples_per_kernel
)
matern_3_2_samples = sample_from_gp(
    X_test, matern_3_2_k, num_samples=num_samples_per_kernel
)
matern_5_2_samples = sample_from_gp(
    X_test, matern_5_2_k, num_samples=num_samples_per_kernel
)
rbf_samples = sample_from_gp(X_test, rbf_k, num_samples=num_samples_per_kernel)


In [45]:
import plotly.subplots as sp
import plotly.graph_objs as go

fig = sp.make_subplots(
    rows=4,
    cols=1,
    shared_xaxes=True,
    shared_yaxes=True,
    vertical_spacing=0.05,
    subplot_titles=[
        "Matérn ν=1/2 (Ornstein-Uhlenbeck)",
        "Matérn ν=3/2",
        "Matérn ν=5/2",
        "Matérn ν→∞ (RBF)",
    ],
)

for i in range(num_samples_per_kernel):
    fig.add_trace(
        go.Scatter(
            x=X_test[:, 0],
            y=matern_1_2_samples[i, :],
            mode="lines",
            name=f"ν=1/2 Sample {i + 1}",
            showlegend=(i == 0),
        ),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=X_test[:, 0],
            y=matern_3_2_samples[i, :],
            mode="lines",
            name=f"ν=3/2 Sample {i + 1}",
            showlegend=False,
        ),
        row=2,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=X_test[:, 0],
            y=matern_5_2_samples[i, :],
            mode="lines",
            name=f"ν=5/2 Sample {i + 1}",
            showlegend=False,
        ),
        row=3,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=X_test[:, 0],
            y=rbf_samples[i, :],
            mode="lines",
            name=f"RBF Sample {i + 1}",
            showlegend=False,
        ),
        row=4,
        col=1,
    )

fig.update_xaxes(title_text="x", row=4, col=1)
for r in range(1, 5):
    fig.update_yaxes(title_text="f(x)", row=r, col=1)

fig.update_layout(
    height=900,
    width=1200,
    title_text="Samples from Gaussian Processes with Different Matérn Kernels",
    template="simple_white",
)

fig.show()


# The Rational Quadratic Kernel

The **Rational Quadratic (RQ) kernel** can be seen as an infinite mixture of Squared Exponential (SE) kernels with different length scales. This gives it the ability to model functions that have variations at multiple length scales.

Its formula is:

$$
k_{\text{RQ}}(|r|) = \sigma^2 \left(1 + \frac{\|x - x'\|^2}{2 \alpha \ell^2}\right)^{-\alpha}
$$

Here, $\alpha$ is a positive parameter that determines the scale-mixture. A larger $\alpha$ corresponds to mixing SE kernels over a wider range of length scales. As $\alpha \to \infty$, the RQ kernel approaches the SE kernel.

Let's implement the RQ kernel.

In [10]:
import jax.numpy as jnp


def rational_quadratic_kernel(
    x1: jnp.ndarray,
    x2: jnp.ndarray,
    alpha: float = 1.0,
    lengthscale: float = 1.0,
    sigma: float = 1.0,
) -> jnp.ndarray:
    """
    Computes the Rational Quadratic kernel matrix.

    Args:
        x1: First set of input points. Shape (N1, D).
        x2: Second set of input points. Shape (N2, D).
        alpha: Scale-mixture parameter.
        lengthscale: Length scale hyperparameter.
        sigma: Output variance (amplitude) hyperparameter.

    Returns:
        The kernel matrix K. Shape (N1, N2).
    """
    # Ensure inputs are JAX arrays and have at least 2 dimensions
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)

    # Compute the squared Euclidean distance
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)

    # Compute the kernel matrix
    K = sigma**2 * (1.0 + sq_dist / (2.0 * alpha * lengthscale**2)) ** (-alpha)

    return K


# Example usage:
# points1 = jnp.array([[0.0], [1.0], [2.0]])
# points2 = jnp.array([[0.5], [1.5]])
# kernel_matrix = rational_quadratic_kernel(points1, points2, alpha=1.0, lengthscale=1.0, sigma=1.0)
# print("Rational Quadratic Kernel Matrix:\n", kernel_matrix)


In [48]:
# --- Sample from a GP with Rational Quadratic Kernel ---
# Generate input points
X_test = jnp.linspace(-8.0, 8.0, 200)[:, None]  # Test inputs

# Define the Rational Quadratic kernel
rq_alpha = 1.0
rq_lengthscale = 1.0
rq_sigma = 1.0
rq_k = partial(
    rational_quadratic_kernel,
    alpha=rq_alpha,
    lengthscale=rq_lengthscale,
    sigma=rq_sigma,
)

# Sample functions
num_samples = 5
rq_samples = sample_from_gp(X_test, rq_k, num_samples=num_samples)

import plotly.graph_objs as go

fig = go.Figure()

for i in range(num_samples):
    fig.add_trace(
        go.Scatter(
            x=X_test[:, 0],
            y=rq_samples[i, :],
            mode="lines",
            name=f"Sample {i + 1}",
            opacity=0.7,
        )
    )

# Add the mean function (zero)
fig.add_trace(
    go.Scatter(
        x=X_test[:, 0],
        y=jnp.zeros_like(X_test[:, 0]),
        mode="lines",
        name="Mean Function",
        line=dict(color="black", dash="dash"),
    )
)

fig.update_layout(
    title="Samples from a Gaussian Process with Rational Quadratic Kernel",
    xaxis_title="x",
    yaxis_title="f(x)",
    template="simple_white",
)

fig.show()


# Building New Kernels from Existing Ones

A powerful property of kernels is that we can **combine existing valid kernels to create new ones**. This allows us to construct complex kernels that capture various properties of the data.

The slides list several ways to combine kernels:

---

If $k_1(x, x')$ and $k_2(x, x')$ are Mercer kernels, and $\phi: Y \to X$ is any function, then the following are also Mercer kernels:

---

### 1. Scaling

$$
\alpha \cdot k_1(x, x'), \quad \text{for any } \alpha \in \mathbb{R}^+
$$

Scaling a kernel scales the variance of the GP.

---

### 2. Transformation of Inputs

$$
k_1(\phi(y), \phi(y')), \quad \text{for } y, y' \in Y
$$

Applying a transformation to the inputs before evaluating the kernel can introduce non-linearities and capture different structures.  
This is how length scales are often incorporated into kernels like the RBF:

$$
k(x, x') = k_{\text{base}}\left(\frac{x}{\ell}, \frac{x'}{\ell}\right)
$$

---

### 3. Addition

$$
k_1(x, x') + k_2(x, x')
$$

Adding kernels allows the GP to model functions that are a sum of functions drawn from the individual GPs.  
The resulting GP has a covariance function that is the sum of the individual kernels.  
This is useful for modeling components of a function, e.g., a smooth trend plus a periodic component.

If $f_1 \sim \mathcal{GP}(m_1, k_1)$ and $f_2 \sim \mathcal{GP}(m_2, k_2)$, then

$$
f_1 + f_2 \sim \mathcal{GP}(m_1 + m_2, k_1 + k_2)
$$

---

### 4. Multiplication (Schur Product)

$$
k_1(x, x') \cdot k_2(x, x')
$$

Multiplying kernels is less intuitive than adding them, but it is a valid way to construct new kernels.

If $f_1 \sim \mathcal{GP}(0, k_1)$ and $f_2 \sim \mathcal{GP}(0, k_2)$ are independent GPs, their product $h(x) = f_1(x) f_2(x)$ is **not generally a GP**.  
However, the product of their kernels is a valid kernel.  
This can be used to model interactions or dependencies between different aspects of the function.

---

These rules form a **"semi-ring" of kernels**, meaning that kernels are closed under addition and multiplication by non-negative scalars, and multiplication.  
This provides a powerful framework for designing complex kernels by combining simpler ones.

The slides illustrate examples of these combinations, showing:

- How scaling affects the amplitude of samples,
- How transforming inputs (like scaling by a length scale or using a non-linear transformation) changes the effective length scale or introduces non-stationarity,
- How sums and products of kernels can create functions with combined properties.

# Learning the Kernel (Hyperparameter Optimization)

So far, we've been assuming that the hyperparameters of our kernels (like the length scale $\ell$, variance $\sigma^2$, noise variance $\sigma_{\text{noise}}^2$, $\nu$ for Matérn, $\alpha$ for RQ) are fixed. However, in practice, we need to **learn these hyperparameters from the data**.

This is typically done using **Hierarchical Bayesian Inference** or **Bayesian model adaptation**. We treat the hyperparameters $\theta$ as unknown variables and place a prior distribution $p(\theta)$ over them. Then, we aim to find the posterior distribution over the hyperparameters given the data:

$$
p(\theta \mid y) = \frac{p(y \mid \theta) p(\theta)}{p(y)}
$$

The term $p(y \mid \theta)$ is the **marginal likelihood** or **model evidence**. It represents the probability of observing the data given the hyperparameters, marginalized over the function $f$:

$$
p(y \mid \theta) = \int p(y \mid f, \theta) \, p(f \mid \theta) \, df
$$

For Gaussian Processes with Gaussian likelihoods, this marginal likelihood can be computed analytically. Recall the joint distribution of $y$ and $f_{\text{test}}$. The marginal distribution of $y$ is:

$$
p(y \mid X_{\text{train}}, \theta) = \mathcal{N}\left(y; m(X_{\text{train}}), K_{\text{train,train}} + \sigma_{\text{noise}}^2 I\right)
$$

The **log-marginal likelihood** is then:

$$
\log p(y \mid X_{\text{train}}, \theta) = 
- \frac{1}{2} (y - m(X_{\text{train}}))^\top (K_{\text{train,train}} + \sigma_{\text{noise}}^2 I)^{-1} (y - m(X_{\text{train}}))
- \frac{1}{2} \log \left| K_{\text{train,train}} + \sigma_{\text{noise}}^2 I \right|
- \frac{N_{\text{train}}}{2} \log(2\pi)
$$

We can find the optimal hyperparameters by **maximizing this log-marginal likelihood** with respect to $\theta$. This is typically a non-convex optimization problem and requires numerical methods.

While the marginal likelihood is analytically available, its dependence on the hyperparameters $\theta$ is generally non-linear, which is why we can't do analytic Gaussian inference over $\theta$ itself.

**Finding the hyperparameters by maximizing the marginal likelihood is a common approach in GP modeling and allows the model to adapt its properties (like smoothness and scale) to the observed data.**

# Summary of Understanding Gaussian Processes

In this post, we've deepened our understanding of **Gaussian Processes (GPs)** by focusing on the role of **kernels**:

---

- **GPs arise naturally from probabilistic parametric models with Gaussian priors** through "lazy evaluation."

- **Kernels** are functions that build symmetric positive semidefinite matrices and define the covariance between function values at different input points.

- **Every kernel corresponds to an inner product** in a potentially infinite-dimensional feature space.

- **Conditioning a GP on data results in a new GP (the posterior GP)** with updated mean and covariance functions.

- **Different kernel functions** (like Wiener, Cubic Spline, Matérn, Rational Quadratic) encode different assumptions about the function's properties (e.g., smoothness, periodicity).

- **New kernels can be constructed** by combining existing kernels through scaling, transformation of inputs, addition, and multiplication.

- **Kernel hyperparameters can be learned from data** by maximizing the marginal likelihood:
  
  $$
  \log p(y \mid X_{\text{train}}, \theta) = 
  - \frac{1}{2} (y - m(X_{\text{train}}))^\top (K_{\text{train,train}} + \sigma_{\text{noise}}^2 I)^{-1} (y - m(X_{\text{train}}))
  - \frac{1}{2} \log \left| K_{\text{train,train}} + \sigma_{\text{noise}}^2 I \right|
  - \frac{N_{\text{train}}}{2} \log(2\pi)
  $$

---

This exploration of kernels highlights their flexibility and power in modeling a wide variety of functions.  
By choosing or constructing appropriate kernels, we can tailor a GP to the specific characteristics of our data.