# Session 3: Scaling with HSGP and Sparse Methods

**Duration:** 2â€“3 hours

## Learning Objectives

By the end of this session, you will be able to:

- Understand the computational bottlenecks of standard Gaussian processes
- Apply inducing point methods and sparse approximations to scale GPs to larger datasets
- Implement Hilbert Space GP (HSGP) approximations in PyMC
- Choose appropriate approximation parameters using helper functions and heuristics
- Navigate the trade-offs between approximation fidelity and computational efficiency

## Introduction

In the previous sessions, we explored the foundations of Gaussian processes and built models with various kernels and likelihoods. However, you may have noticed that as datasets grow larger, GP computations become increasingly expensive. The standard GP formulation requires inverting an $n \times n$ covariance matrix, where $n$ is the number of data points. This operation has $\mathcal{O}(n^3)$ computational complexity and $\mathcal{O}(n^2)$ memory requirementsâ€”quickly becoming prohibitive for datasets with thousands of observations.

In this session, we'll explore two powerful approaches to overcome these computational barriers: sparse GP approximations using inducing points, and the Hilbert Space GP (HSGP) method. These techniques allow us to apply GP models to much larger datasets while maintaining the flexibility and uncertainty quantification that make GPs so valuable.

Think of these methods as strategic compromises: we trade away some exactness in our GP representation to gain massive improvements in speed and scalability. The key question we'll answer throughout this session is: *how do we make this trade-off intelligently?*

## Setup

Let's begin by importing our standard libraries and setting up our environment. We'll use the same stack as in previous sessions: PyMC for modeling, Polars for data manipulation, and Plotly for interactive visualization.

In [None]:
import pymc as pm
import numpy as np
import polars as pl
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import arviz as az
import pytensor.tensor as pt
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

# Set random seed for reproducibility
RANDOM_SEED = 42
rng = np.random.default_rng(RANDOM_SEED)

# Print versions for reproducibility
print(f"PyMC version: {pm.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"Polars version: {pl.__version__}")
print(f"ArviZ version: {az.__version__}")

## Section 3.1: Understanding the Computational Challenge

Before diving into solutions, let's develop intuition for *why* standard GPs become computationally expensive. The bottleneck lies in computing and inverting the covariance matrix.

For a GP with $n$ observations, we need to:

1. **Compute** the $n \times n$ covariance matrix $K$ by evaluating the kernel function at all pairs of data points
2. **Invert** this matrix (or equivalently, solve a linear system) to compute the marginal likelihood
3. **Repeat** these operations at every step during MCMC sampling as hyperparameters change

The matrix inversion step dominates the computational cost, scaling as $\mathcal{O}(n^3)$. This cubic scaling means that doubling your dataset size increases computation time by roughly 8Ã—. For a dataset with 10,000 points, a full GP could take hours or days to fit, making interactive model development essentially impossible.

Let's visualize this by looking at the structure of covariance matrices for different dataset sizes. We'll use a simple squared exponential kernel and observe how the matrices grow.

In [None]:
def visualize_covariance_matrices():
    """Visualize how covariance matrix size grows with data."""
    
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=['n=50', 'n=200', 'n=1000'],
        horizontal_spacing=0.1
    )
    
    sizes = [50, 200, 1000]
    
    for idx, n in enumerate(sizes, 1):
        # Generate data
        x = np.linspace(0, 10, n)[:, None]
        
        # Create covariance matrix
        cov_func = pm.gp.cov.ExpQuad(1, ls=1.0)
        K = cov_func(x).eval()
        
        # Add to subplot
        fig.add_trace(
            go.Heatmap(
                z=K,
                colorscale='Viridis',
                showscale=(idx == 3),
                colorbar=dict(title="Covariance")
            ),
            row=1, col=idx
        )
    
    fig.update_xaxes(title_text="Data point index")
    fig.update_yaxes(title_text="Data point index", row=1, col=1)
    fig.update_layout(
        title_text="Covariance Matrix Size Grows Quadratically with Data",
        height=400,
        showlegend=False
    )
    
    return fig

fig = visualize_covariance_matrices()
fig.show()

### Interpreting the Visualization

Look at how the covariance matrices expand as we increase $n$ from 50 to 1000. Each pixel represents a covariance calculation between two data points. The bright diagonal shows that each point is perfectly correlated with itself (covariance = 1). The off-diagonal elements show how points are correlated with each other, with the correlation decaying as points become more distant.

The key insight: **the number of elements grows quadratically** ($n^2$), but the computational cost of inverting this matrix grows **cubically** ($n^3$). When $n=1000$, we're working with a million-element matrix that requires a billion operations to invertâ€”and we need to do this at every MCMC step!

This is where approximation methods become essential. Rather than abandoning GPs for large datasets, we can use clever mathematical tricks to reduce computational complexity while retaining most of the modeling flexibility.

## Section 3.2: Sparse GP Approximations with Inducing Points

The first approach to scaling GPs is the **sparse** or **inducing point** method. The central idea is elegant: instead of computing correlations between all $n$ data points, we select a smaller set of $m < n$ "inducing points" (also called "pseudo-inputs") that act as summary locations. These inducing points capture the essential structure of the function, allowing us to approximate the full GP at a fraction of the computational cost.

Think of inducing points like strategic observation posts. If you're trying to understand temperature variations across a city, you don't need thermometers at every houseâ€”carefully placed weather stations at key locations can give you an excellent picture of the temperature field everywhere.

Mathematically, the approximation reduces complexity from $\mathcal{O}(n^3)$ to $\mathcal{O}(nm^2)$. When $m \ll n$, this is a massive speedup. The trade-off is that our approximation quality depends on how well the $m$ inducing points can represent the underlying function.

PyMC implements three variants of sparse GPs:

- **DTC (Deterministic Training Conditional)**: The simplest approach, which can underestimate uncertainty
- **FITC (Fully Independent Training Conditional)**: Adds back point-specific noise variance, improving uncertainty estimates
- **VFE (Variational Free Energy)**: A variational approach that can optimize inducing point locations

For most practical applications, FITC provides a good balance between accuracy and simplicity. Let's see it in action with a moderately large dataset.

### Generating Data for Sparse GP Demonstration

