[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/fonnesbeck/instats_gp/blob/main/sessions/Session_3.ipynb)

# Session 3: Scaling with HSGP and Sparse Methods

## Learning Objectives

By the end of this session, you will be able to:

- Understand the computational bottlenecks of standard Gaussian processes
- Apply inducing point methods and sparse approximations to scale GPs to larger datasets
- Implement Hilbert Space GP (HSGP) approximations in PyMC
- Choose appropriate approximation parameters using helper functions and heuristics
- Navigate the trade-offs between approximation fidelity and computational efficiency

## Introduction

In the previous sessions, we explored the foundations of Gaussian processes and built models with various kernels and likelihoods. However, you may have noticed that as datasets grow larger, GP computations become increasingly expensive. The standard GP formulation requires inverting an $n \times n$ covariance matrix, where $n$ is the number of data points. This operation has $\mathcal{O}(n^3)$ computational complexity and $\mathcal{O}(n^2)$ memory requirementsâ€”quickly becoming prohibitive for datasets with thousands of observations.

In this session, we'll explore two powerful approaches to overcome these computational barriers: sparse GP approximations using inducing points, and the Hilbert Space GP (HSGP) method. These techniques allow us to apply GP models to much larger datasets while maintaining the flexibility and uncertainty quantification that make GPs so valuable.

Think of these methods as strategic compromises: we trade away some exactness in our GP representation to gain massive improvements in speed and scalability. The key question we'll answer throughout this session is: *how do we make this trade-off intelligently?*

In [None]:
import pymc as pm
import numpy as np
import polars as pl
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import arviz as az
import pytensor.tensor as pt
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

RNG = np.random.default_rng(RANDOM_SEED:= 8675309)

DATA_DIR = "../data/"

print(f"PyMC version: {pm.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Polars version: {pl.__version__}")
print(f"ArviZ version: {az.__version__}")

## Section 3.1: Understanding the Computational Challenge

Before diving into solutions, let's develop intuition for *why* standard GPs become computationally expensive. The bottleneck lies in computing and inverting the covariance matrix.

For a GP with $n$ observations, we need to:

1. **Compute** the $n \times n$ covariance matrix $K$ by evaluating the kernel function at all pairs of data points
2. **Invert** this matrix (or equivalently, solve a linear system) to compute the marginal likelihood
3. **Repeat** these operations at every step during MCMC sampling as hyperparameters change

The matrix inversion step dominates the computational cost, scaling as $\mathcal{O}(n^3)$. This cubic scaling means that doubling your dataset size increases computation time by roughly 8Ã—. For a dataset with 10,000 points, a full GP could take hours or days to fit, making interactive model development essentially impossible.

### The Cost of Cubic Scaling

To make this concrete, consider what happens as we increase dataset size:

- **n=50**: Covariance matrix has 2,500 elements, ~125,000 operations to invert
- **n=200**: Covariance matrix has 40,000 elements, ~8 million operations to invert (64Ã— more than n=50)
- **n=1,000**: Covariance matrix has 1,000,000 elements, ~1 billion operations to invert (8,000Ã— more than n=50)
- **n=10,000**: Covariance matrix has 100,000,000 elements, ~1 trillion operations to invert (8,000,000Ã— more than n=50)

And remember: these matrix inversions happen at **every MCMC iteration**. For 4,000 samples across 4 chains, that's 16,000 inversions!

This is where approximation methods become essential. Rather than abandoning GPs for large datasets, we can use clever mathematical tricks to reduce computational complexity while retaining most of the modeling flexibility.

### Generating Data for Sparse GP Demonstration

We'll create a dataset with 2000 observationsâ€”large enough to make standard GP inference slow, but small enough to allow us to compare against the exact solution. Our data will be drawn from a GP with a MatÃ©rn 5/2 kernel and moderate noise.

In [None]:
# Set parameters for data generation
n = 2000
ell_true = 1.0
eta_true = 3.0
sigma_true = 0.5

# Generate input locations
x = 10 * np.sort(RNG.random(n))

# Define true covariance function
cov_func = eta_true**2 * pm.gp.cov.Matern52(1, ell_true)

# Sample the latent GP function
K = cov_func(x[:, None]).eval()
K_stable = K + 1e-8 * np.eye(n)  # Add jitter for numerical stability
f_true = RNG.multivariate_normal(np.zeros(n), K_stable)

# Add observation noise
y = f_true + sigma_true * RNG.standard_normal(n)

# Create a Polars DataFrame
df = pl.DataFrame({
    'x': x,
    'y': y,
    'f_true': f_true
})

go.Figure().add_trace(go.Scatter(
    x=df['x'],
    y=df['y'],
    mode='markers',
    name='Observed data',
    marker=dict(size=3, color='gray', opacity=0.5)
)).add_trace(go.Scatter(
    x=df['x'],
    y=df['f_true'],
    mode='lines',
    name='True latent function',
    line=dict(color='dodgerblue', width=2)
)).update_layout(
    title='Simulated GP Data',
    xaxis_title='x',
    yaxis_title='y',
    height=400,
    showlegend=False
)

### Understanding Sparse GPs: The Inducing Point Idea

Standard GPs require computing and inverting an $n \times n$ covariance matrix $K_{nn}$, where $n$ is the number of observations. As we saw earlier, this $\mathcal{O}(n^3)$ operation becomes prohibitively expensive for large datasets. Sparse GPs solve this by introducing a clever approximation based on **inducing points**.

#### The Core Approximation Strategy

Instead of modeling correlations between all $n$ observations directly, sparse GPs introduce a smaller set of $m$ **inducing points** (also called pseudo-inputs) at strategic locations $\mathbf{X}_u$. These inducing points act as an **information bottleneck**: all correlations between observations flow through this compressed representation.

**Think of it like this**: Imagine understanding temperature patterns across a country. Rather than measuring correlations between all pairs of cities (expensive), you could:
1. Select $m$ strategically-placed weather stations (inducing points)
2. Model how each city correlates with these stations
3. Infer city-to-city relationships indirectly through the stations

This changes the problem from requiring $\mathcal{O}(n^3)$ operations to just $\mathcal{O}(nm^2)$ operations, where $m \ll n$.

#### The Mathematical Details

For a standard GP, the posterior mean and covariance at test locations $\mathbf{x}_*$ are:

$$
\begin{align}
\boldsymbol{\mu}_* &= K_{*n} \underbrace{K_{nn}^{-1}}_{\mathcal{O}(n^3)} \mathbf{y} \\
\boldsymbol{\Sigma}_* &= K_{**} - K_{*n} \underbrace{K_{nn}^{-1}}_{\mathcal{O}(n^3)} K_{n*}
\end{align}
$$

The bottleneck is inverting the large $n \times n$ matrix $K_{nn}$. This happens at every MCMC iteration as hyperparameters change.

**Sparse GPs avoid this** by factorizing the covariance structure through inducing points. The FITC (Fully Independent Training Conditional) approximation assumes observations are conditionally independent given the inducing point values $\mathbf{u}$:

$$
p(\mathbf{f} \mid \mathbf{u}) \approx \prod_{i=1}^n p(f_i \mid \mathbf{u})
$$

This leads to the approximate posterior:

$$
\begin{align}
\tilde{\boldsymbol{\mu}}_* &= K_{*m} \underbrace{K_{mm}^{-1}}_{\mathcal{O}(m^3)} K_{mn} \Lambda^{-1} \mathbf{y} \\
\tilde{\boldsymbol{\Sigma}}_* &= K_{**} - K_{*m} \left(K_{mm}^{-1} - K_{mm}^{-1} \Sigma_m K_{mm}^{-1}\right) K_{m*}
\end{align}
$$

where $\Lambda = \text{diag}(K_{nn} - K_{nm} K_{mm}^{-1} K_{mn}) + \sigma^2 I$ is diagonal (cheap to invert).

**The key insight**: We only invert the small $m \times m$ matrix $K_{mm}$, not the huge $n \times n$ matrix. We compute:
- $K_{mm}$: $m \times m$ covariance between inducing points (invert this: $\mathcal{O}(m^3)$)
- $K_{nm}$: $n \times m$ covariance between observations and inducing points (no inversion needed)
- $\Lambda$: diagonal, so trivial to invert

