# Gaussian Process Regression: An Extensive Example

Welcome back to our **Gaussian Processes (GPs)** learning journey! In previous posts, we introduced GPs, explored various kernel functions, and understood how they define the properties of functions a GP can model.

In this post, we'll take a deeper dive into applying GPs to more complex, real-world scenarios. We'll focus on:

- **Modeling structured data using combinations of kernels (additive kernels).**
- **Learning kernel hyperparameters from data.**
- **Source separation:** using GPs to disentangle multiple underlying signals from a combined observation.

This lecture will provide a more extensive example, building on the theoretical foundations we've established.

---

### Key Topics

1. **Additive Kernels:**  
    Combine simple kernels to model complex data structures, e.g.,  
    $k_{\text{sum}}(x, x') = k_1(x, x') + k_2(x, x')$

2. **Hyperparameter Learning:**  
    Optimize kernel parameters by maximizing the marginal likelihood:  
    $$
    \log p(\mathbf{y} \mid X, \theta) = -\frac{1}{2} \mathbf{y}^\top (K_{XX} + \sigma_{\text{noise}}^2 I)^{-1} \mathbf{y}
    - \frac{1}{2} \log |K_{XX} + \sigma_{\text{noise}}^2 I|
    - \frac{N}{2} \log(2\pi)
    $$

3. **Source Separation:**  
    Decompose observed data $y$ into interpretable components, e.g.,  
    $$
    y = f_1(x) + f_2(x) + \text{noise}
    $$

---

By the end of this example, you'll see how GPs can be used for structured modeling, automatic hyperparameter tuning, and even for separating mixed signals in real-world time series data.

## A Real-World Dataset: Mauna Loa CO₂ Data

A classic example in time series analysis is the **NOAA Mauna Loa CO₂ dataset**. This dataset records the atmospheric carbon dioxide concentration (in parts per million, ppm) at the Mauna Loa Observatory in Hawaii, starting from the late 1950s.

If you look at the plot (similar to the one in the slides), you'll notice two prominent features:

- **A long-term upward trend:**  
    This reflects the increasing CO₂ concentration in the atmosphere due to human activities. The trend is roughly linear, but may include some non-linearities.

- **Annual seasonality:**  
    There's a clear oscillating pattern within each year, caused by seasonal changes in plant growth and decay. CO₂ levels drop during the Northern Hemisphere's growing season and rise during winter.

This dataset is a perfect candidate for Gaussian Processes because it exhibits clear structure that can be modeled by combining different kernels. We can think of the observed CO₂ concentration as a sum of a long-term trend function and a seasonal function, plus some noise:

$$
y(x) = f_{\text{trend}}(x) + f_{\text{seasonal}}(x) + \epsilon
$$

where  
- $f_{\text{trend}}(x)$ models the long-term trend,  
- $f_{\text{seasonal}}(x)$ models the annual seasonality,  
- $\epsilon$ is observational noise.

## Additive Kernels and Multi-Output GPs

One of the most powerful ways to build expressive Gaussian Processes (GPs) is by combining simpler kernels. As discussed previously, if $k_1(x, x')$ and $k_2(x, x')$ are valid kernels, then their sum

$$
k_{\text{sum}}(x, x') = k_1(x, x') + k_2(x, x')
$$

is also a valid kernel.

**Intuition:**  
If we have two independent Gaussian Processes, $f_1 \sim \mathcal{GP}(m_1, k_1)$ and $f_2 \sim \mathcal{GP}(m_2, k_2)$, then their sum $f = f_1 + f_2$ is also a Gaussian Process:

$$
f = f_1 + f_2 \sim \mathcal{GP}(m_1 + m_2,\, k_1 + k_2)
$$

This property is extremely useful for modeling data that is a combination of different underlying processes. For example, in the Mauna Loa CO$_2$ data, we can model the long-term trend with one kernel (e.g., a squared exponential or linear kernel) and the seasonality with another kernel (e.g., a periodic kernel). The overall covariance structure of the data is then the sum of these individual kernels.

---

This approach is related to **Multi-Output GPs** or **Composite Kernels**. While the general multi-output GP has a block-diagonal covariance matrix for $p(f_1, f_2)$, the simplest case—where outputs are sums of independent GPs—leads directly to additive kernels:

$$
\begin{align*}
p(f_1) &= \mathcal{GP}(f_1;\, 0,\, k_1) \\\\
p(f_2) &= \mathcal{GP}(f_2;\, 0,\, k_2)
\end{align*}
$$

The joint distribution of $f_1$ and $f_2$ is:

$$
\rho(f_1, f_2) = \mathcal{GP}\left(
\begin{bmatrix} f_1 \\ f_2 \end{bmatrix};
\begin{bmatrix} 0 \\ 0 \end{bmatrix},
\begin{bmatrix} k_1 & 0 \\\\ 0 & k_2 \end{bmatrix}
\right)
$$

If we define $f = f_1 + f_2$, its distribution is:

$$
p(f) = \mathcal{GP}(f;\, 0,\, k_1 + k_2)
$$

**Conclusion:**  
We can model complex functions by summing simpler, interpretable components, each defined by its own kernel.

In [11]:
import polars as pl
import plotly.express as px

data = pl.read_csv("co2_annmean_mlo.csv", truncate_ragged_lines=True)
data.head()

year,mean,unc
i64,f64,f64
1959,315.98,0.12
1960,316.91,0.12
1961,317.64,0.12
1962,318.45,0.12
1963,318.99,0.12


In [1]:
import jax.numpy as jnp
from typing import Callable

# --- Re-defining necessary kernel functions from previous lectures for completeness ---


def squared_exponential_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, sigma: float = 1.0, lengthscale: float = 1.0
) -> jnp.ndarray:
    """
    Computes the Squared Exponential (RBF) kernel matrix between two sets of points.
    Args:
        x1: First set of input points. Shape (N1, D).
        x2: Second set of input points. Shape (N2, D).
        sigma: Output variance (amplitude) hyperparameter.
        lengthscale: Length scale hyperparameter.
    Returns:
        The kernel matrix K, where K[i, j] = k(x1[i], x2[j]). Shape (N1, N2).
    """
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    sq_dist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    K = sigma**2 * jnp.exp(-0.5 * sq_dist / lengthscale**2)
    return K


# --- New Kernel: Periodic Kernel ---
def periodic_kernel(
    x1: jnp.ndarray,
    x2: jnp.ndarray,
    sigma: float = 1.0,
    lengthscale: float = 1.0,
    period: float = 1.0,
) -> jnp.ndarray:
    """
    Computes the Periodic kernel matrix.
    This kernel is suitable for modeling cyclical patterns.
    Args:
        x1: First set of input points. Shape (N1, D). Assumes D=1 for simplicity.
        x2: Second set of input points. Shape (N2, D). Assumes D=1 for simplicity.
        sigma: Output variance (amplitude) hyperparameter.
        lengthscale: Length scale hyperparameter.
        period: The period of the cyclical pattern.
    Returns:
        The kernel matrix K. Shape (N1, N2).
    """
    x1 = jnp.atleast_2d(x1).squeeze()  # Ensure 1D for distance calculation
    x2 = jnp.atleast_2d(x2).squeeze()  # Ensure 1D for distance calculation

    # Compute the absolute difference between all pairs of points
    diff = jnp.abs(x1[:, None] - x2[None, :])

    # Compute the periodic term
    sin_term = jnp.sin(jnp.pi * diff / period) ** 2

    # Compute the kernel matrix
    K = sigma**2 * jnp.exp(-2.0 * sin_term / lengthscale**2)
    return K


# --- New Kernel: Linear Kernel ---
def linear_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, c: float = 0.0, sigma_b: float = 1.0
) -> jnp.ndarray:
    """
    Computes the Linear kernel matrix.
    This kernel models linear trends.
    Args:
        x1: First set of input points. Shape (N1, D).
        x2: Second set of input points. Shape (N2, D).
        c: A constant offset, often set to 0.
        sigma_b: Variance of the slope.
    Returns:
        The kernel matrix K. Shape (N1, N2).
    """
    x1 = jnp.atleast_2d(x1)
    x2 = jnp.atleast_2d(x2)
    # The linear kernel is (x1 - c) @ (x2 - c).T
    K = sigma_b**2 * jnp.dot((x1 - c), (x2 - c).T)
    return K