We'll create a dataset with 2000 observationsâ€”large enough to make standard GP inference slow, but small enough to allow us to compare against the exact solution. Our data will be drawn from a GP with a MatÃ©rn 5/2 kernel and moderate noise.

In [None]:
# Set parameters for data generation
n = 2000
ell_true = 1.0
eta_true = 3.0
sigma_true = 0.5

# Generate input locations
x = 10 * np.sort(rng.random(n))

# Define true covariance function
cov_func = eta_true**2 * pm.gp.cov.Matern52(1, ell_true)

# Sample the latent GP function
K = cov_func(x[:, None]).eval()
K_stable = K + 1e-8 * np.eye(n)  # Add jitter for numerical stability
f_true = rng.multivariate_normal(np.zeros(n), K_stable)

# Add observation noise
y = f_true + sigma_true * rng.standard_normal(n)

# Create a Polars DataFrame
df = pl.DataFrame({
    'x': x,
    'y': y,
    'f_true': f_true
})

print(f"Generated {len(df)} observations")
print(f"x range: [{df['x'].min():.2f}, {df['x'].max():.2f}]")
print(f"y range: [{df['y'].min():.2f}, {df['y'].max():.2f}]")

### Visualizing the Data

Let's plot our simulated data. With 2000 points, we'll use transparency to show the density while still seeing the underlying smooth function.

In [None]:
# Plot the data
fig = go.Figure()

# Subsample for clearer visualization
subsample_idx = rng.choice(len(df), size=500, replace=False)
subsample_idx = np.sort(subsample_idx)

fig.add_trace(go.Scatter(
    x=df['x'][subsample_idx],
    y=df['y'][subsample_idx],
    mode='markers',
    name='Observed data (subsample)',
    marker=dict(size=3, color='gray', opacity=0.5)
))

fig.add_trace(go.Scatter(
    x=df['x'],
    y=df['f_true'],
    mode='lines',
    name='True latent function',
    line=dict(color='dodgerblue', width=2)
))

fig.update_layout(
    title='Simulated GP Data (2000 points)',
    xaxis_title='x',
    yaxis_title='y',
    height=400,
    showlegend=True
)

fig.show()

### Choosing Inducing Points with K-Means

A practical strategy for selecting inducing points is to use K-means clustering. This places inducing points at cluster centers in the input space, naturally concentrating them where we have more data while still covering the entire domain.

We'll use $m=20$ inducing pointsâ€”a 100Ã— reduction in effective data size. This gives us $\mathcal{O}(2000 \times 20^2) = \mathcal{O}(800\text{K})$ operations instead of $\mathcal{O}(2000^3) = \mathcal{O}(8\text{B})$ operationsâ€”roughly a 10,000Ã— speedup!

In [None]:
# Use K-means to select inducing points
m = 20  # Number of inducing points

kmeans = KMeans(n_clusters=m, random_state=RANDOM_SEED, n_init=10)
kmeans.fit(x[:, None])
Xu = np.sort(kmeans.cluster_centers_.flatten())

print(f"Selected {m} inducing points using K-means")
print(f"Inducing points span: [{Xu.min():.2f}, {Xu.max():.2f}]")

Let's visualize where K-means placed our inducing points relative to the data density.

In [None]:
fig = go.Figure()

# Data histogram
fig.add_trace(go.Histogram(
    x=df['x'],
    nbinsx=50,
    name='Data density',
    marker=dict(color='lightblue', opacity=0.6),
    yaxis='y2'
))

# Inducing points
fig.add_trace(go.Scatter(
    x=Xu,
    y=np.zeros(m),
    mode='markers',
    name='Inducing points',
    marker=dict(size=10, color='cyan', symbol='x', line=dict(width=2))
))

fig.update_layout(
    title='Inducing Point Locations from K-Means',
    xaxis_title='x',
    yaxis_title='',
    yaxis2=dict(title='Count', overlaying='y', side='right'),
    height=300,
    showlegend=True
)

fig.show()

### Interpreting Inducing Point Placement

Notice how K-means distributes the inducing points fairly evenly across the input domain. Since our data is uniformly distributed, the inducing points spread out to cover the range. This is exactly what we want: the inducing points act as strategic summary locations that can represent the entire function.

The key insight is that these $m=20$ points don't need to be at data locationsâ€”they're auxiliary variables that help us approximate the GP efficiently. Think of them as anchor points that define a lower-dimensional representation of the function.

### Building the Sparse GP Model with FITC

Now we'll build our sparse GP model using PyMC's `MarginalApprox` class (the modern replacement for `MarginalSparse`) with the FITC approximation. Notice how the model specification is nearly identical to a standard GPâ€”we just provide the inducing point locations `Xu` and specify the approximation type.

In [None]:
with pm.Model() as sparse_model:
    # Priors on hyperparameters
    ell = pm.Gamma('ell', alpha=2, beta=1)
    eta = pm.HalfNormal('eta', sigma=5)
    
    # Define covariance function
    cov = eta**2 * pm.gp.cov.Matern52(1, ls=ell)
    
    # Sparse GP with FITC approximation
    gp = pm.gp.MarginalApprox(cov_func=cov, approx='FITC')
    
    # Observation noise
    sigma = pm.HalfNormal('sigma', sigma=2)
    
    # Marginal likelihood
    y_obs = gp.marginal_likelihood(
        'y_obs',
        X=x[:, None],
        Xu=Xu[:, None],
        y=y,
        sigma=sigma
    )
    
    # Sample posterior
    idata_sparse = pm.sample(
        1000,
        tune=1000,
        random_seed=RANDOM_SEED,
    )

### Examining the Posterior

Let's check the posterior distributions of our hyperparameters and verify that sampling was successful.

In [None]:
# Summary statistics
summary = az.summary(
    idata_sparse,
    var_names=['ell', 'eta', 'sigma'],
    round_to=2
)
print(summary)

In [None]:
# Trace plot
az.plot_trace(
    idata_sparse,
    var_names=['ell', 'eta', 'sigma'],
    figsize=(10, 6)
)
plt.tight_layout()
plt.show()

### Interpreting the Results

Look at the trace plots and summary statistics. We should see good mixing (the traces look like "hairy caterpillars"), high effective sample sizes (ESS), and $\hat{R}$ values close to 1.0. These diagnostics tell us that the sampler successfully explored the posterior, despite using only 20 inducing points to represent 2000 data points.