For our example with $n=2000$ and $m=20$:
- Standard GP: $\mathcal{O}(2000^3) \approx 8$ billion operations
- Sparse GP: $\mathcal{O}(2000 \times 20^2 + 20^3) \approx 800$ thousand operations
- **Speedup: ~10,000Ã—**

#### What Inducing Points Actually Do

Inducing points don't replace your dataâ€”they **summarize** it by playing two complementary roles:

1. **As anchor points**: They define strategic locations where the GP explicitly represents function values. The function everywhere else is determined by kernel-based interpolation from these anchors.

2. **As a compression mechanism**: Instead of tracking $\frac{n(n-1)}{2}$ pairwise correlations between observations, the GP only needs:
   - $\frac{m(m-1)}{2}$ correlations between inducing points ($m \times m$ relationships)
   - $n \times m$ correlations from each observation to inducing points

The FITC approximation assumes that once we know the function values at the inducing points, observations become conditionally independent. This is a strong assumption but works well in practice when:
- The function is smooth (appropriate kernel)
- Inducing points are well-distributed
- $m$ is large enough to capture the function's complexity

#### Choosing Inducing Point Locations

The approximation quality depends critically on where we place the inducing points. A practical strategy is **K-means clustering**:
- Places inducing points at cluster centers in the input space
- Naturally concentrates them where data is dense (important regions)
- Still ensures coverage across the entire domain
- Fast and deterministic (given a random seed)

For our 2000-observation dataset, we'll use $m=20$ inducing pointsâ€”a 100Ã— compression that still captures the essential structure.

In [None]:
# Use K-means to select inducing points
m = 20  # Number of inducing points

kmeans = KMeans(n_clusters=m, random_state=RANDOM_SEED, n_init=10)
kmeans.fit(x[:, None])
Xu = np.sort(kmeans.cluster_centers_.flatten())

print(f"Selected {m} inducing points using K-means")
print(f"Inducing points span: [{Xu.min():.2f}, {Xu.max():.2f}]")

Let's visualize where K-means placed our inducing points relative to the data density.

In [None]:
# Visualize inducing points overlaid on the data scatter plot
subsample_idx = RNG.choice(len(df), size=500, replace=False)

go.Figure().add_trace(go.Scatter(
    x=df['x'][subsample_idx],
    y=df['y'][subsample_idx],
    mode='markers',
    name='Observed data',
    marker=dict(size=3, color='#0066cc', opacity=0.3)
)).add_trace(go.Scatter(
    x=Xu,
    y=[0] * m,  # Place at y=0
    mode='markers',
    name=f'Inducing points (n={m})',
    marker=dict(
        size=10, 
        color='#ff6b6b',
        line=dict(width=1.5, color='#c92a2a')
    )
)).update_layout(
    title='Data Distribution with K-Means Selected Inducing Points',
    xaxis_title='x',
    yaxis_title='y',
    height=450,
    showlegend=True,
    legend=dict(x=0.02, y=0.98, bgcolor='rgba(255,255,255,0.8)'),
    plot_bgcolor='rgba(0,0,0,0)',
    hovermode='closest'
)

### Interpreting Inducing Point Placement

The K-means algorithm has distributed the 20 inducing points fairly evenly across the input domain. Since our data is uniformly distributed, this even spacing makes senseâ€”the inducing points provide coverage across the entire range where we'll need to make predictions.

Notice that **inducing points don't need to be at actual data locations**. They're auxiliary variables introduced solely to compress the GP's representation. The sparse GP will:
1. Learn latent function values at these 20 inducing locations
2. Use the kernel to propagate information from inducing points to observations
3. Interpolate smoothly from inducing points to any prediction location

#### Understanding the Trade-off

The approximation quality depends on the interplay between $m$ (number of inducing points) and the function's complexity:

- **Too few inducing points ($m$ too small)**: The GP becomes overly smooth, unable to capture rapid variations or fine-scale structure. Information is lost in the compression.

- **Too many inducing points ($m$ too large)**: Computational savings diminish as we approach the cost of a full GP. The approximation becomes nearly exact but defeats the purpose.

- **Well-chosen $m$**: Captures the essential structure while maintaining computational efficiency. For smooth functions, surprisingly few inducing points often suffice.

For our smooth underlying function with $m=20$ inducing points distributed across the domain, we're betting that kernel-based interpolation through these 20 anchors can effectively reconstruct the function between observations. The visualization in the next sections will show whether this bet pays off!

### Visualizing the Covariance Approximation

The visualization below reveals why sparse GPs work: the low-rank factorization captures nearly all the correlation structure of the full covariance matrix using just $m=20$ inducing points. 

The approximation error remains small (typically 5-10% relative error) because smooth GPs have covariance matrices that are inherently low-rankâ€”nearby points are highly correlated, while distant points contribute little information. This mathematical property allows us to compress a 2000Ã—2000 matrix (requiring 2 million parameters) down to just 20 strategic locations (requiring 20 parameters) without meaningful loss of fidelity.

In [None]:
# Compute full covariance matrix (subset) using the true covariance function
K_full = cov_func(x[:, None]).eval()

# Compute sparse approximation covariance
Xu_viz = Xu[:, None]
x_viz_2d = x[:, None]
K_mm = cov_func(Xu_viz).eval()
K_nm = cov_func(x_viz_2d, Xu_viz).eval()
K_sparse_approx = K_nm @ np.linalg.solve(K_mm, K_nm.T)

# Create side-by-side heatmaps
fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=(
        'Full Covariance',
        'Sparse Approximation',
        'Approximation Error'
    ),
    horizontal_spacing=0.1
)

# Full covariance
fig.add_trace(
    go.Heatmap(z=K_full, colorscale='Viridis', showscale=False,
               hovertemplate='i=%{x}, j=%{y}<br>K=%{z:.3f}<extra></extra>'),
    row=1, col=1
)

# Sparse approximation
fig.add_trace(
    go.Heatmap(z=K_sparse_approx, colorscale='Viridis', showscale=False,
               hovertemplate='i=%{x}, j=%{y}<br>KÌƒ=%{z:.3f}<extra></extra>'),
    row=1, col=2
)

# Error
error = K_full - K_sparse_approx
fig.add_trace(
    go.Heatmap(
        z=error,
        colorscale='RdBu_r',
        zmid=0,
        showscale=True,
        colorbar=dict(title='Error', x=1.0),
        hovertemplate='i=%{x}, j=%{y}<br>Error=%{z:.3f}<extra></extra>'
    ),
    row=1, col=3
)

fig.update_xaxes(title_text='Data point index', row=1, col=1)
fig.update_xaxes(title_text='Data point index', row=1, col=2)
fig.update_xaxes(title_text='Data point index', row=1, col=3)
fig.update_yaxes(title_text='Data point index', row=1, col=1)

fig.update_layout(
    height=400,
    showlegend=False
)

fig.show()

print(f"Maximum absolute error: {np.abs(error).max():.4f}")
print(f"Mean absolute error: {np.abs(error).mean():.4f}")
print(f"Relative error (Frobenius norm): {np.linalg.norm(error) / np.linalg.norm(K_full):.4f}")

### Building the Sparse GP Model with FITC

Now we'll build our sparse GP model using PyMC's `MarginalApprox` class (the modern replacement for `MarginalSparse`) with the FITC approximation. Notice how the model specification is nearly identical to a standard GPâ€”we just provide the inducing point locations `Xu` and specify the approximation type.

In [None]:
with pm.Model() as sparse_model:
    # Priors on hyperparameters
    ell = pm.Gamma('ell', alpha=2, beta=1)
    eta = pm.HalfNormal('eta', sigma=5)
    
    # Define covariance function
    cov = eta**2 * pm.gp.cov.Matern52(1, ls=ell)
    
    # Sparse GP with FITC approximation
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='FITC')
    
    # Observation noise
    sigma = pm.HalfNormal('sigma', sigma=2)
    
    # Marginal likelihood
    y_obs = gp.marginal_likelihood(
        'y_obs',
        X=x[:, None],
        Xu=Xu[:, None],
        y=y,
        sigma=sigma
    )
    
    # Sample posterior
    idata_sparse = pm.sample(
        500,
        tune=500,
        chains=2,
        random_seed=RANDOM_SEED,
        nuts_sampler="nutpie"
    )

### Examining the Posterior

Let's check the posterior distributions of our hyperparameters and verify that sampling was successful.