# --- Additive Kernel Combiner ---
def add_kernels(*kernel_funcs: Callable) -> Callable:
    """
    Combines multiple kernel functions into an additive kernel.
    The resulting kernel function will sum the outputs of the input kernel functions.
    Args:
        *kernel_funcs: Variable number of kernel functions to be added.
                      Each kernel function should take (x1, x2) as arguments.
    Returns:
        A new callable kernel function that computes the sum of the input kernels.
    """

    def combined_kernel(x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:
        K_sum = jnp.zeros((x1.shape[0], x2.shape[0]))
        for k_func in kernel_funcs:
            K_sum += k_func(x1, x2)
        return K_sum

    return combined_kernel

In [2]:
import jax.numpy as jnp
from jax.scipy.linalg import solve  # Use JAX's solve for numerical stability
from typing import Callable


def gp_predict(
    X_train: jnp.ndarray,
    y_train: jnp.ndarray,
    X_test: jnp.ndarray,
    mean_func: Callable[[jnp.ndarray], jnp.ndarray],
    kernel_func: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
    noise_variance: float = 1e-6,  # Small value for numerical stability if noise is zero
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Performs Gaussian Process regression prediction.
    Args:
        X_train: Training input points. Shape (N_train, D).
        y_train: Training output values. Shape (N_train,).
        X_test: Test input points. Shape (N_test, D).
        mean_func: Mean function.
        kernel_func: Kernel function.
        noise_variance: Variance of the observational noise.
    Returns:
        A tuple containing:
            mu_pred: Predictive mean at test points. Shape (N_test,).
            Sigma_pred: Predictive covariance matrix at test points. Shape (N_test, N_test).
    """
    X_train = jnp.atleast_2d(X_train)
    X_test = jnp.atleast_2d(X_test)
    y_train = jnp.atleast_1d(y_train)

    K_train_train = kernel_func(X_train, X_train) + noise_variance * jnp.eye(
        X_train.shape[0]
    )
    K_test_train = kernel_func(X_test, X_train)
    K_test_test = kernel_func(X_test, X_test)

    # Compute the inverse term (K_train_train + sigma_noise^2 I)^-1 (y_train - m(X_train))
    K_train_train_inv_y_diff = solve(K_train_train, y_train - mean_func(X_train))

    # Compute predictive mean
    mu_pred = mean_func(X_test) + jnp.dot(K_test_train, K_train_train_inv_y_diff)

    # Compute the term (K_train_train + sigma_noise^2 I)^-1 K_test_train^T
    K_train_train_inv_K_test_train_T = solve(K_train_train, K_test_train.T)

    # Compute predictive covariance
    Sigma_pred = K_test_test - jnp.dot(K_test_train, K_train_train_inv_K_test_train_T)

    return mu_pred, Sigma_pred


In [21]:
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt

# Set a random seed for reproducibility
key = random.PRNGKey(42)

# --- Simulate Mauna Loa-like Data ---
# We'll simulate data with a linear trend and a seasonal component
num_points = 100
# Years from 1958 to 2023 (similar to Mauna Loa range)
X_sim = jnp.linspace(1958.0, 2023.0, num_points)[:, None]  # Inputs (years)

# True underlying components
true_trend_slope = 1  # ppm per year
true_trend_intercept = 310.0  # ppm at year 1958
true_seasonal_amplitude = 5  # ppm
true_seasonal_period = 1.0  # years

true_trend = true_trend_intercept + true_trend_slope * (X_sim - X_sim.min())
true_seasonal = true_seasonal_amplitude * jnp.sin(
    2 * jnp.pi * (X_sim - X_sim.min()) / true_seasonal_period
)

# Observed data is sum of components plus noise
true_noise_variance = 0.5**2
y_sim = (true_trend + true_seasonal).squeeze() + random.normal(
    key, shape=(num_points,)
) * jnp.sqrt(true_noise_variance)

In [24]:
# --- Define GP Components with Additive Kernel ---
# 1. Linear Kernel for the trend
# We'll center the linear kernel around the start year to make its interpretation easier
linear_k_func = lambda x1, x2: linear_kernel(x1, x2, c=X_sim.min(), sigma_b=1.0)

# 2. Periodic Kernel for seasonality
# Assuming an annual period (1.0 year)
periodic_k_func = lambda x1, x2: periodic_kernel(
    x1, x2, sigma=1.0, lengthscale=0.5, period=1.0
)

# Combine them into an additive kernel
combined_kernel = add_kernels(linear_k_func, periodic_k_func)

# Zero mean function (common for GPs, especially when the trend is modeled by a kernel)
zero_mean_func = lambda x: jnp.zeros(x.shape[0])

In [25]:
# --- Perform GP Regression ---
# Use all simulated data as training data for this example
X_train = X_sim
y_train = y_sim

# Create test points for prediction
X_test = jnp.linspace(1950.0, 2030.0, 200)[:, None]

# Perform prediction using the combined kernel
mu_pred, Sigma_pred = gp_predict(
    X_train,
    y_train,
    X_test,
    zero_mean_func,
    combined_kernel,
    noise_variance=true_noise_variance,  # Use the true noise for this demo
)

predictive_std = jnp.sqrt(jnp.diag(Sigma_pred))


In [26]:
# --- Plotting Results ---
import plotly.graph_objects as go

fig = go.Figure()

# Scatter plot for training data
fig.add_trace(
    go.Scatter(
        x=X_train[:, 0],
        y=y_train,
        mode="markers",
        name="Simulated Data",
        marker=dict(size=6, opacity=0.7),
    )
)

# Predictive mean line
fig.add_trace(
    go.Scatter(
        x=X_test[:, 0],
        y=mu_pred,
        mode="lines",
        name="GP Predictive Mean",
        line=dict(color="red"),
    )
)

# 95% confidence interval (shaded area)
fig.add_trace(
    go.Scatter(
        x=jnp.concatenate([X_test[:, 0], X_test[::-1, 0]]),
        y=jnp.concatenate(
            [mu_pred - 2 * predictive_std, (mu_pred + 2 * predictive_std)[::-1]]
        ),
        fill="toself",
        fillcolor="rgba(255,0,0,0.2)",
        line=dict(color="rgba(255,255,255,0)"),
        hoverinfo="skip",
        showlegend=True,
        name="95% Confidence Interval",
    )
)

fig.update_layout(
    title="Gaussian Process Regression with Additive Kernel (Simulated Mauna Loa Data)",
    xaxis_title="Year",
    yaxis_title="CO2 Concentration [ppm]",
    legend=dict(x=0.01, y=0.99),
    template="simple_white",
)

fig.show()

# --- Optional: Plotting individual components (conceptual) ---
# To plot individual components, you would typically need to define a separate GP for each
# component and then condition them independently or jointly.
# For example, to get the mean of the trend component:
# mu_trend, _ = gp_predict(X_train, y_train, X_test, zero_mean_func, linear_k_func, noise_variance=true_noise_variance)
# plt.figure(figsize=(10, 4))
# plt.plot(X_test[:, 0], mu_trend, label='Estimated Trend Component')
# plt.title('Estimated Trend Component')
# plt.grid(True)
# plt.legend()
# plt.show()
# This is a bit more involved as the 'y_train' is a sum, so directly predicting with just one kernel
# using the full y_train is not strictly correct for isolating the component.
# Proper component separation would involve a more advanced multi-output GP setup or specific inference techniques.


## Kernel Learning: Optimizing Hyperparameters

In our previous examples, we manually set the hyperparameters of our kernels (such as $\sigma$, $\text{lengthscale}$, $\text{period}$, $\sigma_{\text{noise}}^2$). However, in real applications, we need to learn these values from the data. This process is called **hyperparameter optimization** or **kernel learning**.

The standard approach for learning hyperparameters in GPs is to maximize the **marginal likelihood** (also known as the model evidence). The marginal likelihood $p(\mathbf{y} \mid X, \theta)$ is the probability of observing the training data $\mathbf{y}$ given the inputs $X$ and the hyperparameters $\theta$, after marginalizing out the latent function $f$:

$$
p(\mathbf{y} \mid X, \theta) = \int p(\mathbf{y} \mid f, X, \theta)\, p(f \mid X, \theta)\, df
$$

For Gaussian Processes with a Gaussian likelihood, this integral has a convenient analytical solution:

$$
p(\mathbf{y} \mid X, \theta) = \mathcal{N}(\mathbf{y};\, m(X),\, K_{XX} + \sigma_{\text{noise}}^2 I)
$$

where $m(X)$ is the mean function evaluated at the training inputs, $K_{XX}$ is the kernel matrix evaluated at training inputs with hyperparameters $\theta$, and $\sigma_{\text{noise}}^2 I$ is the noise covariance.

To find the optimal hyperparameters $\hat{\theta}$, we maximize this marginal likelihood, or equivalently, minimize its negative logarithm (the **negative log-marginal likelihood**):

$$
\hat{\theta} = \arg\max_{\theta}\, p(\mathbf{y} \mid X, \theta) = \arg\min_{\theta}\, -\log p(\mathbf{y} \mid X, \theta)
$$

The negative log-marginal likelihood (NLML) for a zero-mean GP is given by:

$$
-\log p(\mathbf{y} \mid X, \theta) = \frac{1}{2} \mathbf{y}^\top (K_{XX} + \sigma_{\text{noise}}^2 I)^{-1} \mathbf{y}
+ \frac{1}{2} \log \left| K_{XX} + \sigma_{\text{noise}}^2 I \right|
+ \frac{N}{2} \log(2\pi)
$$

where $N$ is the number of training data points.

This objective function is generally non-convex, so we typically use gradient-based optimization algorithms (like gradient descent or more advanced optimizers) to find a good set of hyperparameters. This requires computing the gradients of the NLML with respect to each hyperparameter. Fortunately, modern deep learning frameworks like JAX provide **automatic differentiation**, which makes computing these gradients straightforward.

The slides illustrate the computational graph for calculating the loss $L(\theta)$ and how automatic differentiation (forward or backward mode) can be used to compute gradients. This is a powerful concept that underpins much of modern machine learning.

In [101]:
import jax.numpy as jnp
import jax
from jax import jit, grad
from jax.scipy.linalg import solve, cholesky
import optax  # A JAX-based optimization library
import matplotlib.pyplot as plt
from functools import partial

# --- Re-using previous kernel definitions for this example ---
# (squared_exponential_kernel, periodic_kernel, linear_kernel, add_kernels, gp_predict)
# Assume these are defined in the notebook context.


# --- Log-Marginal Likelihood Function ---
@partial(jit, static_argnames=["kernel_types", "mean_func", "base_kernel_funcs"])
def negative_log_marginal_likelihood(
    params: dict,
    kernel_types: tuple,
    X_train: jnp.ndarray,
    y_train: jnp.ndarray,
    mean_func: Callable[[jnp.ndarray], jnp.ndarray],
    base_kernel_funcs: tuple[Callable],
) -> float:
    """
    Computes the negative log-marginal likelihood (NLML) for a GP.

    Args:
        params: A dictionary of hyperparameters. Expected keys:
                'noise_variance': Scalar for observational noise.
                'kernel_params': A list of dictionaries, one for each base kernel,
                                 containing its specific hyperparameters (e.g., 'sigma', 'lengthscale', 'period', 'c', 'sigma_b').
        X_train: Training input points.
        y_train: Training output values.
        mean_func: Mean function.
        base_kernel_funcs: A list of base kernel functions (e.g., [linear_kernel, periodic_kernel]).

    Returns:
        The negative log-marginal likelihood.
    """
    # Use softplus to ensure positive hyperparameters where necessary
    noise_variance = jax.nn.softplus(params["noise_variance_raw"])

    # Construct the combined kernel function with current hyperparameters
    def current_combined_kernel(x1, x2):
        K_sum = jnp.zeros((x1.shape[0], x2.shape[0]))
        for i, k_func_name in enumerate(kernel_types):
            k_params = params["kernel_params"][i]

            if k_func_name == "linear_kernel":
                # c can be negative, sigma_b must be positive
                K_sum += linear_kernel(
                    x1,
                    x2,
                    c=k_params["c"],
                    sigma_b=jax.nn.softplus(k_params["sigma_b_raw"]),
                )
            elif k_func_name == "periodic_kernel":
                # sigma, lengthscale, period must be positive
                K_sum += periodic_kernel(
                    x1,
                    x2,
                    sigma=jax.nn.softplus(k_params["sigma_raw"]),
                    lengthscale=jax.nn.softplus(k_params["lengthscale_raw"]),
                    period=jax.nn.softplus(k_params["period_raw"]),
                )
            elif k_func_name == "squared_exponential_kernel":
                # sigma, lengthscale must be positive
                K_sum += squared_exponential_kernel(
                    x1,
                    x2,
                    sigma=jax.nn.softplus(k_params["sigma_raw"]),
                    lengthscale=jax.nn.softplus(k_params["lengthscale_raw"]),
                )
            # Add more kernel types as needed
        return K_sum

    K_XX = current_combined_kernel(X_train, X_train)
    K_XX_noisy = K_XX + noise_variance * jnp.eye(X_train.shape[0])

    # Compute Cholesky decomposition for stable determinant and inverse
    try:
        L = cholesky(K_XX_noisy, lower=True)
    except Exception as e:
        # Handle cases where K_XX_noisy might not be positive definite (e.g., during bad hyperparams)
        # Return a very large value to push optimizer away from this region
        # print(f"Cholesky decomposition failed: {e}. Returning large NLML.") # Uncomment for debugging
        return jnp.array(1e10)  # Return a large value to penalize invalid parameters

    alpha = solve(L.T, solve(L, y_train - mean_func(X_train)))

    # Log determinant using Cholesky factor: log|A| = 2 * sum(log(diag(L)))
    log_det_K = 2.0 * jnp.sum(jnp.log(jnp.diag(L)))

    # Negative Log-Marginal Likelihood (NLML)
    nlml = (
        0.5 * jnp.dot(y_train - mean_func(X_train), alpha)
        + 0.5 * log_det_K
        + 0.5 * X_train.shape[0] * jnp.log(2 * jnp.pi)
    )

    return nlml

In [None]:
# --- Optimization Loop ---
def optimize_hyperparameters(
    X_train: jnp.ndarray,
    y_train: jnp.ndarray,
    mean_func: Callable[[jnp.ndarray], jnp.ndarray],
    base_kernel_funcs: list[Callable],
    initial_params: dict,
    num_steps: int = 1000,
    learning_rate: float = 0.01,
) -> dict:
    """
    Optimizes GP hyperparameters using gradient descent.

    Args:
        X_train: Training input points.
        y_train: Training output values.
        mean_func: Mean function.
        base_kernel_funcs: A list of base kernel functions.
        initial_params: Initial dictionary of hyperparameters.
        num_steps: Number of optimization steps.
        learning_rate: Learning rate for the optimizer.

    Returns:
        Optimized hyperparameters.
    """
    optimizer = optax.adam(learning_rate)
    kernel_types = initial_params.pop("kernel_types")
    opt_state = optimizer.init(initial_params)

    # Define the loss function to be optimized
    loss_fn = lambda params: negative_log_marginal_likelihood(
        params, kernel_types, X_train, y_train, mean_func, base_kernel_funcs
    )

    # Get the gradient function using JAX's grad
    loss_grad_fn = grad(loss_fn)

    params = initial_params
    losses = []

    for step in range(num_steps):
        loss_value = loss_fn(params)
        losses.append(loss_value)

        grads = loss_grad_fn(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        if step % 100 == 0:
            print(f"Step {step}, Loss: {loss_value:.4f}")
            # print(f"Current params: {jax.tree_map(jax.nn.softplus, params)}") # Uncomment to see positive values

    print(f"Final Loss: {loss_fn(params):.4f}")
    return params, losses


In [103]:
# Example usage with simulated Mauna Loa data:
# (Requires X_sim, y_sim, linear_kernel, periodic_kernel, squared_exponential_kernel, zero_mean_func from above)

# Define the base kernel functions to be used in the additive model
# We'll use a linear kernel for the trend and a periodic kernel for seasonality
base_kernels_for_optimization = (linear_kernel, periodic_kernel)
kernel_type_names = ("linear_kernel", "periodic_kernel")  # Names to map to params

# Initial hyperparameters (raw values for softplus)
# Start with some reasonable initial guesses for the raw parameters
initial_params = {
    "noise_variance_raw": jnp.array(
        jnp.log(jnp.exp(0.5**2) - 1.0)
    ),  # Initial noise variance (e.g., 0.5^2)
    "kernel_types": kernel_type_names,  # List of kernel names for lookup
    "kernel_params": [
        # Linear kernel params
        {
            "c": jnp.array(1958.0),
            "sigma_b_raw": jnp.array(jnp.log(jnp.exp(0.01) - 1.0)),
        },  # Initial slope variance (e.g., 0.01)
        # Periodic kernel params
        {
            "sigma_raw": jnp.array(
                jnp.log(jnp.exp(2.5) - 1.0)
            ),  # Initial amplitude (e.g., 2.5)
            "lengthscale_raw": jnp.array(
                jnp.log(jnp.exp(0.5) - 1.0)
            ),  # Initial lengthscale (e.g., 0.5)
            "period_raw": jnp.array(jnp.log(jnp.exp(1.0) - 1.0)),
        },  # Initial period (e.g., 1.0)
    ],
}


In [104]:
# Run optimization
optimized_raw_params, losses = optimize_hyperparameters(
    X_sim,
    y_sim,
    zero_mean_func,
    base_kernels_for_optimization,
    initial_params,
    num_steps=2000,
    learning_rate=0.01,
)


Step 0, Loss: 42144.8984
Step 100, Loss: 12721.4512
Step 200, Loss: 7400.3374
Step 300, Loss: 5400.9375
Step 400, Loss: 4330.1328
Step 500, Loss: 3649.0093
Step 600, Loss: 3170.9727
Step 700, Loss: 2813.5593
Step 800, Loss: 2534.2937
Step 900, Loss: 2308.9250
Step 1000, Loss: 2122.4521
Step 1100, Loss: 1965.2068
Step 1200, Loss: 1830.4062
Step 1300, Loss: 1713.3069
Step 1400, Loss: 1610.4957
Step 1500, Loss: 1519.4294
Step 1600, Loss: 1437.7784
Step 1700, Loss: 1365.2867
Step 1800, Loss: 1297.8950
Step 1900, Loss: 1237.5189
Final Loss: 1181.9966


In [109]:
# Convert raw optimized parameters back to their positive values for interpretation
optimized_params = {
    "noise_variance": jax.nn.softplus(optimized_raw_params["noise_variance_raw"]),
    "kernel_params": [],
}
for i, k_type in enumerate(kernel_type_names):
    k_raw_params = optimized_raw_params["kernel_params"][i]
    if k_type == "linear_kernel":
        optimized_params["kernel_params"].append(
            {
                "c": k_raw_params["c"],
                "sigma_b": jax.nn.softplus(k_raw_params["sigma_b_raw"]),
            }
        )
    elif k_type == "periodic_kernel":
        optimized_params["kernel_params"].append(
            {
                "sigma": jax.nn.softplus(k_raw_params["sigma_raw"]),
                "lengthscale": jax.nn.softplus(k_raw_params["lengthscale_raw"]),
                "period": jax.nn.softplus(k_raw_params["period_raw"]),
            }
        )
    elif k_type == "squared_exponential_kernel":
        optimized_params["kernel_params"].append(
            {
                "sigma": jax.nn.softplus(k_raw_params["sigma_raw"]),
                "lengthscale": jax.nn.softplus(k_raw_params["lengthscale_raw"]),
            }
        )

print("\nOptimized Hyperparameters (positive values):")
print(optimized_params)



Optimized Hyperparameters (positive values):
{'noise_variance': Array(0.20302267, dtype=float32), 'kernel_params': [{'c': Array(1950.8862, dtype=float32), 'sigma_b': Array(0.16462138, dtype=float32)}, {'sigma': Array(7.023222, dtype=float32), 'lengthscale': Array(2.5246286, dtype=float32), 'period': Array(0.99984074, dtype=float32)}]}


In [115]:
# Plot the optimization loss
import plotly.graph_objects as go

fig_loss = go.Figure()
fig_loss.add_trace(go.Scatter(y=[float(l) for l in losses], mode="lines", name="NLML"))
fig_loss.update_layout(
    title="NLML during Hyperparameter Optimization",
    xaxis_title="Optimization Step",
    yaxis_title="Negative Log-Marginal Likelihood (NLML)",
    template="simple_white",
)
fig_loss.show()


# --- Perform GP Prediction with Optimized Hyperparameters ---
# Construct the optimized combined kernel function
def optimized_combined_kernel(x1, x2):
    K_sum = jnp.zeros((x1.shape[0], x2.shape[0]))
    for i, k_type in enumerate(kernel_type_names):
        k_opt_params = optimized_params["kernel_params"][
            i
        ]  # Use already converted params
        if k_type == "linear_kernel":
            K_sum += linear_kernel(
                x1, x2, c=k_opt_params["c"], sigma_b=k_opt_params["sigma_b"]
            )
        elif k_type == "periodic_kernel":
            K_sum += periodic_kernel(
                x1,
                x2,
                sigma=k_opt_params["sigma"],
                lengthscale=k_opt_params["lengthscale"],
                period=k_opt_params["period"],
            )
        elif k_type == "squared_exponential_kernel":
            K_sum += squared_exponential_kernel(
                x1,
                x2,
                sigma=k_opt_params["sigma"],
                lengthscale=k_opt_params["lengthscale"],
            )
    return K_sum


mu_opt_pred, Sigma_opt_pred = gp_predict(
    X_sim,
    y_sim,
    X_test,
    zero_mean_func,
    optimized_combined_kernel,
    noise_variance=optimized_params["noise_variance"],
)

predictive_opt_std = jnp.sqrt(jnp.diag(Sigma_opt_pred))

# --- Plotting Results with Optimized Hyperparameters using Plotly ---
import plotly.graph_objects as go

fig_opt = go.Figure()

# Scatter plot for training data
fig_opt.add_trace(
    go.Scatter(
        x=X_sim[:, 0],
        y=y_sim,
        mode="markers",
        name="Simulated Data",
        marker=dict(size=6, opacity=0.7),
    )
)

# Predictive mean line (optimized)
fig_opt.add_trace(
    go.Scatter(
        x=X_test[:, 0],
        y=mu_opt_pred,
        mode="lines",
        name="GP Predictive Mean (Optimized)",
        line=dict(color="green"),
    )
)

# 95% confidence interval (shaded area, optimized)
fig_opt.add_trace(
    go.Scatter(
        x=jnp.concatenate([X_test[:, 0], X_test[::-1, 0]]),
        y=jnp.concatenate(
            [
                mu_opt_pred - 2 * predictive_opt_std,
                (mu_opt_pred + 2 * predictive_opt_std)[::-1],
            ]
        ),
        fill="toself",
        fillcolor="rgba(0,128,0,0.2)",
        line=dict(color="rgba(255,255,255,0)"),
        hoverinfo="skip",
        showlegend=True,
        name="95% CI (Optimized)",
    )
)

fig_opt.update_layout(
    title="Gaussian Process Regression with Optimized Additive Kernel",
    xaxis_title="Year",
    yaxis_title="CO2 Concentration [ppm]",
    legend=dict(x=0.01, y=0.99),
    template="simple_white",
)

fig_opt.show()


## Source Separation with Gaussian Processes

One of the most exciting applications of structured GP models is **source separation**—extracting individual underlying signals from observed data where these signals are linearly mixed. The Mauna Loa CO₂ data is a prime example: we observe the total CO₂, but we want to separate it into its long-term trend and seasonal components.

The general principle comes from the properties of conditional Gaussian distributions. If we have a joint Gaussian distribution $p(\mathbf{x})$ and a linear observation model $p(\mathbf{y} \mid \mathbf{x}) = \mathcal{N}(\mathbf{y};\, A^\top \mathbf{x} + \mathbf{b},\, \Lambda)$, then the posterior $p(B^\top \mathbf{x} + \mathbf{c} \mid \mathbf{y})$ is also Gaussian.

In our GP context, we can think of our observed data $\mathbf{y}$ as a linear combination of latent functions $f_1, f_2, \ldots$:

$$
\mathbf{y} = \sum_i f_i(X) + \text{noise}
$$

If we assume $f_i \sim \mathcal{GP}(0, k_i)$ are independent GPs, then their sum $f = \sum_i f_i$ is a GP with kernel $k = \sum_i k_i$.

A specific case for two functions $f_1$ and $f_2$ is shown in the slides, where $\mathbf{y} = [1\ 1] \begin{bmatrix} f_1 \\ f_2 \end{bmatrix}$. The goal is to find the posterior distribution of $f_1$ given $\mathbf{y}$.

The formula for $p(f_1(\cdot) \mid \mathbf{y})$ is:

$$
p(f_1(\cdot) \mid \mathbf{y}) = \mathcal{GP}\left(
f_1;\,
k_{1, \cdot X} (K_{1, XX} + K_{2, XX} + \Lambda)^{-1} \mathbf{y},\,
k_{1, \cdot \circ} - k_{1, \cdot X} (K_{1, XX} + K_{2, XX} + \Lambda)^{-1} K_{2, X \circ}
\right)
$$

Where:

- $k_{1, \cdot X}$ is the covariance vector between a new point $\cdot$ and the training points $X$, using kernel $k_1$.
- $K_{1, XX}$ and $K_{2, XX}$ are kernel matrices for $k_1$ and $k_2$ at training points $X$.
- $\Lambda$ is the noise covariance matrix (often $\sigma_\text{noise}^2 I$).
- $\mathbf{y}$ is the observed data.
- $k_{1, \cdot \circ}$ is the kernel $k_1$ evaluated between two new points $\cdot$ and $\circ$.
- $K_{2, X \circ}$ is the cross-covariance between training points $X$ and a new point $\circ$, using kernel $k_2$.

---

**Intuition for Source Separation:**  
When we observe $\mathbf{y}$ as a sum of $f_1$ and $f_2$, the GP framework allows us to "attribute" parts of the observed variance to each component. The posterior mean of $f_1$ will be influenced by how well $k_1$ can explain the observed data, while also considering what $k_2$ is capable of explaining. The posterior covariance of $f_1$ reflects the remaining uncertainty about $f_1$ after accounting for $\mathbf{y}$ and the presence of $f_2$.

While explicitly implementing the full source separation formula can be intricate due to the matrix algebra, the core idea is that by defining a composite kernel (like our additive kernel for Mauna Loa), the GP implicitly performs this separation. The predictive mean of the combined GP gives us the best estimate of the sum $f_1 + f_2$. To get individual components, one often needs to project the posterior onto the individual kernel spaces, or use more advanced multi-output GP formulations.

> **Key takeaway:**  
> The additive kernel implicitly models the data as a sum of functions, and the GP framework provides a principled way to infer these components—even if we don't explicitly separate them in the prediction step.

# Summary of Gaussian Process Regression: Extensive Example

In this extensive example, we've explored how **Gaussian Processes (GPs)** can be applied to real-world time series data, such as the Mauna Loa CO$_2$ measurements. Here are the key takeaways:

---

### 1. Structured Modeling with Additive Kernels

- **Additive kernels** allow us to model complex data by combining simpler, interpretable kernel components.
    - *Example*: Use a **linear kernel** for the long-term trend and a **periodic kernel** for seasonality.
- This approach enables the GP to capture different aspects of the underlying process:
    $$
    k_{\text{sum}}(x, x') = k_{\text{trend}}(x, x') + k_{\text{seasonal}}(x, x')
    $$

---

### 2. Kernel Learning is Crucial

- **Manual hyperparameter selection** is rarely optimal.
- **Maximizing the marginal likelihood** provides a principled way to learn kernel parameters from data:
    $$
    \log p(\mathbf{y} \mid X, \theta) = -\frac{1}{2} \mathbf{y}^\top (K_{XX} + \sigma_{\text{noise}}^2 I)^{-1} \mathbf{y}
    - \frac{1}{2} \log |K_{XX} + \sigma_{\text{noise}}^2 I|
    - \frac{N}{2} \log(2\pi)
    $$
- **Automatic differentiation** (e.g., with JAX) makes this optimization feasible and efficient.

---

### 3. Source Separation Capability

- GPs with **composite kernels** inherently provide a framework for disentangling underlying signals that are linearly mixed in the observations.
- While explicit formulas can be complex, the GP's ability to model the covariance structure enables this separation:
    $$
    y(x) = f_{\text{trend}}(x) + f_{\text{seasonal}}(x) + \epsilon
    $$
- The GP framework allows us to infer each component's contribution to the observed data.

---

> **Conclusion:**  
> While GPs are powerful, their effectiveness relies on incorporating prior knowledge about the data's structure through the design of appropriate kernels. The ability to build such structured probabilistic models is a highly valued skill in machine learning.