The posterior means should be close to our true values (lengthscale=1.0, amplitude=3.0, noise=0.5), though with some uncertainty since we're working with finite data and an approximation.

### Making Predictions with the Sparse GP

One of the benefits of the sparse GP approximation is that prediction is also fast. Let's make predictions at a dense grid of test points and visualize the posterior predictive distribution.

In [None]:
# Create test points
x_test = np.linspace(-0.5, 10.5, 300)

# Add conditional distribution to model and sample
with sparse_model:
    f_pred = gp.conditional('f_pred', x_test[:, None])
    
    # Sample posterior predictive
    posterior_pred = pm.sample_posterior_predictive(
        idata_sparse,
        var_names=['f_pred'],
        random_seed=RANDOM_SEED,
    )

In [None]:
# Compute posterior summary
f_pred_mean = posterior_pred.posterior_predictive['f_pred'].mean(dim=['chain', 'draw']).values
f_pred_std = posterior_pred.posterior_predictive['f_pred'].std(dim=['chain', 'draw']).values

# Plot results
fig = go.Figure()

# Credible interval
fig.add_trace(go.Scatter(
    x=np.concatenate([x_test, x_test[::-1]]),
    y=np.concatenate([f_pred_mean + 2*f_pred_std, (f_pred_mean - 2*f_pred_std)[::-1]]),
    fill='toself',
    fillcolor='rgba(255,0,0,0.2)',
    line=dict(color='rgba(255,0,0,0)'),
    name='95% Credible Interval',
    showlegend=True
))

# Posterior mean
fig.add_trace(go.Scatter(
    x=x_test,
    y=f_pred_mean,
    mode='lines',
    name='Posterior Mean',
    line=dict(color='red', width=2)
))

# True function (interpolated to test points)
f_true_interp = np.interp(x_test, x, f_true)
fig.add_trace(go.Scatter(
    x=x_test,
    y=f_true_interp,
    mode='lines',
    name='True Function',
    line=dict(color='dodgerblue', width=2, dash='dash')
))

# Data subsample
subsample_idx = rng.choice(len(df), size=200, replace=False)
fig.add_trace(go.Scatter(
    x=df['x'][subsample_idx],
    y=df['y'][subsample_idx],
    mode='markers',
    name='Data (subsample)',
    marker=dict(size=3, color='gray', opacity=0.5)
))

# Inducing points
fig.add_trace(go.Scatter(
    x=Xu,
    y=np.ones(len(Xu)) * (f_pred_mean.min() - 1),
    mode='markers',
    name='Inducing Points',
    marker=dict(size=10, color='cyan', symbol='x', line=dict(width=2))
))

fig.update_layout(
    title='Sparse GP Predictions with FITC Approximation',
    xaxis_title='x',
    yaxis_title='f(x)',
    height=500,
    showlegend=True,
    hovermode='x unified'
)

fig.show()

### Understanding the Sparse GP Fit

This plot reveals several important features of the sparse GP approximation:

1. **The posterior mean (red line) closely tracks the true function (blue dashed line)**, demonstrating that just 20 inducing points can effectively represent the smooth underlying pattern in 2000 observations.

2. **The credible intervals appropriately capture uncertainty**, widening slightly in regions with fewer nearby inducing points and remaining tight where inducing points are dense.

3. **The inducing points (cyan X markers at bottom) are strategically distributed** across the domain, acting as anchor points for the approximation.

The key takeaway: we've achieved dramatic computational savings while maintaining excellent approximation quality. For smooth functions and well-placed inducing points, the sparse GP delivers results nearly indistinguishable from the exact GP. The downside of sparse approximations is that they reduce the expressiveness of the GPâ€”reducing the dimension of the covariance matrix effectively reduces the number of eigenvectors that can be used to fit the data.

### Exercise: Sparse GP with Inducing Points

Now it's your turn to experiment with sparse GPs and explore how the number of inducing points affects approximation quality.

In [None]:
# ðŸ¤– EXERCISE: Use your LLM to help implement a sparse GP

# STEP 1: Ask your LLM to help you implement this function
def sparse_gp_with_kmeans(X, y, M=200):
    """
    Build pm.gp.MarginalApprox using KMeans to initialize M inducing points.
    
    Prompt suggestion: "Help me set up MarginalApprox with KMeans initialization
    for a dataset, including sampling and prediction."
    """
    # YOUR LLM-ASSISTED CODE HERE
    pass

# STEP 2: Test your implementation
# Try different values of M (e.g., 10, 50, 100) and compare fit quality vs speed

## Section 3.3: Hilbert Space GP (HSGP) Theory

While sparse GPs use inducing points to reduce complexity, the Hilbert Space GP (HSGP) takes a completely different approach: it approximates the GP using a **basis function expansion**. This transforms the non-parametric GP into a parametric model with a fixed number of basis functions, making it compatible with standard MCMC samplers and dramatically improving computational efficiency.

The mathematical foundation of HSGP comes from spectral analysis of covariance functions. Any stationary covariance kernel can be represented through its **power spectral density**â€”essentially a Fourier transform that describes the kernel's behavior in frequency space. The HSGP approximation uses a finite set of basis functions (sinusoids) whose coefficients are drawn from a distribution determined by this spectral density.

### Why This Matters

Think of it this way: instead of defining a function through all pairwise correlations (which requires $n^2$ parameters and $n^3$ operations), HSGP defines it through $m$ basis function coefficients. These basis functions are pre-computed and don't depend on hyperparameters, so we only need to update the coefficients during sampling.

The computational complexity drops from $\mathcal{O}(n^3)$ for exact GPs to $\mathcal{O}(nm + m)$ for HSGP, where $m$ is the number of basis functions. Even better, HSGP is fully parametricâ€”we can use `pm.set_data` for predictions without explicitly computing conditional distributions. This makes it much easier to integrate an HSGP into your existing PyMC model.

Additionally, unlike many other GP approximations, HSGPs can be used anywhere within a model and with any likelihood function. This flexibility is a major advantage over methods like sparse GPs that work best with Gaussian likelihoods.

### HSGP Restrictions