In [None]:
az.summary(
    idata_sparse,
    var_names=['ell', 'eta', 'sigma'],
    round_to=2
)

In [None]:
# Trace plot
az.plot_trace(
    idata_sparse,
    var_names=['ell', 'eta', 'sigma'],
    figsize=(10, 6)
)
plt.tight_layout();

### Making Predictions with the Sparse GP

One of the benefits of the sparse GP approximation is that prediction is also fast. Let's make predictions at a dense grid of test points and visualize the posterior predictive distribution.

In [None]:
# Create test points
x_test = np.linspace(-0.5, 10.5, 300)

# Add conditional distribution to model and sample
with sparse_model:
    f_pred = gp.conditional('f_pred', x_test[:, None])
    
    # Sample posterior predictive
    posterior_pred = pm.sample_posterior_predictive(
        idata_sparse,
        var_names=['f_pred'],
        random_seed=RANDOM_SEED,
    )

In [None]:
# Use built-in plot_gp_dist for cleaner visualization
f_true_interp = np.interp(x_test, x, f_true)
subsample_idx = RNG.choice(len(df), size=200, replace=False)

# Reshape samples for plot_gp_dist: needs shape (n_samples, n_points)
# posterior_predictive has shape (chain, draw, n_points)
f_pred_samples = posterior_pred.posterior_predictive['f_pred'].values
# Stack chain and draw dimensions: (chain, draw, n_points) -> (chain*draw, n_points)
n_chains, n_draws, n_points = f_pred_samples.shape
f_pred_samples = f_pred_samples.reshape(n_chains * n_draws, n_points)

fig, ax = plt.subplots(figsize=(12, 6))

# Plot GP distribution with credible intervals
pm.gp.util.plot_gp_dist(
    ax,
    f_pred_samples,
    x_test,
    palette='Reds',
    plot_samples=False
)

# Add true function and data
ax.plot(x_test, f_true_interp, 'b--', linewidth=2, label='True Function')
ax.plot(df['x'][subsample_idx], df['y'][subsample_idx], 'o',
        color='gray', alpha=0.5, markersize=3, label='Data (subsample)')

# Add inducing points
ax.plot(Xu, np.ones(len(Xu)) * ax.get_ylim()[0], 'cx',
        markersize=10, markeredgewidth=2, label='Inducing Points')

ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('f(x)', fontsize=12)
ax.set_title('Sparse GP Predictions with FITC Approximation', fontsize=14)
ax.legend()
plt.tight_layout()

### ðŸ¤– EXERCISE: Sparse GP with Cherry Blossoms Data

Now it's your turn to experiment with sparse GPs using real historical data. The cherry blossoms dataset contains over 1,000 years of recorded bloom dates from Kyoto, Japanâ€”one of the longest phenological records in existence.

**Your task**: Apply the sparse GP techniques you've learned to model how cherry blossom bloom timing has changed over the centuries.

**Dataset**: The cherry blossoms data (`data/cherry_blossoms.csv`) contains:
- `year`: Year (801-2015 CE)
- `doy`: Day of year when cherry blossoms bloomed (with some missing values)

This dataset is sparse in time (many years have missing observations) and exhibits long-term trends that make it perfect for practicing sparse GP methods.

In [None]:
# Load cherry blossoms data
cherry_df = pl.read_csv(
    DATA_DIR + 'cherry_blossoms.csv', 
    separator=';',
    null_values=['NA']  # Treat 'NA' strings as null
)

# Remove rows with missing bloom dates
cherry_df = cherry_df.filter(pl.col('doy').is_not_null())

# Extract year and day-of-year
years = cherry_df['year'].to_numpy().astype(float)
doy = cherry_df['doy'].to_numpy().astype(float)

print(f"Cherry Blossoms Dataset: {len(cherry_df)} observations")
print(f"Year range: {int(years.min())} - {int(years.max())}")
print(f"Day-of-year range: {int(doy.min())} - {int(doy.max())}")
print(f"Mean bloom date: Day {doy.mean():.1f} (approximately {doy.mean():.0f} days after Jan 1)")

# Visualize the data
go.Figure().add_trace(go.Scatter(
    x=years,
    y=doy,
    mode='markers',
    name='Observed bloom dates',
    marker=dict(size=5, color='hotpink', opacity=0.7, line=dict(width=0.5, color='darkviolet'))
)).update_layout(
    title='Cherry Blossom Bloom Dates in Kyoto (827-2015)',
    xaxis_title='Year',
    yaxis_title='Day of Year',
    height=400,
    showlegend=False,
    plot_bgcolor='white',
    xaxis=dict(showgrid=False, zeroline=False),
    yaxis=dict(showgrid=False, zeroline=False)
)

In [None]:
# EXERCISE INSTRUCTIONS:
#
# STEP 1: Choose inducing points using K-means
# - Try different numbers of inducing points: M = 20, 50, 100
# - Use the K-means clustering approach shown earlier
# - Visualize where the inducing points are placed
#
# Prompt suggestion: "Help me use K-means to select M inducing points from the
# cherry blossoms years data and create a visualization showing their placement."

# STEP 2: Build a sparse GP model with MarginalApprox
# - Use a MatÃ©rn 5/2 or ExpQuad kernel (cherry blossoms show smooth long-term trends)
# - Consider appropriate priors for lengthscale (think in terms of decades or centuries)
# - Use the FITC approximation
#
# Prompt suggestion: "Help me build a pm.gp.MarginalApprox model for the cherry
# blossoms data with priors suitable for multi-century trends."

# STEP 3: Make predictions and visualize
# - Predict bloom dates across the full time range
# - Compare predictions with different numbers of inducing points (M=20 vs M=100)
# - Visualize uncertainty (credible intervals)
#
# Prompt suggestion: "Help me create predictions from the sparse GP and make a
# plotly visualization showing the posterior mean, credible intervals, and data points."

# STEP 4: Explore the trade-offs
# - How does increasing M affect:
#   * Approximation quality (smoothness, fit to data)
#   * Computation time (sampling speed)
#   * Uncertainty estimates
# - Can you identify interesting historical patterns (e.g., warming trends)?

# YOUR LLM-ASSISTED CODE HERE

## Section 3.3: Hilbert Space GP (HSGP) Theory

While sparse GPs use inducing points to reduce complexity, the Hilbert Space GP (HSGP) takes a completely different approach: it approximates the GP using a **basis function expansion**. This transforms the non-parametric GP into a parametric model with a fixed number of basis functions, making it compatible with standard MCMC samplers and dramatically improving computational efficiency.

The mathematical foundation of HSGP comes from spectral analysis of covariance functions. Any stationary covariance kernel can be represented through its **power spectral density**â€”essentially a Fourier transform that describes the kernel's behavior in frequency space. The HSGP approximation uses a finite set of basis functions (sinusoids) whose coefficients are drawn from a distribution determined by this spectral density.

### Why This Matters

Think of it this way: instead of defining a function through all pairwise correlations (which requires $n^2$ parameters and $n^3$ operations), HSGP defines it through $m$ basis function coefficients. These basis functions are pre-computed and don't depend on hyperparameters, so we only need to update the coefficients during sampling.

The computational complexity drops from $\mathcal{O}(n^3)$ for exact GPs to $\mathcal{O}(nm + m)$ for HSGP, where $m$ is the number of basis functions. Even better, HSGP is fully parametricâ€”we can use `pm.set_data` for predictions without explicitly computing conditional distributions. This makes it much easier to integrate an HSGP into your existing PyMC model.

Additionally, unlike many other GP approximations, HSGPs can be used anywhere within a model and with any likelihood function. This flexibility is a major advantage over methods like sparse GPs that work best with Gaussian likelihoods.

### HSGP Restrictions

The HSGP approximation does carry some restrictions:

1. It **can only be used with stationary covariance kernels** such as the MatÃ©rn family or ExpQuad. The kernel must implement the `power_spectral_density` method.
2. It **does not scale well with input dimension**. HSGP is a good choice for 1D processes (like time series) or 2D spatial processes, but likely not efficient beyond 3 dimensions.
3. It **may struggle with very rapidly varying processes**. If the process changes very quickly relative to the domain extent, you may need very large $m$ to accurately represent it.
4. **For smaller datasets, the full unapproximated GP may still be more efficient**.

### Key Parameters: m and c

HSGP approximations are controlled by two parameters:

