# Gaussian Process Regression: An Extensive Example

Welcome back to our **Gaussian Processes (GPs)** learning journey! In previous posts, we introduced GPs, explored various kernel functions, and understood how they define the properties of functions a GP can model.

In this post, we'll take a deeper dive into applying GPs to more complex, real-world scenarios. We'll focus on:

- **Modeling structured data using combinations of kernels (additive kernels).**
- **Learning kernel hyperparameters from data.**
- **Source separation:** using GPs to disentangle multiple underlying signals from a combined observation.

This lecture will provide a more extensive example, building on the theoretical foundations we've established.

---

### Key Topics

1. **Additive Kernels:**  
    Combine simple kernels to model complex data structures, e.g.,  
    $k_{\text{sum}}(x, x') = k_1(x, x') + k_2(x, x')$

2. **Hyperparameter Learning:**  
    Optimize kernel parameters by maximizing the marginal likelihood:  
    $$
    \log p(\mathbf{y} \mid X, \theta) = -\frac{1}{2} \mathbf{y}^\top (K_{XX} + \sigma_{\text{noise}}^2 I)^{-1} \mathbf{y}
    - \frac{1}{2} \log |K_{XX} + \sigma_{\text{noise}}^2 I|
    - \frac{N}{2} \log(2\pi)
    $$

3. **Source Separation:**  
    Decompose observed data $y$ into interpretable components, e.g.,  
    $$
    y = f_1(x) + f_2(x) + \text{noise}
    $$

---

By the end of this example, you'll see how GPs can be used for structured modeling, automatic hyperparameter tuning, and even for separating mixed signals in real-world time series data.

## A Real-World Dataset: Mauna Loa CO₂ Data

A classic example in time series analysis is the **NOAA Mauna Loa CO₂ dataset**. This dataset records the atmospheric carbon dioxide concentration (in parts per million, ppm) at the Mauna Loa Observatory in Hawaii, starting from the late 1950s.

If you look at the plot (similar to the one in the slides), you'll notice two prominent features:

- **A long-term upward trend:**  
    This reflects the increasing CO₂ concentration in the atmosphere due to human activities. The trend is roughly linear, but may include some non-linearities.

- **Annual seasonality:**  
    There's a clear oscillating pattern within each year, caused by seasonal changes in plant growth and decay. CO₂ levels drop during the Northern Hemisphere's growing season and rise during winter.

This dataset is a perfect candidate for Gaussian Processes because it exhibits clear structure that can be modeled by combining different kernels. We can think of the observed CO₂ concentration as a sum of a long-term trend function and a seasonal function, plus some noise:

$$
y(x) = f_{\text{trend}}(x) + f_{\text{seasonal}}(x) + \epsilon
$$

where  
- $f_{\text{trend}}(x)$ models the long-term trend,  
- $f_{\text{seasonal}}(x)$ models the annual seasonality,  
- $\epsilon$ is observational noise.

## Additive Kernels and Multi-Output GPs

One of the most powerful ways to build expressive Gaussian Processes (GPs) is by combining simpler kernels. As discussed previously, if $k_1(x, x')$ and $k_2(x, x')$ are valid kernels, then their sum

$$
k_{\text{sum}}(x, x') = k_1(x, x') + k_2(x, x')
$$

is also a valid kernel.

**Intuition:**  
If we have two independent Gaussian Processes, $f_1 \sim \mathcal{GP}(m_1, k_1)$ and $f_2 \sim \mathcal{GP}(m_2, k_2)$, then their sum $f = f_1 + f_2$ is also a Gaussian Process:

$$
f = f_1 + f_2 \sim \mathcal{GP}(m_1 + m_2,\, k_1 + k_2)
$$

This property is extremely useful for modeling data that is a combination of different underlying processes. For example, in the Mauna Loa CO$_2$ data, we can model the long-term trend with one kernel (e.g., a squared exponential or linear kernel) and the seasonality with another kernel (e.g., a periodic kernel). The overall covariance structure of the data is then the sum of these individual kernels.

---

This approach is related to **Multi-Output GPs** or **Composite Kernels**. While the general multi-output GP has a block-diagonal covariance matrix for $p(f_1, f_2)$, the simplest case—where outputs are sums of independent GPs—leads directly to additive kernels:

$$
\begin{align*}
p(f_1) &= \mathcal{GP}(f_1;\, 0,\, k_1) \\\\
p(f_2) &= \mathcal{GP}(f_2;\, 0,\, k_2)
\end{align*}
$$

The joint distribution of $f_1$ and $f_2$ is:

$$
\rho(f_1, f_2) = \mathcal{GP}\left(
\begin{bmatrix} f_1 \\ f_2 \end{bmatrix};
\begin{bmatrix} 0 \\ 0 \end{bmatrix},
\begin{bmatrix} k_1 & 0 \\\\ 0 & k_2 \end{bmatrix}
\right)
$$

If we define $f = f_1 + f_2$, its distribution is:

$$
p(f) = \mathcal{GP}(f;\, 0,\, k_1 + k_2)
$$

**Conclusion:**  
We can model complex functions by summing simpler, interpretable components, each defined by its own kernel.

In [None]:
%load_ext autoreload
%autoreload 2

In [26]:
import polars as pl
import plotly.express as px
import jax.numpy as jnp
from gaussians import *
import plotly.graph_objects as go

data = pl.read_csv("co2_mm_mlo.csv")
X = jnp.asarray(data.get_column("decimal date").to_numpy())[:, None]
Y = jnp.asarray(data.get_column("average").to_numpy())
N = X.shape[0]
sigma = 0.1
x = jnp.linspace(1930, 2053, 2000)[:, None]
data.head()

year,month,decimal date,average,deseasonalized,ndays,sdev,unc
i64,i64,f64,f64,f64,i64,f64,f64
1958,3,1958.2027,315.71,314.44,-1,-9.99,-0.99
1958,4,1958.2877,317.45,315.16,-1,-9.99,-0.99
1958,5,1958.3699,317.51,314.69,-1,-9.99,-0.99
1958,6,1958.4548,317.27,315.15,-1,-9.99,-0.99
1958,7,1958.537,315.87,315.2,-1,-9.99,-0.99


In [5]:
px.line(data, x="decimal date", y="average")

In [38]:
def plot_prior_posterior(prior, posterior, x, X, Y, num_samples=5):
    fig = go.Figure()

    # Plot the observed data
    fig.add_trace(
        go.Scatter(
            x=X[:, 0],
            y=Y,
            mode="markers",
            name="Observed",
            marker=dict(color="black", size=4),
        )
    )
    # Plot the prior mean
    fig.add_trace(
        go.Scatter(
            x=x[:, 0],
            y=prior.mu,
            mode="lines",
            name="Linear prior",
            line=dict(color="gray"),
        )
    )

    # Plot the 2-sigma confidence interval
    fig.add_trace(
        go.Scatter(
            x=jnp.concatenate([x[:, 0], x[::-1, 0]]),
            y=jnp.concatenate(
                [
                    prior.mu - 2 * prior.std,
                    (prior.mu + 2 * prior.std)[::-1],
                ]
            ),
            fill="toself",
            fillcolor="rgba(128,128,128,0.2)",
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
        )
    )

    # Plot prior samples
    key = jax.random.PRNGKey(0)
    samples = prior.sample(key, num_samples=num_samples).T
    for i in range(num_samples):
        fig.add_trace(
            go.Scatter(
                x=x[:, 0],
                y=samples[:, i],
                mode="lines",
                line=dict(color="gray", width=1, dash="dot"),
                opacity=0.3,
                name=f"Sample {i + 1}" if i == 0 else None,
                showlegend=(i == 0),
            )
        )

    # Plot the linear posterior mean
    fig.add_trace(
        go.Scatter(
            x=x[:, 0],
            y=posterior.mu,
            mode="lines",
            name="Linear posterior",
            line=dict(color="red"),
        )
    )

    # Plot the 2-sigma confidence interval
    fig.add_trace(
        go.Scatter(
            x=jnp.concatenate([x[:, 0], x[::-1, 0]]),
            y=jnp.concatenate(
                [
                    posterior.mu - 2 * posterior.std,
                    (posterior.mu + 2 * posterior.std)[::-1],
                ]
            ),
            fill="toself",
            fillcolor="rgba(255,0,0,0.2)",
            line=dict(color="rgba(255,0,0,0)"),
            hoverinfo="skip",
            showlegend=False,
        )
    )

    # Plot posterior samples
    key = jax.random.PRNGKey(0)
    samples = posterior.sample(key, num_samples=num_samples).T
    for i in range(num_samples):
        fig.add_trace(
            go.Scatter(
                x=x[:, 0],
                y=samples[:, i],
                mode="lines",
                line=dict(color="red", width=1, dash="dot"),
                opacity=0.3,
                name=f"Sample {i + 1}" if i == 0 else None,
                showlegend=(i == 0),
            )
        )

    fig.update_layout(
        title="Linear Regression Posterior with Plotly",
        xaxis_title="Year",
        yaxis_title="CO₂ (ppm)",
        legend=dict(bgcolor="white"),
        template="plotly_white",
    )
    fig.show()

In [None]:
# .mu we could just do linear regression, of course:
def polynomial_features(x, num_features=2):
    # output shape: (n_samples, order)
    return (
        x ** jnp.arange(num_features)
        / jnp.exp(jax.scipy.special.gammaln(jnp.arange(num_features) + 1))
        / jnp.sqrt(num_features)
    )


phi = functools.partial(polynomial_features, num_features=2)
phi_X = phi(X)
F = phi_X.shape[1]
prior = Gaussian(mu=jnp.zeros(F), Sigma=10**2 * jnp.eye(F))
posterior = prior.condition(phi_X, Y, sigma**2 * jnp.eye(len(X)))

# plot the posterior
prior_x = phi(x) @ prior
posterior_x = phi(x) @ posterior

plot_prior_posterior(prior_x, posterior_x, x, X, Y)

In [44]:
# or do "textbook" GP regression, e.g. using the Rational quadratic kernel
def RQ_kernel(a, b, ell=1.0, alpha=1.0, theta=1.0):
    return theta**2 * (1 + jnp.sum((a - b) / ell, axis=-1) ** 2 / (2 * alpha)) ** (
        -alpha
    )


# we set the mean to a constant function at the data mean
mean_Y = jnp.mean(Y)


def constant_mean(x):
    return mean_Y * jnp.ones_like(x[:, 0])


# instantiate the Gaussian process prior
prior = GaussianProcess(
    constant_mean, functools.partial(RQ_kernel, ell=1.0, alpha=0.2, theta=3.0)
)

# condition the prior on the data
posterior = prior.condition(Y, X, sigma)


plot_prior_posterior(prior(x), posterior(x), x, X, Y)

In [47]:
def long_term_trend_kernel(a, b, theta=1.0, ell=20.0):
    return theta**2 * jnp.exp(-jnp.sum((a - b) ** 2, axis=-1) / (2 * ell**2))


# Instantiate the Gaussian process prior with the long-term trend kernel
prior = GaussianProcess(
    constant_mean, functools.partial(long_term_trend_kernel, ell=20.0, theta=10.0)
)

# Condition the prior on the data
posterior = prior.condition(Y, X, sigma)

plot_prior_posterior(prior(x), posterior(x), x, X, Y)

In [48]:
def sum_kernel(
    a, b, theta_long=1.0, theta_rq=1.0, ell_long_term=20.0, ell=1.0, alpha=1.0
):
    return long_term_trend_kernel(
        a, b, theta=theta_long, ell=ell_long_term
    ) + RQ_kernel(a, b, theta=theta_rq, ell=ell, alpha=alpha)


# Instantiate the Gaussian process prior with the sum kernel
prior = GaussianProcess(
    constant_mean,
    functools.partial(
        sum_kernel,
        theta_long=100.0,
        theta_rq=5.0,
        ell_long_term=50.0,
        ell=1.0,
        alpha=0.2,
    ),
)

# Condition the prior on the data
posterior = prior.condition(Y, X, sigma)

plot_prior_posterior(prior(x), posterior(x), x, X, Y)

In [50]:
def periodic_kernel(x, y, period=1.0, ell=1.0, theta=1.0):
    return theta**2 * jnp.exp(
        -2 * jnp.sin(jnp.pi * jnp.sum(x - y, axis=-1) / period) ** 2 / ell**2
    )


# Instantiate the Gaussian process prior with the periodic kernel
prior = GaussianProcess(constant_mean, lambda a, b: periodic_kernel(a, b))

# Define a decaying prior using a combination of periodic and long-term trend kernels
decaying_prior = GaussianProcess(
    constant_mean,
    lambda a, b: periodic_kernel(a, b) * long_term_trend_kernel(a, b, ell=10, theta=1),
)

plot_prior_posterior(prior(x), decaying_prior(x), x, X, Y)


In [52]:
def sum_kernel(
    x, y, theta_long=1.0, theta_periodic=1.0, ell_long_term=20.0, ell_periodic=1.0
):
    return long_term_trend_kernel(
        x, y, theta=theta_long, ell=ell_long_term
    ) + periodic_kernel(x, y, theta=theta_periodic, ell=ell_periodic, period=1.0)


# Instantiate the Gaussian process prior with the sum kernel
prior = GaussianProcess(
    constant_mean,
    functools.partial(
        sum_kernel,
        theta_long=100.0,
        theta_periodic=5.0,
        ell_long_term=50.0,
        ell_periodic=1.0,
    ),
)

# Condition the prior on the data
posterior = prior.condition(Y, X, sigma)

plot_prior_posterior(prior(x), posterior(x), x, X, Y)

In [54]:
def long_term_trend_kernel(x, y, theta=100.0, ell=100.0):
    return theta**2 * jnp.exp(-jnp.sum((x - y) ** 2, axis=-1) / (2 * ell**2))


def periodic_kernel(x, y, theta=1.0, ell_period=1.0, ell_decay=50.0):
    return (
        theta**2
        * jnp.exp(
            -2
            * jnp.sin(jnp.pi * jnp.sum(x - y, axis=-1) / ell_period) ** 2
            / ell_period**2
        )
        * jnp.exp(-jnp.sum((x - y) ** 2, axis=-1) / (2 * ell_decay**2))
    )


def mid_term_trend(x, y, theta=1.0, ell=1.0, alpha=1.0):
    return theta**2 * (1 + jnp.sum((x - y) ** 2, axis=-1) / (2 * alpha * ell**2)) ** (
        -alpha
    )


def noise_kernel(x, y, theta_weather=0.1, ell_weather=0.1, theta_measurement=0.1):
    return theta_weather**2 * jnp.exp(
        -jnp.sum((x - y) ** 2, axis=-1) / (2 * ell_weather**2)
    ) + theta_measurement**2 * jnp.all(x == y, axis=-1)


def model_kernel(x, y, parameters):
    (
        theta_long,
        ell_long_term,
        theta_periodic,
        ell_periodic,
        ell_decay_periodic,
        theta_mid_term,
        ell_mid_term,
        shape_mid_term,
        theta_weather,
        ell_weather,
        theta_measurement,
    ) = parameters

    return (
        long_term_trend_kernel(x, y, theta=theta_long, ell=ell_long_term)
        + periodic_kernel(
            x,
            y,
            theta=theta_periodic,
            ell_period=ell_periodic,
            ell_decay=ell_decay_periodic,
        )
        + mid_term_trend(
            x, y, theta=theta_mid_term, ell=ell_mid_term, alpha=shape_mid_term
        )
        + noise_kernel(
            x,
            y,
            theta_weather=theta_weather,
            ell_weather=ell_weather,
            theta_measurement=theta_measurement,
        )
    )


# Initial guesses for the parameters
init_params = jnp.asarray(
    [
        100.0,  # theta_long
        100.0,  # ell_long_term
        5.0,  # theta_periodic
        1.0,  # ell_periodic
        50.0,  # ell_decay_periodic
        1.0,  # theta_mid_term
        1.0,  # ell_mid_term
        1.0,  # shape_mid_term
        0.1,  # theta_weather
        0.1,  # ell_weather
        0.1,  # theta_measurement
    ]
)

# Define the model
gp = GaussianProcess(
    constant_mean, functools.partial(model_kernel, parameters=init_params)
)

# Condition the prior on the data
gp_posterior = gp.condition(Y, X, sigma)

plot_prior_posterior(gp(x), gp_posterior(x), x, X, Y)

In [55]:
from jax import grad, hessian
from scipy import optimize
import jax.numpy as jnp


def NegEvidence(params):
    gp = GaussianProcess(
        constant_mean, functools.partial(model_kernel, parameters=params)
    )
    # Assuming gp.log_pdf is a method that computes the log probability density function
    return -gp(X).log_pdf(Y)


# Compute the gradient and Hessian of the negative evidence function
grad_neg_ev = grad(NegEvidence)
hess_neg_ev = hessian(NegEvidence)

# Initial parameters
params = init_params

# Initial negative evidence and gradient
init_neg_evidence = NegEvidence(params)
grad_init = grad_neg_ev(params)
print(init_neg_evidence, grad_init)

# Optimize the parameters
results = optimize.minimize(
    NegEvidence,
    params,
    method="CG",
    jac=grad(NegEvidence),
    options={"gtol": 1e-6, "disp": True, "maxiter": 100},
)

# Retrieve the optimized parameters
optimized_params = results.x
print("Optimized Parameters:", optimized_params)


739.4111169681631 [-6.31420962e-02  6.39126553e-02  4.56166441e+00 -1.95125498e+03
 -3.49902047e-01  7.32921107e+00  7.15953269e+01  2.59031472e+01
 -5.03443458e+03  3.06239969e+03 -1.30526500e+04]
         Current function value: 248.765921
         Iterations: 100
         Function evaluations: 251
         Gradient evaluations: 248
Optimized Parameters: [100.00190487  99.99807848   4.84357293   1.00108951  50.01535196
   0.74712522   1.29753772   0.94492455   0.20347477   0.12097733
   0.20431567]



Maximum number of iterations has been exceeded.



In [56]:
# Print the optimized parameters with their names, initial values, optimized values, and units
names = [
    "theta_long",
    "ell_long_term",
    "theta_periodic",
    "ell_decay_periodic",
    "ell_periodic",
    "theta_mid_term",
    "ell_mid_term",
    "shape_mid_term",
    "theta_weather",
    "ell_weather",
    "theta_measurement",
]

RW_params = [66.0, 67.0, 2.4, 90.0, 1.3, 0.66, 1.2, 0.78, 0.18, 0.13, 0.19]
units = [
    "ppm",
    "years",
    "ppm",
    "years",
    "years",
    "ppm",
    "years",
    "unitless",
    "ppm",
    "years",
    "ppm",
]

for p, name, r, i, u in zip(results.x, names, RW_params, init_params, units):
    print(f"{name.rjust(25)} {p:.2f} {i:.2f} {r:.2f} {u}")

               theta_long 100.00 100.00 66.00 ppm
            ell_long_term 100.00 100.00 67.00 years
           theta_periodic 4.84 5.00 2.40 ppm
       ell_decay_periodic 1.00 1.00 90.00 years
             ell_periodic 50.02 50.00 1.30 years
           theta_mid_term 0.75 1.00 0.66 ppm
             ell_mid_term 1.30 1.00 1.20 years
           shape_mid_term 0.94 1.00 0.78 unitless
            theta_weather 0.20 0.10 0.18 ppm
              ell_weather 0.12 0.10 0.13 years
        theta_measurement 0.20 0.10 0.19 ppm


In [58]:
# Optimized parameters from the optimization process
opt_params = jnp.asarray(results.x)

# Define the Gaussian Process with optimized parameters
gp = GaussianProcess(
    constant_mean, functools.partial(model_kernel, parameters=opt_params)
)

# Condition the Gaussian Process on the data
gp_posterior = gp.condition(Y, X, sigma)

# Sample from the posterior
key = jax.random.PRNGKey(0)
samples = gp(X).sample(key=key, num_samples=5)

# Plot the data and samples using Plotly
import plotly.graph_objects as go

fig = go.Figure()

# Plot observed data
fig.add_trace(
    go.Scatter(
        x=X[:, 0],
        y=Y,
        mode="markers",
        name="Data",
        marker=dict(color="black", size=4),
        opacity=1.0,
    )
)

# Plot the samples
for i in range(samples.shape[0]):
    fig.add_trace(
        go.Scatter(
            x=X[:, 0],
            y=samples[i, :],
            mode="lines",
            name=f"Sample {i + 1}" if i == 0 else None,
            line=dict(color="royalblue"),
            opacity=0.7,
            showlegend=(i == 0),
        )
    )

fig.update_layout(
    title="Can you pick out the data?",
    xaxis_title="Year",
    yaxis_title="CO₂ (ppm)",
    template="plotly_white",
    legend=dict(bgcolor="white"),
)
fig.show()

In [59]:
plot_prior_posterior(gp(x), gp_posterior(x), x, X, Y, num_samples=5)