The HSGP approximation does carry some restrictions:

1. It **can only be used with stationary covariance kernels** such as the MatÃ©rn family or ExpQuad. The kernel must implement the `power_spectral_density` method.
2. It **does not scale well with input dimension**. HSGP is a good choice for 1D processes (like time series) or 2D spatial processes, but likely not efficient beyond 3 dimensions.
3. It **may struggle with very rapidly varying processes**. If the process changes very quickly relative to the domain extent, you may need very large $m$ to accurately represent it.
4. **For smaller datasets, the full unapproximated GP may still be more efficient**.

### Key Parameters: m and c

HSGP approximations are controlled by two parameters:

- **m**: The number of basis functions. Larger $m$ gives better approximation quality but increases computational cost. Think of $m$ as the "resolution" of your approximationâ€”more basis functions can represent more complex, rapidly-varying patterns. Increasing $m$ helps the HSGP approximate GPs with smaller lengthscales.

- **c**: The boundary extension factor. HSGP basis functions are defined on a finite domain $[-L, L]$ where $L = c \cdot S$ and $S$ is half the range of your centered data. Larger $c$ values help approximate GPs with longer lengthscales and ensure predictions away from data aren't affected by boundary conditions. However, increasing $c$ may require increasing $m$ to compensate for loss of fidelity at smaller lengthscales.

The art of using HSGP effectively lies in choosing $m$ and $c$ appropriately for your data and expected lengthscales. Fortunately, PyMC provides a helper function to get you started.

### Visualizing HSGP Basis Functions

To build intuition, let's visualize what HSGP basis functions actually look like. These are the sinusoidal building blocks that will be combined to approximate our GP. Notice that we need to center the data firstâ€”this is an important requirement for HSGP.

In [None]:
# Create a centered grid
x_grid = np.linspace(-5, 5, 1000)

# Create subplots to show effect of L and m
fig, axs = plt.subplots(1, 3, figsize=(14, 4), sharey=True)

ylim = 0.55
axs[0].set_ylim([-ylim, ylim])
axs[1].set_yticks([])
axs[1].set_xlabel("x (centered)")
axs[2].set_yticks([])

# Change L as we create the basis vectors
L_options = [5.0, 6.0, 20.0]
m_options = [3, 3, 5]

for i, ax in enumerate(axs):
    L = L_options[i]
    m_val = m_options[i]
    
    eigvals = pm.gp.hsgp_approx.calc_eigenvalues(pt.as_tensor([L]), [m_val])
    phi = pm.gp.hsgp_approx.calc_eigenvectors(
        x_grid[:, None],
        pt.as_tensor([L]),
        eigvals,
        [m_val],
    ).eval()
    
    for j in range(phi.shape[1]):
        ax.plot(x_grid, phi[:, j])
    
    ax.set_xticks(np.arange(-5, 6, 5))
    
    S = 5.0
    c = L / S
    ax.text(-4.9, -0.45, f"L = {L}\nc = {c}", fontsize=12)
    ax.set_title(f"{m_val} basis functions")
    ax.set_xlabel("x (centered)")

axs[0].set_ylabel("Basis function value")
plt.suptitle("The Effect of Changing L on HSGP Basis Vectors", fontsize=14)
plt.tight_layout()
plt.show()

### Interpreting the Basis Functions

These plots reveal critical insights about HSGP basis functions:

**Left panel (L=5, c=1.0)**: When $L$ equals the data range, all basis vectors are forced to pinch to zero at the boundaries (at $x=-5$ and $x=5$). This means the HSGP approximation becomes poor near the edges of your data. This is why we need $c > 1$.

**Middle panel (L=6, c=1.2)**: With $c=1.2$, the basis functions extend beyond the data range and are no longer forced to zero at the data boundaries. This helps the approximation remain accurate across the entire domain. Values of $c$ around 1.2 are considered the minimum for reasonable approximations.

**Right panel (L=20, c=4.0, m=5)**: With larger $L$ or $c$, the basis functions become lower frequency (longer wavelength). Notice how the first basis function (blue) is nearly flatâ€”it's becoming partially unidentifiable with an intercept term. This is why we sometimes need to drop the first basis function, or increase $m$ to compensate.

Notice that the basis functions are sinusoids with increasing frequency. Lower-order basis functions capture long-range trends, while higher-order functions capture increasingly rapid oscillations. An HSGP approximation works by taking a weighted sum of these basis functions.

The key lessons:
- **Increasing $m$ helps approximate GPs with smaller lengthscales** (more basis functions = higher resolution)
- **Increasing $c$ or $L$ helps approximate GPs with larger lengthscales** but may require increasing $m$ to maintain fidelity at smaller lengthscales
- **Consider where predictions will be made**â€”they also need to be away from the boundary "pinch"

## Section 3.4: HSGP Implementation

Now let's implement an HSGP model and see it in action. We'll use the same dataset as before for direct comparison with the sparse GP.

### Choosing HSGP Parameters

PyMC provides a helper function `approx_hsgp_hyperparams` that suggests values for $m$ and $c$ based on:
- The range of your input data
- The range of lengthscales you expect (from your prior)
- The covariance function type

These recommendations are based on approximation error bounds derived in the HSGP literature. The heuristics help you choose $c$ large enough to handle the largest lengthscales you might fit, and $m$ large enough to accommodate the smallest lengthscales. Let's use this function to get started.

In [None]:
# Determine appropriate m and c
x_range = [x.min(), x.max()]
lengthscale_range = [0.5, 3.0]  # Based on our prior knowledge

m_recommended, c_recommended = pm.gp.hsgp_approx.approx_hsgp_hyperparams(
    x_range=x_range,
    lengthscale_range=lengthscale_range,
    cov_func='matern52'
)

print(f"Recommended m: {m_recommended}")
print(f"Recommended c: {c_recommended:.2f}")

# We'll use these values for our model
m_hsgp = m_recommended
c_hsgp = c_recommended

### Building the HSGP Model

The HSGP model specification in PyMC is remarkably similar to a standard GP. The key difference is that we use `pm.gp.HSGP` instead of `pm.gp.Latent` or `pm.gp.Marginal`, and we specify the approximation parameters $m$ and $c$. 