- **m**: The number of basis functions. Larger $m$ gives better approximation quality but increases computational cost. Think of $m$ as the "resolution" of your approximationâ€”more basis functions can represent more complex, rapidly-varying patterns. Increasing $m$ helps the HSGP approximate GPs with smaller lengthscales.

- **c**: The boundary extension factor. HSGP basis functions are defined on a finite domain $[-L, L]$ where $L = c \cdot S$ and $S$ is half the range of your centered data. Larger $c$ values help approximate GPs with longer lengthscales and ensure predictions away from data aren't affected by boundary conditions. However, increasing $c$ may require increasing $m$ to compensate for loss of fidelity at smaller lengthscales.

The art of using HSGP effectively lies in choosing $m$ and $c$ appropriately for your data and expected lengthscales. Fortunately, PyMC provides a helper function to get you started.

### Visualizing HSGP Basis Functions

To build intuition, let's visualize what HSGP basis functions actually look like. These are the sinusoidal building blocks that will be combined to approximate our GP. Notice that we need to center the data firstâ€”this is an important requirement for HSGP.

In [None]:
x_grid = np.linspace(-5, 5, 1000)

fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharey=True)

ylim = 0.55
axs[0].set_ylim([-ylim, ylim])
axs[1].set_yticks([])
axs[1].set_xlabel("x (centered)")
axs[2].set_yticks([])

L_options = [5.0, 6.0, 20.0]
m_options = [3, 3, 5]

for i, ax in enumerate(axs):
    L = L_options[i]
    m_val = m_options[i]
    
    eigvals = pm.gp.hsgp_approx.calc_eigenvalues(pt.as_tensor([L]), [m_val])
    phi = pm.gp.hsgp_approx.calc_eigenvectors(
        x_grid[:, None],
        pt.as_tensor([L]),
        eigvals,
        [m_val],
    ).eval()
    
    for j in range(phi.shape[1]):
        ax.plot(x_grid, phi[:, j])
    
    ax.set_xticks(np.arange(-5, 6, 5))
    
    S = 5.0
    c = L / S
    ax.text(-4.9, -0.45, f"L = {L}\nc = {c}", fontsize=12)
    ax.set_title(f"{m_val} basis functions")
    ax.set_xlabel("x (centered)")

axs[0].set_ylabel("Basis function value")
plt.suptitle("The Effect of Changing L on HSGP Basis Vectors", fontsize=14)
plt.tight_layout()

### Interpreting the Basis Functions

These plots reveal critical insights about HSGP basis functions:

**Left panel (L=5, c=1.0)**: When $L$ equals the data range, all basis vectors are forced to pinch to zero at the boundaries (at $x=-5$ and $x=5$). This means the HSGP approximation becomes poor near the edges of your data. This is why we need $c > 1$.

**Middle panel (L=6, c=1.2)**: With $c=1.2$, the basis functions extend beyond the data range and are no longer forced to zero at the data boundaries. This helps the approximation remain accurate across the entire domain. Values of $c$ around 1.2 are considered the minimum for reasonable approximations.

**Right panel (L=20, c=4.0, m=5)**: With larger $L$ or $c$, the basis functions become lower frequency (longer wavelength). Notice how the first basis function (blue) is nearly flatâ€”it's becoming partially unidentifiable with an intercept term. This is why we sometimes need to drop the first basis function, or increase $m$ to compensate.

Notice that the basis functions are sinusoids with increasing frequency. Lower-order basis functions capture long-range trends, while higher-order functions capture increasingly rapid oscillations. An HSGP approximation works by taking a weighted sum of these basis functions.

The key lessons:
- **Increasing $m$ helps approximate GPs with smaller lengthscales** (more basis functions = higher resolution)
- **Increasing $c$ or $L$ helps approximate GPs with larger lengthscales** but may require increasing $m$ to maintain fidelity at smaller lengthscales
- **Consider where predictions will be made**â€”they also need to be away from the boundary "pinch"

## Section 3.4: HSGP Implementation

Now let's implement an HSGP model and see it in action. We'll use the same dataset as before for direct comparison with the sparse GP.

### Choosing HSGP Parameters

PyMC provides a helper function `approx_hsgp_hyperparams` that suggests values for $m$ and $c$ based on:
- The range of your input data
- The range of lengthscales you expect (from your prior)
- The covariance function type

These recommendations are based on approximation error bounds derived in the HSGP literature. The heuristics help you choose $c$ large enough to handle the largest lengthscales you might fit, and $m$ large enough to accommodate the smallest lengthscales. Let's use this function to get started.

In [None]:
# Determine appropriate m and c
x_range = [x.min(), x.max()]
lengthscale_range = [0.5, 3.0]  # Based on our prior knowledge

m_recommended, c_recommended = pm.gp.hsgp_approx.approx_hsgp_hyperparams(
    x_range=x_range,
    lengthscale_range=lengthscale_range,
    cov_func='matern52'
)

print(f"Recommended m: {m_recommended}")
print(f"Recommended c: {c_recommended:.2f}")

m_hsgp = m_recommended
c_hsgp = c_recommended

### Building the HSGP Model

The HSGP model specification in PyMC is remarkably similar to a standard GP. The key difference is that we use `pm.gp.HSGP` instead of `pm.gp.Latent` or `pm.gp.Marginal`, and we specify the approximation parameters $m$ and $c$. 

Notice that we use the `.prior` method just like with `pm.gp.Latent`. For basic usage, HSGP can be treated as a drop-in replacement for the standard latent GP.

In [None]:
with pm.Model() as hsgp_model:

    # Priors on hyperparameters 
    ell = pm.Gamma('ell', alpha=2, beta=1)
    eta = pm.HalfNormal('eta', sigma=5)
    
    # Define covariance function
    cov = eta**2 * pm.gp.cov.Matern52(1, ls=ell)
    
    # HSGP approximation
    gp = pm.gp.HSGP(m=[m_hsgp], c=c_hsgp, cov_func=cov)
    
    # Prior over the latent function
    f = gp.prior('f', X=x[:, None])
    
    # Observation noise
    sigma = pm.HalfNormal('sigma', sigma=2)
    
    # Likelihood
    y_obs = pm.Normal('y_obs', mu=f, sigma=sigma, observed=y)
    
    idata_hsgp = pm.sample(
        500,
        tune=500,
        chains=2,
        random_seed=RANDOM_SEED,
        nuts_sampler="nutpie",
    )

### Examining HSGP Results

Let's check sampling diagnostics and posterior distributions for the HSGP model.

In [None]:
summary_hsgp = az.summary(
    idata_hsgp,
    var_names=['ell', 'eta', 'sigma'],
    round_to=2
)
print(summary_hsgp)

In [None]:
az.plot_trace(
    idata_hsgp,
    var_names=['ell', 'eta', 'sigma'],
    figsize=(10, 6)
)
plt.tight_layout()
plt.show()

### Making Predictions with HSGP

One major advantage of HSGP is the ease of prediction. Since it's parametric, we can use the `.conditional` method just like with other GPs. Let's make predictions and visualize the fit.

In [None]:
x_test = np.linspace(-0.5, 10.5, 300)

with hsgp_model:
    f_pred_hsgp = gp.conditional('f_pred', x_test[:, None])
    
    posterior_pred_hsgp = pm.sample_posterior_predictive(
        idata_hsgp,
        var_names=['f_pred'],
        random_seed=RANDOM_SEED,
    )

In [None]:
# Use built-in plot_gp_dist for cleaner visualization
f_true_interp = np.interp(x_test, x, f_true)
subsample_idx = RNG.choice(len(df), size=200, replace=False)

# Reshape samples for plot_gp_dist: needs shape (n_samples, n_points)
# posterior_predictive has shape (chain, draw, n_points)
f_pred_samples = posterior_pred_hsgp.posterior_predictive['f_pred'].values
# Stack chain and draw dimensions: (chain, draw, n_points) -> (chain*draw, n_points)
n_chains, n_draws, n_points = f_pred_samples.shape
f_pred_samples = f_pred_samples.reshape(n_chains * n_draws, n_points)

fig, ax = plt.subplots(figsize=(12, 6))

# Plot GP distribution with credible intervals
pm.gp.util.plot_gp_dist(
    ax,
    f_pred_samples,
    x_test,
    palette='Purples',
    plot_samples=False
)