Notice that we use the `.prior` method just like with `pm.gp.Latent`. For basic usage, HSGP can be treated as a drop-in replacement for the standard latent GP.

In [None]:
with pm.Model() as hsgp_model:
    # Priors on hyperparameters (same as sparse GP)
    ell = pm.Gamma('ell', alpha=2, beta=1)
    eta = pm.HalfNormal('eta', sigma=5)
    
    # Define covariance function
    cov = eta**2 * pm.gp.cov.Matern52(1, ls=ell)
    
    # HSGP approximation
    gp = pm.gp.HSGP(m=[m_hsgp], c=c_hsgp, cov_func=cov)
    
    # Prior over the latent function
    f = gp.prior('f', X=x[:, None])
    
    # Observation noise
    sigma = pm.HalfNormal('sigma', sigma=2)
    
    # Likelihood
    y_obs = pm.Normal('y_obs', mu=f, sigma=sigma, observed=y)
    
    # Sample posterior
    idata_hsgp = pm.sample(
        1000,
        tune=1000,
        random_seed=RANDOM_SEED,
    )

### Examining HSGP Results

Let's check sampling diagnostics and posterior distributions for the HSGP model.

In [None]:
# Summary statistics
summary_hsgp = az.summary(
    idata_hsgp,
    var_names=['ell', 'eta', 'sigma'],
    round_to=2
)
print(summary_hsgp)

In [None]:
# Trace plot
az.plot_trace(
    idata_hsgp,
    var_names=['ell', 'eta', 'sigma'],
    figsize=(10, 6)
)
plt.tight_layout()
plt.show()

### Making Predictions with HSGP

One major advantage of HSGP is the ease of prediction. Since it's parametric, we can use the `.conditional` method just like with other GPs. Let's make predictions and visualize the fit.

In [None]:
# Create test points
x_test = np.linspace(-0.5, 10.5, 300)

with hsgp_model:
    f_pred_hsgp = gp.conditional('f_pred', x_test[:, None])
    
    # Sample posterior predictive
    posterior_pred_hsgp = pm.sample_posterior_predictive(
        idata_hsgp,
        var_names=['f_pred'],
        random_seed=RANDOM_SEED,
    )

In [None]:
# Compute posterior summary
f_hsgp_mean = posterior_pred_hsgp.posterior_predictive['f_pred'].mean(dim=['chain', 'draw']).values
f_hsgp_std = posterior_pred_hsgp.posterior_predictive['f_pred'].std(dim=['chain', 'draw']).values

# Plot HSGP results
fig = go.Figure()

# Credible interval
fig.add_trace(go.Scatter(
    x=np.concatenate([x_test, x_test[::-1]]),
    y=np.concatenate([f_hsgp_mean + 2*f_hsgp_std, (f_hsgp_mean - 2*f_hsgp_std)[::-1]]),
    fill='toself',
    fillcolor='rgba(139,0,139,0.2)',
    line=dict(color='rgba(139,0,139,0)'),
    name='95% Credible Interval',
    showlegend=True
))

# Posterior mean
fig.add_trace(go.Scatter(
    x=x_test,
    y=f_hsgp_mean,
    mode='lines',
    name='HSGP Posterior Mean',
    line=dict(color='darkviolet', width=2)
))

# True function
f_true_interp = np.interp(x_test, x, f_true)
fig.add_trace(go.Scatter(
    x=x_test,
    y=f_true_interp,
    mode='lines',
    name='True Function',
    line=dict(color='gold', width=3, dash='dash')
))

# Data subsample
subsample_idx = rng.choice(len(df), size=200, replace=False)
fig.add_trace(go.Scatter(
    x=df['x'][subsample_idx],
    y=df['y'][subsample_idx],
    mode='markers',
    name='Data (subsample)',
    marker=dict(size=3, color='gray', opacity=0.5)
))

fig.update_layout(
    title=f'HSGP Fit (m={m_hsgp}, c={c_hsgp:.2f})',
    xaxis_title='x',
    yaxis_title='f(x)',
    height=500,
    showlegend=True,
    hovermode='x unified'
)

fig.show()

### Understanding the HSGP Fit

The HSGP inferred posterior (purple) accurately matches the true underlying GP (gold dashed line). We also see that the credible intervals appropriately capture uncertainty. This demonstrates that even with an approximation using basis functions, we can achieve excellent fit quality.

Notice that with recommended parameters from `approx_hsgp_hyperparams`, the approximation is essentially indistinguishable from what an exact GP would produce. The computational cost, however, is dramatically lowerâ€”$\mathcal{O}(nm)$ instead of $\mathcal{O}(n^3)$.

### Comparing HSGP to Sparse GP

Let's directly compare the posterior distributions from the HSGP and sparse GP models using a forest plot, which is ideal for comparing multiple models side-by-side.

In [None]:
# Compare posteriors using forest plot
az.plot_forest(
    [idata_sparse, idata_hsgp],
    model_names=['Sparse GP', 'HSGP'],
    var_names=['ell', 'eta', 'sigma'],
    combined=True,
    figsize=(10, 5)
)
plt.tight_layout()
plt.show()

### Interpreting the Comparison

The posterior distributions from HSGP and sparse GP should be very similar, particularly for the lengthscale and amplitude parameters that control the function's smoothness and scale. Small differences are expected since both are approximations, but substantial disagreement would suggest that one or both approximations is inadequate.

Both methods successfully inferred hyperparameters close to the true values (lengthscale=1.0, amplitude=3.0, noise=0.5), demonstrating that either approach can work well for moderate-sized datasets with smooth underlying functions.

## Section 3.5: Advanced HSGP - Centered vs Non-Centered Parameterization

An important consideration when using HSGP is choosing between centered and non-centered parameterizations. This is analogous to the choice you make in hierarchical models, and for similar reasons: the correlation structure in the posterior.

### When to Use Each Parameterization

**Centered parameterization** works better when:
- The underlying GP is strongly informed by the data
- You have lots of data relative to the lengthscale
- The signal-to-noise ratio is high

**Non-centered parameterization** (the default) works better when:
- The underlying GP is weakly informed by the data  
- You have sparse data or large lengthscales
- The signal-to-noise ratio is low

In our example with 2000 noisy observations, the centered parameterization might actually be better. Let's test this.