# Add true function and data
ax.plot(x_test, f_true_interp, color='gold', linestyle='--',
        linewidth=3, label='True Function')
ax.plot(df['x'][subsample_idx], df['y'][subsample_idx], 'o',
        color='gray', alpha=0.5, markersize=3, label='Data (subsample)')

ax.set_xlabel('x', fontsize=12)
ax.set_ylabel('f(x)', fontsize=12)
ax.set_title(f'HSGP Fit (m={m_hsgp}, c={c_hsgp:.2f})', fontsize=14)
ax.legend()
plt.tight_layout()

### Understanding the HSGP Fit

The HSGP inferred posterior (purple) accurately matches the true underlying GP (gold dashed line). We also see that the credible intervals appropriately capture uncertainty. This demonstrates that even with an approximation using basis functions, we can achieve excellent fit quality.

Notice that with recommended parameters from `approx_hsgp_hyperparams`, the approximation is essentially indistinguishable from what an exact GP would produce. The computational cost, however, is dramatically lowerâ€”$\mathcal{O}(nm)$ instead of $\mathcal{O}(n^3)$.

### Comparing HSGP to Sparse GP

Let's directly compare the posterior distributions from the HSGP and sparse GP models using a forest plot, which is ideal for comparing multiple models side-by-side.

In [None]:
az.plot_forest(
    [idata_sparse, idata_hsgp],
    model_names=['Sparse GP', 'HSGP'],
    var_names=['ell', 'eta', 'sigma'],
    combined=True,
    figsize=(10, 5)
)
plt.tight_layout()
plt.show()

### Interpreting the Comparison

The posterior distributions from HSGP and sparse GP should be very similar, particularly for the lengthscale and amplitude parameters that control the function's smoothness and scale. Small differences are expected since both are approximations, but substantial disagreement would suggest that one or both approximations is inadequate.

Both methods successfully inferred hyperparameters close to the true values (lengthscale=1.0, amplitude=3.0, noise=0.5), demonstrating that either approach can work well for moderate-sized datasets with smooth underlying functions.

## Section 3.5: Advanced HSGP - Centered vs Non-Centered Parameterization

An important consideration when using HSGP is choosing between centered and non-centered parameterizations. This is analogous to the choice you make in hierarchical models, and for similar reasons: the correlation structure in the posterior.

### When to Use Each Parameterization

**Centered parameterization** works better when:
- The underlying GP is strongly informed by the data
- You have lots of data relative to the lengthscale
- The signal-to-noise ratio is high

**Non-centered parameterization** (the default) works better when:
- The underlying GP is weakly informed by the data  
- You have sparse data or large lengthscales
- The signal-to-noise ratio is low

In our example with 2000 noisy observations, the centered parameterization might actually be better. Let's test this.

In [None]:
with pm.Model() as hsgp_centered:

    ell = pm.Gamma('ell', alpha=2, beta=1)
    eta = pm.HalfNormal('eta', sigma=5)
    
    cov = eta**2 * pm.gp.cov.Matern52(1, ls=ell)
    
    # HSGP with centered parameterization
    gp = pm.gp.HSGP(
        m=[m_hsgp], 
        c=c_hsgp, 
        cov_func=cov,
        parametrization='centered'  # Key difference!
    )
    
    f = gp.prior('f', X=x[:, None])
    sigma = pm.HalfNormal('sigma', sigma=2)
    y_obs = pm.Normal('y_obs', mu=f, sigma=sigma, observed=y)
    
    idata_hsgp_centered = pm.sample(
        500,
        tune=500,
        chains=2,
        random_seed=RANDOM_SEED,
        nuts_sampler="nutpie"
    )

In [None]:
print("Non-centered parameterization:")
print(az.summary(idata_hsgp, var_names=['ell', 'eta', 'sigma'])[['ess_bulk', 'ess_tail', 'r_hat']])
print("\nCentered parameterization:")
print(az.summary(idata_hsgp_centered, var_names=['ell', 'eta', 'sigma'])[['ess_bulk', 'ess_tail', 'r_hat']])

### Interpreting Parameterization Effects

Compare the effective sample sizes (ESS) between the two parameterizations. Higher ESS means more efficient samplingâ€”you're getting more independent samples per iteration. For this dataset with strong signal, you may find the centered parameterization provides better ESS.

The choice of parameterization doesn't affect what you're learning about the hyperparametersâ€”it only affects how efficiently the sampler explores the posterior. If you find sampling is slow or you see low ESS, try switching parameterizations.

### Visualizing the HSGP Approximate Gram Matrix

Another way to check HSGP fidelity is to directly compare the unapproximated Gram matrix (covariance matrix) $\mathbf{K}$ to the one resulting from the HSGP approximation:

$$
\tilde{\mathbf{K}} = \Phi \Lambda \Phi^T
$$

where $\Phi$ is the matrix of eigenvectors (basis functions), and $\Lambda$ has the spectral densities computed at the eigenvalues along the diagonal. Let's visualize this for different values of $m$ and $c$ to see when the approximation starts to degrade.

In [None]:
# Use a subset for clearer visualization
n_viz = 100
x_viz = np.linspace(0, 10, n_viz)

# True GP covariance
chosen_ell = 1.5
cov_func_viz = pm.gp.cov.Matern52(1, ls=chosen_ell)
K_true = cov_func_viz(x_viz[:, None]).eval()

# Helper function to calculate HSGP approximate Gram matrix
def calculate_K_approx(x_centered, L, m_val, cov_func):
    """Calculate the HSGP approximate covariance matrix."""
    eigvals = pm.gp.hsgp_approx.calc_eigenvalues(L, m_val)
    phi = pm.gp.hsgp_approx.calc_eigenvectors(x_centered, L, eigvals, m_val)
    omega = pt.sqrt(eigvals)
    psd = cov_func.power_spectral_density(omega)
    return (phi @ pt.diag(psd) @ phi.T).eval()

# Center the data
x_center = (x_viz.max() + x_viz.min()) / 2.0
x_viz_centered = (x_viz - x_center)[:, None]

# Create comparison plot
fig, axs = plt.subplots(2, 3, figsize=(12, 8), sharey=True)

# True Gram matrix
axs[0, 0].imshow(K_true, cmap='inferno', vmin=0, vmax=1)
axs[0, 0].set_title(f'True Gram matrix\nâ„“ = {chosen_ell}')
axs[0, 0].set_ylabel('Index')
axs[1, 0].axis('off')

# Various m and c combinations
configs = [
    ([30], 2.5, 1),
    ([15], 2.5, 2),
    ([30], 1.2, 3),
    ([15], 1.2, 4),
    ([5], 2.5, 5),
]

for m_val, c_val, idx in configs:
    row = 0 if idx <= 2 else 1
    col = idx if idx <= 2 else idx - 3
    
    L = pm.gp.hsgp_approx.set_boundary(x_viz_centered, c_val)
    K_approx = calculate_K_approx(x_viz_centered, L, m_val, cov_func_viz)
    
    axs[row, col].imshow(K_approx, cmap='inferno', vmin=0, vmax=1, interpolation='none')
    axs[row, col].set_title(f'm = {m_val[0]}, c = {c_val}')
    
    if col == 0:
        axs[row, col].set_ylabel('Index')
    if row == 1:
        axs[row, col].set_xlabel('Index')

plt.suptitle('HSGP Approximation Quality: Comparing to True Gram Matrix', fontsize=14)
plt.tight_layout()

### Understanding Gram Matrix Approximations

These plots compare approximate Gram matrices to the unapproximated one (top left). The goal is visual similarityâ€”the more alike they look, the better the approximation. Important caveats:

- These results are **only relevant for this specific domain and lengthscale** ($\ell = 1.5$). Different lengthscales will show different approximation quality.
- The approximation looks good for $m = 30$ or $m = 15$ with $c=2.5$. The rest show clear differences.
- $c=1.2$ is generally too small, regardless of $m$, showing degradation at the boundaries.
- Surprisingly, $m=5$, $c=1.2$ can look better than $m=5$, $c=2.5$. When we "stretch" the basis to fill a larger domain, we lose fidelity at smaller lengthscales if $m$ is too small.

The lesson: **you need to experiment across your range of lengthscales** to find adequate $m$ and $c$ values. Often during prototyping, you can use lower fidelity (smaller $m$) for faster iteration, then dial in higher fidelity once you understand the relevant lengthscales.

### Practical Heuristics for Choosing m and c

In practice, you'll need to infer the lengthscale from data, so HSGP needs to approximate a GP across a range of lengthscales representative of your prior. Based on the research literature (Riutort-Mayol et al., 2020) and the PyMC implementation:

1. **Start with `approx_hsgp_hyperparams`**: This function provides data-driven recommendations based on your input range and expected lengthscales. It uses approximation error bounds from the HSGP literature to choose $c$ large enough to handle your largest expected lengthscales and $m$ large enough for your smallest lengthscales.

2. **Understand the trade-offs**:
   - **Increasing $m$ helps approximate GPs with smaller lengthscales** at the cost of increased computation
   - **Increasing $c$ helps approximate GPs with larger lengthscales** but may require increasing $m$ to compensate for loss of fidelity at smaller scales

3. **Experiment with your specific problem**: You will need to verify approximation quality across your range of lengthscales. The recommendations from `approx_hsgp_hyperparams` provide a starting point, but you may be able to reduce $m$ for computational savings if your function is smooth, or need to increase it if the process varies rapidly.

4. **Check the basis vectors if sampling struggles**: The first eigenvector can become unidentifiable with an intercept when $c$ is large. Consider using the `drop_first` option.

5. **Verify approximation quality**: Compare the HSGP approximate Gram matrix to the true Gram matrix (as shown in the visualization above) on a data subset to confirm adequate approximation for your lengthscale range.


In [None]:
# ðŸ¤– EXERCISE: Use your LLM to help compare HSGP vs standard GP

# STEP 1: Ask your LLM to help you implement this function
def hsgp_vs_full_gp(X, y, m_values=(20, 50, 100), L_factor=1.5):
    """
    Fit HSGP for several m, compare to full GP in time and RMSE.
    
    Prompt suggestion: "Help me implement HSGP in PyMC and produce plots of
    computation time vs error across different m values."
    """
    # YOUR LLM-ASSISTED CODE HERE
    pass

# STEP 2: Test on a subset of data
# Use 500-1000 points for reasonable comparison times

Now that we understand how to choose HSGP parameters in one dimension, let's see how these principles extend to spatial modeling with 2D data.

## Section 3.6: Advanced HSGP - 2D Spatial Example

So far we've focused on one-dimensional examples. HSGP also works well in two dimensions, making it excellent for spatial modeling. Let's apply HSGP to a real-world spatial dataset to see how $m$ and $c$ work in multiple dimensions.

### The Walker Lake Dataset

We'll use the famous **Walker Lake dataset (Isaaks & Srivastava 1989)**, a classic dataset in spatial statistics. This involves spatial sampling of mineral concentrations across Walker Lake in Nevada. The data consist of spatial coordinates (Xloc, Yloc) in meters and measurements of variable V (concentration in parts per million). The samples are taken regularly over a coarse grid across the entire area, with additional irregular sampling in regions of interest.

In [None]:
# Load Walker Lake data using Polars
# The file has mixed whitespace (tabs and multiple spaces)
# Polars doesn't support regex separators, so we read lines and parse with Polars

# Read file and clean up whitespace
with open(DATA_DIR + 'walker.txt', 'r') as f:
    lines = [line.split() for line in f.readlines()[8:] if line.strip()]

# Create Polars DataFrame from parsed data
walker_data = pl.DataFrame({
    'ID': [int(row[0]) for row in lines if len(row) >= 6],
    'Xloc': [float(row[1]) for row in lines if len(row) >= 6],
    'Yloc': [float(row[2]) for row in lines if len(row) >= 6],
    'V': [float(row[3]) for row in lines if len(row) >= 6],
    'U': [float(row[4]) for row in lines if len(row) >= 6],
    'T': [int(row[5]) for row in lines if len(row) >= 6]
}).with_columns(
    # Replace missing values (1E31) with None
    pl.when(pl.col('V') > 1e30).then(None).otherwise(pl.col('V')).alias('V'),
    pl.when(pl.col('U') > 1e30).then(None).otherwise(pl.col('U')).alias('U')
).filter(
    # Use only observations with valid V measurements
    pl.col('V').is_not_null()
)

# Extract spatial coordinates and V variable
X_2d = walker_data[['Xloc', 'Yloc']].to_numpy()
y_2d = walker_data['V'].to_numpy()

print(f"Loaded {len(X_2d)} observations from Walker Lake dataset")
print(f"X ranges: [{X_2d[:, 0].min():.0f}, {X_2d[:, 0].max():.0f}] x [{X_2d[:, 1].min():.0f}, {X_2d[:, 1].max():.0f}] meters")
print(f"V (concentration) range: [{y_2d.min():.1f}, {y_2d.max():.1f}] ppm")

Let's visualize the spatial distribution of the observed concentration data. The samples are irregularly distributed across the Walker Lake area.

In [None]:
go.Figure().add_trace(
    go.Scatter(
        x=X_2d[:, 0],
        y=X_2d[:, 1],
        mode='markers',
        marker=dict(
            size=12,
            color=y_2d,
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(
                title='V (ppm)',
                thickness=20,
                len=0.7
            ),
            line=dict(width=0.5, color='white')  
        ),
        text=[f'V: {v:.1f} ppm' for v in y_2d],
        hovertemplate='X: %{x:.0f}m<br>Y: %{y:.0f}m<br>%{text}<extra></extra>'
    )
).update_layout(
    title='Walker Lake Mineral Concentration Data',
    xaxis=dict(
        title='X Location (meters)',
        showgrid=False,
        zeroline=False,
        showline=False,
        mirror=False
    ),
    yaxis=dict(
        title='Y Location (meters)',
        scaleanchor='x',
        scaleratio=1,
        showgrid=False,
        zeroline=False,
        showline=False,
        mirror=False
    ),
    plot_bgcolor='rgba(240,240,240,0.3)',
    height=600,
    width=650
)

### Building a 2D HSGP Model

For 2D HSGPs, we specify $m$ and $c$ as two-element listsâ€”one value per dimension. The total number of basis functions is $m_1 \times m_2$, so computational cost grows multiplicatively with dimension.

The Walker Lake data spans 300 meters in each direction with irregular sampling. We'll use the helper function to determine appropriate HSGP parameters, though we may need to adjust based on the spatial scale of variation in mineral concentrations.

In [None]:
# Determine appropriate m and c for Walker Lake spatial scale
x_range = [X_2d.min(), X_2d.max()]  # Same range for both dimensions
lengthscale_range = [10.0, 100.0]  # Expected spatial correlation in meters

m_2d, c_2d = pm.gp.hsgp_approx.approx_hsgp_hyperparams(
    x_range=x_range,
    lengthscale_range=lengthscale_range,
    cov_func='matern52'
)

print(f"2D HSGP recommendations: m={m_2d}, c={c_2d:.2f}")
print(f"Total basis functions: {m_2d**2}")

# Standardize the concentration data for better sampling
y_mean = y_2d.mean()
y_std = y_2d.std()
y_2d_std = (y_2d - y_mean) / y_std

with pm.Model() as hsgp_2d_model:

    ell = pm.Gamma('ell', alpha=2, beta=0.05)  # Prior centered around 40 meters
    eta = pm.HalfNormal('eta', sigma=2)
    
    cov = eta**2 * pm.gp.cov.Matern52(2, ls=ell)
    
    gp = pm.gp.HSGP(
        m=[m_2d, m_2d],   # m for each dimension
        c=c_2d,           # c applies to both
        cov_func=cov
    )
    
    f = gp.prior('f', X=X_2d)
    sigma = pm.HalfNormal('sigma', sigma=1)
    y_obs = pm.Normal('y_obs', mu=f, sigma=sigma, observed=y_2d_std)
    
    idata_2d = pm.sample(
        500,
        tune=500,
        chains=2,
        random_seed=RANDOM_SEED,
        nuts_sampler="nutpie"
    )

### Visualizing the 2D HSGP Fit

Let's create a predicted surface over a regular grid to visualize the spatial pattern learned by the HSGP. We'll interpolate the GP to a fine grid covering the Walker Lake area, then compare the predictions to the observed data.