In [None]:
with pm.Model() as hsgp_centered:
    # Same priors
    ell = pm.Gamma('ell', alpha=2, beta=1)
    eta = pm.HalfNormal('eta', sigma=5)
    
    cov = eta**2 * pm.gp.cov.Matern52(1, ls=ell)
    
    # HSGP with centered parameterization
    gp = pm.gp.HSGP(
        m=[m_hsgp], 
        c=c_hsgp, 
        cov_func=cov,
        parametrization='centered'  # Key difference!
    )
    
    f = gp.prior('f', X=x[:, None])
    sigma = pm.HalfNormal('sigma', sigma=2)
    y_obs = pm.Normal('y_obs', mu=f, sigma=sigma, observed=y)
    
    # Sample
    idata_hsgp_centered = pm.sample(
        1000,
        tune=1000,
        random_seed=RANDOM_SEED,
    )

In [None]:
# Compare sampling efficiency
print("Non-centered parameterization:")
print(az.summary(idata_hsgp, var_names=['ell', 'eta', 'sigma'])[['ess_bulk', 'ess_tail', 'r_hat']])
print("\nCentered parameterization:")
print(az.summary(idata_hsgp_centered, var_names=['ell', 'eta', 'sigma'])[['ess_bulk', 'ess_tail', 'r_hat']])

### Interpreting Parameterization Effects

Compare the effective sample sizes (ESS) between the two parameterizations. Higher ESS means more efficient samplingâ€”you're getting more independent samples per iteration. For this dataset with strong signal, you may find the centered parameterization provides better ESS.

The choice of parameterization doesn't affect what you're learning about the hyperparametersâ€”it only affects how efficiently the sampler explores the posterior. If you find sampling is slow or you see low ESS, try switching parameterizations.

### Visualizing the HSGP Approximate Gram Matrix

Another way to check HSGP fidelity is to directly compare the unapproximated Gram matrix (covariance matrix) $\mathbf{K}$ to the one resulting from the HSGP approximation:

$$
\tilde{\mathbf{K}} = \Phi \Lambda \Phi^T
$$

where $\Phi$ is the matrix of eigenvectors (basis functions), and $\Lambda$ has the spectral densities computed at the eigenvalues along the diagonal. Let's visualize this for different values of $m$ and $c$ to see when the approximation starts to degrade.

In [None]:
# Use a subset for clearer visualization
n_viz = 100
x_viz = np.linspace(0, 10, n_viz)

# True GP covariance
chosen_ell = 1.5
cov_func_viz = pm.gp.cov.Matern52(1, ls=chosen_ell)
K_true = cov_func_viz(x_viz[:, None]).eval()

# Helper function to calculate HSGP approximate Gram matrix
def calculate_K_approx(x_centered, L, m_val, cov_func):
    """Calculate the HSGP approximate covariance matrix."""
    eigvals = pm.gp.hsgp_approx.calc_eigenvalues(L, m_val)
    phi = pm.gp.hsgp_approx.calc_eigenvectors(x_centered, L, eigvals, m_val)
    omega = pt.sqrt(eigvals)
    psd = cov_func.power_spectral_density(omega)
    return (phi @ pt.diag(psd) @ phi.T).eval()

# Center the data
x_center = (x_viz.max() + x_viz.min()) / 2.0
x_viz_centered = (x_viz - x_center)[:, None]

# Create comparison plot
fig, axs = plt.subplots(2, 3, figsize=(12, 8), sharey=True)

# True Gram matrix
axs[0, 0].imshow(K_true, cmap='inferno', vmin=0, vmax=1)
axs[0, 0].set_title(f'True Gram matrix\nâ„“ = {chosen_ell}')
axs[0, 0].set_ylabel('Index')
axs[1, 0].axis('off')

# Various m and c combinations
configs = [
    ([30], 2.5, 1),
    ([15], 2.5, 2),
    ([30], 1.2, 3),
    ([15], 1.2, 4),
    ([5], 2.5, 5),
]

for m_val, c_val, idx in configs:
    row = 0 if idx <= 2 else 1
    col = idx if idx <= 2 else idx - 3
    
    L = pm.gp.hsgp_approx.set_boundary(x_viz_centered, c_val)
    K_approx = calculate_K_approx(x_viz_centered, L, m_val, cov_func_viz)
    
    axs[row, col].imshow(K_approx, cmap='inferno', vmin=0, vmax=1, interpolation='none')
    axs[row, col].set_title(f'm = {m_val[0]}, c = {c_val}')
    
    if col == 0:
        axs[row, col].set_ylabel('Index')
    if row == 1:
        axs[row, col].set_xlabel('Index')

plt.suptitle('HSGP Approximation Quality: Comparing to True Gram Matrix', fontsize=14)
plt.tight_layout()
plt.show()

### Understanding Gram Matrix Approximations

These plots compare approximate Gram matrices to the unapproximated one (top left). The goal is visual similarityâ€”the more alike they look, the better the approximation. Important caveats:

- These results are **only relevant for this specific domain and lengthscale** ($\ell = 1.5$). Different lengthscales will show different approximation quality.
- The approximation looks good for $m = 30$ or $m = 15$ with $c=2.5$. The rest show clear differences.
- $c=1.2$ is generally too small, regardless of $m$, showing degradation at the boundaries.
- Surprisingly, $m=5$, $c=1.2$ can look better than $m=5$, $c=2.5$. When we "stretch" the basis to fill a larger domain, we lose fidelity at smaller lengthscales if $m$ is too small.

The lesson: **you need to experiment across your range of lengthscales** to find adequate $m$ and $c$ values. Often during prototyping, you can use lower fidelity (smaller $m$) for faster iteration, then dial in higher fidelity once you understand the relevant lengthscales.

### Practical Heuristics for Choosing m and c

In practice, you'll need to infer the lengthscale from data, so HSGP needs to approximate a GP across a range of lengthscales representative of your prior. Based on the research literature and empirical experience:

1. **Start with `approx_hsgp_hyperparams`**: This function provides good default values. It chooses $c$ large enough to handle your largest expected lengthscales and $m$ large enough for your smallest lengthscales.

2. **For smooth functions with moderate lengthscales**: You can often reduce $m$ to 50-100, lowering computational cost.

3. **For rapidly-varying functions**: Increase $m$ to 100-200 or more to capture high-frequency components.

4. **For long lengthscales**: Increase $c$ to 2.5-4.0 to ensure basis functions extend well beyond your data.

5. **Check the basis vectors if sampling struggles**: The first eigenvector can become unidentifiable with the intercept when $c$ is large. Consider using the `drop_first` option.

6. **Verify approximation quality**: Compare HSGP to exact GP on a data subset when possible.

### Exercise: Comparing HSGP vs Full GP

Now you'll explore the trade-off between approximation quality and computation time.

In [None]:
# ðŸ¤– EXERCISE: Use your LLM to help compare HSGP vs standard GP

# STEP 1: Ask your LLM to help you implement this function
def hsgp_vs_full_gp(X, y, m_values=(20, 50, 100), L_factor=1.5):
    """
    Fit HSGP for several m, compare to full GP in time and RMSE.
    
    Prompt suggestion: "Help me implement HSGP in PyMC and produce plots of
    computation time vs error across different m values."
    """
    # YOUR LLM-ASSISTED CODE HERE
    pass

# STEP 2: Test on a subset of data
# Use 500-1000 points for reasonable comparison times

## Section 3.6: Advanced HSGP - 2D Spatial Example

So far we've focused on one-dimensional examples. HSGP also works well in two dimensions, making it excellent for spatial modeling. Let's see a quick 2D example to understand how $m$ and $c$ work in multiple dimensions.

### Simulating 2D Spatial Data

We'll create data on a 2D grid with a spatial GP component and a fixed effect.

In [None]:
def simulate_2d_spatial(n_points=400, ell_true=1.5, eta_true=1.0, sigma_true=0.3):
    """Simulate 2D spatial data from a GP."""
    # Create spatial grid
    n_side = int(np.sqrt(n_points))
    x1, x2 = np.meshgrid(
        np.linspace(0, 10, n_side),
        np.linspace(0, 10, n_side)
    )
    X = np.column_stack([x1.flatten(), x2.flatten()])
    
    # Sample from 2D GP
    cov_func = eta_true**2 * pm.gp.cov.Matern52(2, ls=ell_true)
    K = cov_func(X).eval()
    K_stable = K + 1e-8 * np.eye(len(X))
    f_true = rng.multivariate_normal(np.zeros(len(X)), K_stable)
    
    # Add noise
    y = f_true + sigma_true * rng.standard_normal(len(X))
    
    return X, y, f_true

X_2d, y_2d, f_true_2d = simulate_2d_spatial()

print(f"Generated {len(X_2d)} observations on 2D grid")
print(f"X ranges: [{X_2d[:, 0].min():.2f}, {X_2d[:, 0].max():.2f}] x [{X_2d[:, 1].min():.2f}, {X_2d[:, 1].max():.2f}]")

Let's visualize the true spatial field and the observed data.

In [None]:
# Create side-by-side plots
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=['True Spatial Field', 'Observed Data'],
    horizontal_spacing=0.15
)

# Reshape for plotting
n_side = int(np.sqrt(len(X_2d)))
f_grid = f_true_2d.reshape(n_side, n_side)
y_grid = y_2d.reshape(n_side, n_side)

fig.add_trace(
    go.Heatmap(z=f_grid, colorscale='RdBu', zmid=0, showscale=True),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(z=y_grid, colorscale='RdBu', zmid=0, showscale=True),
    row=1, col=2
)

fig.update_xaxes(title_text='x1')
fig.update_yaxes(title_text='x2', row=1, col=1)
fig.update_layout(height=400, title_text='2D Spatial GP Data')

fig.show()

### Building a 2D HSGP Model

For 2D HSGPs, we specify $m$ and $c$ as two-element listsâ€”one value per dimension. The total number of basis functions is $m_1 \times m_2$, so computational cost grows multiplicatively with dimension.

In [None]:
# Get recommendations for 2D
m_2d, c_2d = pm.gp.hsgp_approx.approx_hsgp_hyperparams(
    x_range=[0, 10],
    lengthscale_range=[1.0, 3.0],
    cov_func='matern52'
)

print(f"2D HSGP recommendations: m={m_2d}, c={c_2d:.2f}")
print(f"Total basis functions: {m_2d**2}")

with pm.Model() as hsgp_2d_model:
    # Priors
    ell = pm.Gamma('ell', alpha=2, beta=1)
    eta = pm.HalfNormal('eta', sigma=2)
    
    # 2D covariance
    cov = eta**2 * pm.gp.cov.Matern52(2, ls=ell)
    
    # HSGP with 2D specification
    gp = pm.gp.HSGP(
        m=[m_2d, m_2d],  # m for each dimension
        c=c_2d,           # c applies to both
        cov_func=cov
    )
    
    f = gp.prior('f', X=X_2d)
    sigma = pm.HalfNormal('sigma', sigma=1)
    y_obs = pm.Normal('y_obs', mu=f, sigma=sigma, observed=y_2d)
    
    # Sample
    idata_2d = pm.sample(
        500,
        tune=500,
        random_seed=RANDOM_SEED,
    )

### Visualizing the 2D HSGP Fit

Let's extract the posterior mean of the spatial field and compare it to the truth.

In [None]:
# Extract posterior mean
f_post = idata_2d.posterior['f'].mean(dim=['chain', 'draw']).values
f_post_grid = f_post.reshape(n_side, n_side)

# Create comparison plot
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=['True Spatial Field', 'HSGP Posterior Mean'],
    horizontal_spacing=0.15
)

fig.add_trace(
    go.Heatmap(z=f_grid, colorscale='RdBu', zmid=0, showscale=False),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(z=f_post_grid, colorscale='RdBu', zmid=0, showscale=True),
    row=1, col=2
)

fig.update_xaxes(title_text='x1')
fig.update_yaxes(title_text='x2', row=1, col=1)
fig.update_layout(height=400, title_text='2D HSGP Recovery of True Field')

fig.show()

### Understanding 2D HSGP Performance

The HSGP successfully recovered the spatial structure from the noisy observations. Notice how smoothly the posterior captures the underlying pattern while averaging out the noise.