In [None]:
# Create a prediction grid
n_grid = 40
x1_grid = np.linspace(X_2d[:, 0].min(), X_2d[:, 0].max(), n_grid)
x2_grid = np.linspace(X_2d[:, 1].min(), X_2d[:, 1].max(), n_grid)
X1_mesh, X2_mesh = np.meshgrid(x1_grid, x2_grid)
X_pred = np.column_stack([X1_mesh.ravel(), X2_mesh.ravel()])

# Get posterior predictions
with hsgp_2d_model:
    f_pred = gp.conditional('f_pred', X_pred)
    posterior_pred_2d = pm.sample_posterior_predictive(
        idata_2d,
        var_names=['f_pred'],
        random_seed=RANDOM_SEED
    )

In [None]:
# Extract posterior mean and convert back to original scale
f_post_mean = posterior_pred_2d.posterior_predictive['f_pred'].mean(dim=['chain', 'draw']).values
f_post_mean_original = f_post_mean * y_std + y_mean
f_post_grid = f_post_mean_original.reshape(n_grid, n_grid)

# Determine shared color range
vmin = min(y_2d.min(), f_post_mean_original.min())
vmax = max(y_2d.max(), f_post_mean_original.max())

# Create side-by-side comparison with shared styling
make_subplots(
    rows=1, cols=2,
    subplot_titles=('HSGP Posterior Mean Surface', 'Observed Data Locations'),
    horizontal_spacing=0.12,
    vertical_spacing=0.08
).add_trace(
    go.Heatmap(
        z=f_post_grid,
        x=x1_grid,
        y=x2_grid,
        colorscale='Viridis',
        zmin=vmin,
        zmax=vmax,
        showscale=False,
        hovertemplate='X: %{x:.0f}m<br>Y: %{y:.0f}m<br>V: %{z:.1f} ppm<extra></extra>'
    ),
    row=1, col=1
).add_trace(
    go.Scatter(
        x=X_2d[:, 0],
        y=X_2d[:, 1],
        mode='markers',
        marker=dict(
            size=8,
            color=y_2d,
            colorscale='Viridis',
            cmin=vmin,
            cmax=vmax,
            showscale=True,
            colorbar=dict(
                title=dict(text='V (ppm)', side='right'),
                x=1.0,
                thickness=15,
                len=0.65,
                xpad=10
            ),
            line=dict(width=0.5, color='white')
        ),
        hovertemplate='X: %{x:.0f}m<br>Y: %{y:.0f}m<br>V: %{marker.color:.1f} ppm<extra></extra>'
    ),
    row=1, col=2
).update_xaxes(
    showgrid=False,
    zeroline=False,
    showline=False,
    ticks='outside',
    ticklen=5,
    row=1, col=1
).update_xaxes(
    showgrid=False,
    zeroline=False,
    showline=False,
    ticks='outside',
    ticklen=5,
    row=1, col=2
).update_yaxes(
    showgrid=False,
    zeroline=False,
    showline=False,
    ticks='outside',
    ticklen=5,
    scaleanchor='x',
    scaleratio=1,
    row=1, col=1
).update_yaxes(
    showgrid=False,
    zeroline=False,
    showline=False,
    ticks='outside',
    ticklen=5,
    scaleanchor='x2',
    scaleratio=1,
    row=1, col=2
).update_layout(
    title=dict(
        text='2D HSGP Fit to Walker Lake Data',
        x=0.5,
        xanchor='center',
        font=dict(size=16)
    ),
    height=520,
    plot_bgcolor='white',
    paper_bgcolor='white',
    showlegend=False,
    margin=dict(l=80, r=120, t=100, b=80),
    annotations=[
        dict(
            text='X Location (meters)',
            xref='paper', yref='paper',
            x=0.5, y=-0.12,
            xanchor='center', yanchor='top',
            showarrow=False,
            font=dict(size=14)
        ),
        dict(
            text='Y Location (meters)',
            xref='paper', yref='paper',
            x=-0.08, y=0.5,
            xanchor='center', yanchor='middle',
            textangle=-90,
            showarrow=False,
            font=dict(size=14)
        )
    ]
)

### Understanding 2D HSGP Performance on Real Spatial Data

The HSGP successfully learned the spatial structure of mineral concentrations across Walker Lake. Notice several important features:

1. **Smooth spatial interpolation**: The predicted surface (left panel) provides smooth estimates even in areas between observations, capturing the underlying spatial pattern while avoiding overfitting to individual noisy measurements.

2. **Irregular sampling handled naturally**: Unlike grid-based methods, the HSGP works seamlessly with the irregular sampling pattern (right panel), where some areas have dense measurements and others are sparse.

3. **Computational efficiency**: Despite having several hundred observations and predicting on a fine grid (1,600 locations), the HSGP completed fitting and prediction efficiently using basis function expansion.

For 2D problems like this spatial dataset, remember that the total number of basis functions is $m_1 \times m_2$. This is still far more efficient than exact inference, but it shows why HSGP doesn't scale well beyond 3 dimensionsâ€”the basis functions multiply quickly! For higher-dimensional problems, sparse GP methods or other approximations may be more suitable.

### ðŸ¤– EXERCISE: Exploring HSGP Parameter Choices

Now experiment with different HSGP parameter configurations.

In [None]:
# Use your LLM to help explore HSGP parameter choices

# STEP 1: Ask your LLM to help you implement this function
def tune_hsgp_params(X, y, m_grid=(30, 60, 120), c_grid=(1.5, 2.5, 4.0)):
    """
    Grid-search m and c to evaluate speed and accuracy trade-offs.
    
    Prompt suggestion: "Help me create a function that runs multiple HSGP fits
    over m and c grid and summarizes performance with Plotly heatmaps."
    """
    # YOUR LLM-ASSISTED CODE HERE
    pass

# STEP 2: Test on Walker Lake dataset
# Visualize RMSE and sampling time as heatmaps

## Section 3.7: Comparing All Approaches

We've now explored both sparse GPs and HSGP approximations in detail. Let's bring everything together with a comprehensive comparison that highlights when to use each approach.

### Computational Complexity Summary

Let's visualize the computational complexity of each approach as a function of dataset size.

In [None]:
n_values = np.logspace(2, 4, 50)  # 100 to 10,000 data points
m_sparse = 100  # inducing points
m_hsgp = 100    # basis functions

# Relative computational cost (arbitrary units)
cost_exact = n_values**3 / 1e6  # Scale for visibility
cost_sparse = n_values * m_sparse**2 / 1e6
cost_hsgp = n_values * m_hsgp / 1e6

go.Figure().add_trace(go.Scatter(
    x=n_values,
    y=cost_exact,
    mode='lines',
    name='Exact GP: O(nÂ³)',
    line=dict(color='blue', width=3)
)).add_trace(go.Scatter(
    x=n_values,
    y=cost_sparse,
    mode='lines',
    name=f'Sparse GP: O(nmÂ²), m={m_sparse}',
    line=dict(color='green', width=3)
)).add_trace(go.Scatter(
    x=n_values,
    y=cost_hsgp,
    mode='lines',
    name=f'HSGP: O(nm), m={m_hsgp}',
    line=dict(color='red', width=3)
)).update_layout(
    title='Computational Complexity: Exact GP vs. Approximations',
    xaxis_title='Number of data points (n)',
    yaxis_title='Relative computational cost',
    xaxis_type='log',
    yaxis_type='log',
    height=500,
    showlegend=True,
    hovermode='x unified'
)

### Decision Guide: Which Method to Use

Here's practical guidance for choosing between methods, based on evidence from PyMC reference materials:

**Use Standard (Exact) GP when:**
- Dataset is relatively small (typically $n < 1,000$, though this depends on hardware)
- You need exact inference without approximation error
- You're using **non-stationary kernels** (e.g., `Linear`, `Polynomial`)
- Computation time isn't critical

**Use Sparse GP (Inducing Points) when:**
- Moderate-to-large datasets where exact GP is too slow
- Data has **uneven sampling density** (inducing points can be placed strategically)
- **Lengthscale is larger than separation between inducing points** (critical condition for good approximation)
- You can use domain knowledge or K-means to select good inducing point locations
- You're willing to tune the number and placement of inducing points ($m$)