For 2D problems, remember that the total number of basis functions is $m_1 \times m_2$. With $m=[32, 32]$, we're using 1,024 basis functions. This is still far more efficient than exact inference on 400 points (which would require $400^3 \approx 64$ million operations), but it shows why HSGP doesn't scale well beyond 3 dimensionsâ€”the basis functions multiply quickly!

### Exercise: Exploring HSGP Parameter Choices

Now experiment with different HSGP parameter configurations.

In [None]:
# ðŸ¤– EXERCISE: Use your LLM to help explore HSGP parameter choices

# STEP 1: Ask your LLM to help you implement this function
def tune_hsgp_params(X, y, m_grid=(30, 60, 120), c_grid=(1.5, 2.5, 4.0)):
    """
    Grid-search m and c to evaluate speed and accuracy trade-offs.
    
    Prompt suggestion: "Help me create a function that runs multiple HSGP fits
    over m and c grid and summarizes performance with Plotly heatmaps."
    """
    # YOUR LLM-ASSISTED CODE HERE
    pass

# STEP 2: Test on a moderate-sized dataset
# Visualize RMSE and sampling time as heatmaps

## Section 3.7: Comparing All Approaches

We've now explored both sparse GPs and HSGP approximations in detail. Let's bring everything together with a comprehensive comparison that highlights when to use each approach.

### Computational Complexity Summary

Let's visualize the computational complexity of each approach as a function of dataset size.

In [None]:
# Create comparison plot of computational complexity
n_values = np.logspace(2, 4, 50)  # 100 to 10,000 data points
m_sparse = 100  # inducing points
m_hsgp = 100    # basis functions

# Relative computational cost (arbitrary units)
cost_exact = n_values**3 / 1e6  # Scale for visibility
cost_sparse = n_values * m_sparse**2 / 1e6
cost_hsgp = n_values * m_hsgp / 1e6

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=n_values,
    y=cost_exact,
    mode='lines',
    name='Exact GP: O(nÂ³)',
    line=dict(color='blue', width=3)
))

fig.add_trace(go.Scatter(
    x=n_values,
    y=cost_sparse,
    mode='lines',
    name=f'Sparse GP: O(nmÂ²), m={m_sparse}',
    line=dict(color='green', width=3)
))

fig.add_trace(go.Scatter(
    x=n_values,
    y=cost_hsgp,
    mode='lines',
    name=f'HSGP: O(nm), m={m_hsgp}',
    line=dict(color='red', width=3)
))

fig.update_layout(
    title='Computational Complexity: Exact GP vs. Approximations',
    xaxis_title='Number of data points (n)',
    yaxis_title='Relative computational cost',
    xaxis_type='log',
    yaxis_type='log',
    height=500,
    showlegend=True,
    hovermode='x unified'
)

fig.show()

### Understanding the Complexity Comparison

This log-log plot dramatically illustrates why approximations are essential for large datasets:

- **The blue line (exact GP)** curves upward steeply, showing the crushing $\mathcal{O}(n^3)$ growth. By $n=10,000$, exact inference is essentially infeasible.

- **The green line (sparse GP)** grows much more slowly at $\mathcal{O}(nm^2)$, making datasets of several thousand points tractable.

- **The red line (HSGP)** has the gentlest slope at $\mathcal{O}(nm)$, showing near-linear scaling that makes even very large datasets manageable.

The crossover points where approximations become worthwhile depend on your patience, hardware, and accuracy requirements, but as a rough guide: consider sparse GPs beyond ~1,000 points and HSGP beyond ~5,000 points.

### Decision Guide: Which Method to Use

Here's practical guidance for choosing between methods:

**Use Standard (Exact) GP when:**
- $n < 1,000$ points
- You need exact inference without approximation error
- You're using non-stationary kernels
- Computation time isn't critical

**Use Sparse GP (Inducing Points) when:**
- $1,000 < n < 10,000$ points  
- Data has uneven sampling density
- You have domain knowledge about where to place inducing points
- You're primarily using Gaussian likelihoods
- **Typical use cases**: Spatial data with known regions of interest, time series with known change points

**Use HSGP when:**
- $n > 5,000$ points
- Using stationary kernels (MatÃ©rn, ExpQuad)
- Input dimension is 1, 2, or 3
- You need to integrate the GP into a larger model
- You need predictions at many new locations
- **Typical use cases**: Long time series, spatial data on regular grids, any large dataset with smooth variation

**Practical tip**: When prototyping, start with a low-fidelity HSGP (small $m$) for fast iteration. Once you understand the relevant lengthscales, dial in appropriate $m$ and $c$ for production.

## Summary

In this session, we've tackled one of the most practical challenges in GP modeling: scaling to larger datasets. We explored two complementary approaches that dramatically reduce computational complexity while maintaining the core benefits of GP modeling.

### Sparse GP Approximations

Sparse methods reduce complexity from $\mathcal{O}(n^3)$ to $\mathcal{O}(nm^2)$ by representing the GP through $m < n$ strategically placed inducing points. We saw how K-means provides a practical initialization strategy, and how the FITC approximation balances speed with uncertainty quantification. These methods work well for moderately large datasets where inducing points can be placed thoughtfully.

### HSGP Approximations

The Hilbert Space GP uses basis function expansions to achieve $\mathcal{O}(nm + m)$ complexity. By representing the GP as a weighted sum of pre-computed sinusoidal basis functions, HSGP becomes fully parametric and easy to integrate into larger models. The approximation quality is controlled by $m$ (number of basis functions) and $c$ (boundary factor), which can be chosen using PyMC's `approx_hsgp_hyperparams` helper function.

### Key Insights

We learned that choosing between these methods requires understanding:
- Your dataset size and structure
- Whether your kernel is stationary
- The dimensionality of your input space
- The expected range of lengthscales
- Whether you need exact inference or can tolerate approximation

Both approximations involve trade-offsâ€”we exchange some exactness for massive computational gains. The art lies in making these trade-offs intelligently based on your specific application.

### Looking Ahead

With these scaling techniques in your toolkit, you can now apply GP models to real-world datasets that would have been computationally prohibitive with exact inference. In the next session, we'll explore multi-output GPs and comprehensive case studies that bring together all the concepts from this workshop, including using HSGP for complex real-world modeling tasks.