**Use HSGP when:**
- Large datasets (computational advantage grows with $n$)
- Using **stationary kernels only** (MatÃ©rn, ExpQuadâ€”kernel must implement `power_spectral_density` method)
- **Input dimension is low** (1-3 dimensions; doesn't scale well beyond 3D)
- Process doesn't vary **extremely rapidly** relative to domain extent
- You need to integrate the GP into a larger hierarchical model
- You need predictions at many new locations (linear scaling advantage)

**Practical tips:**

1. **Prototyping**: Start with a low-fidelity HSGP (small $m$) for fast iteration. Once you understand the relevant lengthscales, use `pm.gp.hsgp_approx_hsgp_hyperparams()` or manually dial in appropriate $m$ and $c$ values.

2. **Approximation quality**: Be aware that low-fidelity approximations may sometimes give more parsimonious fits than high-fidelity versions.

3. **Inducing points**: Can be selected via K-means (`pm.gp.util.kmeans_inducing_points()`), as a subset of data, or optimized as model parameters.

## Section 3.8: Automatic Relevance Determination (ARD)

So far, we've focused on computational scaling for large datasets using sparse GPs and HSGP. But there's another challenge when working with real-world data: **handling many input features where not all are equally important**.

When you have dozens of potential predictors, manually selecting which features to include becomes tedious and risks overfitting. **Automatic Relevance Determination (ARD)** solves this by assigning a separate lengthscale to each input dimension, allowing the GP to automatically learn which features matter.

Think of ARD as built-in feature selection: relevant dimensions get small lengthscales (the model pays close attention to changes in these features), while irrelevant dimensions get large lengthscales (the model becomes insensitive to them, effectively removing them from predictions).

Let's see ARD in action on a real dataset where we genuinely don't know which features are most important.

### Loading the Boston Housing Dataset

Instead of synthetic data, let's use a real-world dataset: the Boston housing dataset. This classic dataset contains information about housing in Boston suburbs from the 1970s, with 13 features describing each neighborhood and the median home value.

The features include:
- **CRIM**: Per capita crime rate by town
- **ZN**: Proportion of residential land zoned for large lots
- **INDUS**: Proportion of non-retail business acres
- **CHAS**: Charles River dummy variable (1 if tract bounds river)
- **NOX**: Nitric oxides concentration (pollution)
- **RM**: Average number of rooms per dwelling
- **AGE**: Proportion of owner-occupied units built before 1940
- **DIS**: Weighted distances to employment centers
- **RAD**: Index of accessibility to radial highways
- **TAX**: Property tax rate
- **PTRATIO**: Pupil-teacher ratio
- **LSTAT**: Percentage of lower status population
- **MEDV**: Median value of owner-occupied homes (target)

With ARD, we can fit a GP using **all these features** and let the model automatically discover which ones are most relevant for predicting house prices. Features that matter will get small lengthscales (the model pays close attention), while irrelevant features will get large lengthscales (the model effectively ignores them).

In [None]:
# Load Boston housing data
boston_df = pl.read_csv(DATA_DIR + 'HousingData.csv')

# Drop rows with missing values
boston_df = boston_df.drop_nulls()

# Extract all features except MEDV (target) and B (excluded for ethical reasons)
feature_cols = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'LSTAT']
X_boston = boston_df.select(feature_cols).to_numpy()
y_boston = boston_df['MEDV'].to_numpy()

# Standardize features for better GP performance
X_mean = X_boston.mean(axis=0)
X_std = X_boston.std(axis=0)
X_boston_std = (X_boston - X_mean) / X_std

# Standardize target
y_mean = y_boston.mean()
y_std = y_boston.std()
y_boston_std = (y_boston - y_mean) / y_std

print(f"Boston Housing Dataset: {X_boston_std.shape[0]} observations, {X_boston_std.shape[1]} features")
print(f"\nFeature names: {feature_cols}")
print(f"\nTarget (MEDV): mean=${y_boston.mean():.1f}k, std=${y_boston.std():.1f}k, range=[${y_boston.min():.1f}k, ${y_boston.max():.1f}k]")

### Fitting the ARD Model

Now we'll fit a GP with ARD by giving each of the 12 features its own lengthscale parameter. Notice how we specify `dims="features"` for the lengthscale - this creates a vector of 12 lengthscales, one per feature.

The model will simultaneously:
1. Learn the overall amplitude and noise level
2. Learn a separate lengthscale for each feature
3. Use these lengthscales to make predictions

**The key insight**: small lengthscales indicate relevance (function changes rapidly with that feature), while large lengthscales indicate irrelevance (covariance nearly constant across that feature's range).

In [None]:
with pm.Model(coords={"features": feature_cols}) as ard_model:

    # Separate lengthscale for each feature (ARD)
    ls = pm.Gamma("ls", alpha=2, beta=1, dims="features")
    eta = pm.HalfNormal("eta", sigma=2)
    
    # ExpQuad kernel with ARD
    cov_func = eta**2 * pm.gp.cov.ExpQuad(input_dim=12, ls=ls)
    gp = pm.gp.Marginal(cov_func=cov_func)
    
    # Observation noise
    sigma = pm.HalfNormal("sigma", sigma=0.5)
    
    # Likelihood
    y_ = gp.marginal_likelihood("y", X=X_boston_std, y=y_boston_std, sigma=sigma)
    
    # Sample posterior
    trace_ard = pm.sample(
        500, 
        tune=500, 
        nuts_sampler="nutpie", 
        random_seed=RANDOM_SEED, 
        chains=2
    )

### Visualizing Learned Lengthscales by Feature

Let's examine which features the model learned are most important. We'll create a bar plot showing the posterior mean lengthscale for each feature.

In [None]:
az.plot_forest(trace_ard, var_names=['ls'], figsize=(6, 12), combined=True, rope=[0,4]);

### Interpreting the ARD Results

The forest plot above shows the posterior distributions of lengthscales for each feature. Because we **standardized all input features** to have mean 0 and standard deviation 1, these lengthscales are directly comparable across features - this is crucial for ARD interpretation.

#### Understanding Lengthscales with Standardized Features

A lengthscale represents **the distance you need to move along a feature's axis for function values to become uncorrelated**. With standardized features (mean=0, std=1):

- **Highly relevant (< 1)**: Moving just one standard deviation causes substantial decorrelation. The model sees the function changing rapidly with this feature.

- **Moderately relevant (1-3)**: The function changes moderately. These features contribute to predictions but with less sensitivity.

- **Weakly relevant (> 10)**: The covariance becomes nearly constant across this feature's range. Moving even 10 standard deviations barely affects predictions.

- **Effectively irrelevant (> 100)**: The model has learned to ignore this feature entirely. The covariance is nearly independent of this input.

#### What to Look for in the Forest Plot

Examine the lengthscale distributions shown above:

1. **Left side of the plot (small lengthscales)**: Features with tight distributions clustered near 0-2 are the most important predictors. These typically include features like **LSTAT** (% lower status population), **RM** (avg rooms), and **PTRATIO** (pupil-teacher ratio) - all well-known drivers of housing prices.

2. **Middle range (lengthscales 3-10)**: Features that contribute moderately to predictions. They matter, but changes in these features have less dramatic effects on house prices.

3. **Right side of the plot (large lengthscales > 10)**: Features the model has learned to downweight or ignore. This could indicate true irrelevance or that their information is captured by other correlated features. For example, **CHAS** (Charles River proximity) affects relatively few houses in the dataset.

#### The Power of ARD

This automatic feature discovery happens **during model fitting** without any manual intervention. We didn't need to:
- Manually select features beforehand
- Run separate feature importance analyses  
- Fit multiple models and compare

The learned lengthscales serve a dual purpose: they're both **correlation parameters** (controlling how smooth the GP is along each dimension) and **importance weights** (telling us which features actually matter for predictions). This makes ARD an elegant solution to the feature selection problem in high-dimensional GP regression.

### Connecting ARD to Scaling Methods

ARD is particularly valuable when combined with the scaling methods we explored earlier. For high-dimensional problems with many features:

- **ARD identifies which features matter**, potentially reducing the effective dimensionality
- **HSGP** can then efficiently handle the relevant dimensions (though remember HSGP works best for 1-3 dimensions)
- **Sparse GPs** can scale to larger datasets while still using ARD kernels

The combination of ARD for feature selection and approximation methods for computational scaling allows you to tackle real-world problems with both many observations and many features.

This completes our tour of GP scaling methods. You now have the tools to apply GPs to datasets that would be intractable with exact inference, while automatically discovering which features drive your predictions.