# GP regression

This notebook compares MESS variants to ESS ($M=1$) on the Gaussian process regression problem from Murray et. al (2010) and summarizes mixing
diagnostics.

In [None]:
import sys
import os
import time

# Get absolute path to src directory (go up from notebooks to repo root)
repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_path = os.path.join(repo_root, 'src')
sys.path.insert(0, src_path)

print(f"Repo root: {repo_root}")
print(f"Added to path: {src_path}")
print(f"Does it exist? {os.path.exists(src_path)}")

import numpy as np
import matplotlib.pyplot as plt
from mess.data.gp_regression import generate_gp_regression_data
from mess.problems.gp_regression import GaussianProcessRegression
from mess.algorithms.ess import ess_step
from mess.algorithms.mess import mess_step
from mess.algorithms.effective_sample_size import estimate_effective_sample_size, compute_autocorrelation, integrated_autocorrelation_time, plot_ess_histograms

## Configuration

Set problem and sampler parameters used throughout the run.

In [None]:
# Problem parameters
num_data = 200
D = 10
length_scale = 1.0
noise_variance = 0.09

# Sampler parameters
n_iters = 10000
burn_in = 500

seed = 0

## Data generation

Simulate GP regression data and initialize the latent function.

In [None]:
data = generate_gp_regression_data(
    num_data=num_data,
    num_dims=D,
    length_scale=length_scale,
    noise_variance=noise_variance,
    seed=seed,
)

X = data["X"]
y = data["y"]
x0 = data["f_init"]
print("X shape:", X.shape)
print("y shape:", y.shape)

## Problem setup

Build the GP regression log-likelihood used by the samplers.

In [None]:
problem = GaussianProcessRegression(
    X=X,
    y=y,
    length_scale=length_scale,
    noise_variance=noise_variance,
)
print("Initial log-likelihood:", problem.log_likelihood(x0))

## Run samplers

Run ESS, MESS, and MESS variants and record wall-clock time.

In [None]:
rng_ess = np.random.default_rng(seed)

chain_ess = np.zeros((n_iters + 1, num_data))
chain_ess[0] = x0.copy()  # Initial state
intervals_ess = np.zeros(n_iters, dtype=int)    
x = x0.copy()

t0 = time.time()
for t in range(n_iters):
    x, nr_intervals_ess, P1_ess = ess_step(x, problem, rng_ess)
    chain_ess[t + 1] = x
    intervals_ess[t] = nr_intervals_ess
t1 = time.time()
ess_time = t1 - t0
print(f"ESS sampling time: {ess_time:.2f} seconds")

In [None]:
rng_mess = np.random.default_rng(seed)

M = 20
chain_mess = np.zeros((n_iters + 1, num_data))
chain_mess[0] = x0.copy()  # Initial state
intervals_mess = np.zeros(n_iters, dtype=int)
x = x0.copy()

# Uniform transition matrix
t0 = time.time()
for t in range(n_iters):
    x, nr_intervals_mess, P1_mess = mess_step(x, problem, rng_mess, M=M)
    chain_mess[t + 1] = x
    intervals_mess[t] = nr_intervals_mess
t1 = time.time()
mess_time = t1 - t0
print(f"MESS sampling time: {mess_time:.2f} seconds")

In [None]:
rng_mess = np.random.default_rng(seed)

chain_mess_ang = np.zeros((n_iters + 1, num_data))
chain_mess_ang[0] = x0.copy()  # Initial state
intervals_mess_ang = np.zeros(n_iters, dtype=int)
x = x0.copy()

# Transition matrix with LP, angular distance
t0 = time.time()
for t in range(n_iters):
    x, nr_intervals_mess, P1_mess = mess_step(x, problem, rng_mess, M=M, 
                                              use_lp=True, distance_metric='angular', 
                                              lam=0.05)
    chain_mess_ang[t + 1] = x
    intervals_mess_ang[t] = nr_intervals_mess
t1 = time.time()
mess_ang_time = t1 - t0
print(f"MESS (angular) sampling time: {mess_ang_time:.2f} seconds")

In [None]:
rng_mess = np.random.default_rng(seed)
M = 20
chain_mess_eucl = np.zeros((n_iters + 1, num_data))
chain_mess_eucl[0] = x0.copy()  # Initial state
intervals_mess_eucl = np.zeros(n_iters, dtype=int)
x = x0.copy()

# Transition matrix with LP, euclidean distance
t0 = time.time()
for t in range(n_iters):
    x, nr_intervals_mess, P1_mess = mess_step(x, problem, rng_mess, M=M, 
                                              use_lp=True, distance_metric='euclidean', 
                                              lam=0.05)
    chain_mess_eucl[t + 1] = x
    intervals_mess_eucl[t] = nr_intervals_mess
t1 = time.time()
mess_eucl_time = t1 - t0
print(f"MESS (euclidean) sampling time: {mess_eucl_time:.2f} seconds")

### Optional
To experiment with settings in the LP optimization problem:

In [None]:
rng_mess = np.random.default_rng(seed)
M = 100
chain_mess_lam0 = np.zeros((n_iters + 1, num_data))
chain_mess_lam0[0] = x0.copy()  # Initial state
intervals_mess_lam0 = np.zeros(n_iters, dtype=int)
x = x0.copy()

# Transition matrix with LP, angular distance, lambda=0
every = 5000
lam = 0
print(f"Testing MESS (M={M}) with regularization parameter lambda={lam}, \nprinting transition matrix row every {every} iterations:\n")
t0 = time.time()
for t in range(n_iters):
    x, nr_intervals_mess, P1_mess = mess_step(x, problem, rng_mess, M=M, 
                                              use_lp=True, distance_metric='euclidean', 
                                              lam=lam)
    chain_mess_lam0[t + 1] = x
    intervals_mess_lam0[t] = nr_intervals_mess
    
    # Print transition matrix row every `every` iterations
    if (t + 1) % every == 0 and P1_mess is not None:
        print(f"Iteration {t + 1}:")
        print(f"  Transition matrix row: {P1_mess}")
        print()
        
t1 = time.time()
mess_lam0_time = t1 - t0
print(f"MESS (lambda=0) sampling time: {mess_lam0_time:.2f} seconds")


## Interval diagnostics

Compare the number of intervals per iteration across samplers.

In [None]:
plt.plot(intervals_ess)
plt.plot(intervals_mess)
plt.plot(intervals_mess_ang)
plt.plot(intervals_mess_eucl)
plt.xlabel("Iteration")
plt.ylabel("Number of intervals")
plt.legend(["ESS", "MESS", "MESS Angular", "MESS Euclidean"])
plt.title("Number of intervals per iteration")
plt.show()

## Log-likelihood trace

Inspect convergence and mixing via the log-likelihood trace.

In [None]:
ll_ess = np.array([problem.log_likelihood(x) for x in chain_ess])
ll_mess = np.array([problem.log_likelihood(x) for x in chain_mess])
ll_mess_ang = np.array([problem.log_likelihood(x) for x in chain_mess_ang])
ll_mess_eucl = np.array([problem.log_likelihood(x) for x in chain_mess_eucl])

plt.plot(ll_ess, label="ESS")
plt.plot(ll_mess, label="MESS")
plt.plot(ll_mess_ang, label="MESS Angular")
plt.plot(ll_mess_eucl, label="MESS Euclidean")
plt.xlabel("Iteration")
plt.legend()
plt.title("Log-likelihood trace")

## Trace plots

Compare coordinate-level traces across samplers.

In [None]:
idx = [0, 10, 50]

fig, axes = plt.subplots(len(idx), 1, figsize=(6, 10))
for i, k in enumerate(idx):
    axes[i].plot(chain_ess[:, k], alpha=0.7, label="ESS")
    axes[i].plot(chain_mess[:, k], alpha=0.7, label="MESS")
    axes[i].plot(chain_mess_ang[:, k], alpha=0.7, label="MESS Angular")
    axes[i].plot(chain_mess_eucl[:, k], alpha=0.7, label="MESS Euclidean")
    axes[i].set_title(f"Coordinate {k}")
    axes[i].legend()

In [None]:
# Traceplot for all coordinates and algorithms at the same time
fig, axes = plt.subplots(1, 4, figsize=(15, 4))

algorithms = [
    ("ESS", chain_ess[:]),
    ("MESS", chain_mess[:]),
    ("MESS Angular", chain_mess_ang[:]),
    ("MESS Euclidean", chain_mess_eucl[:]),
]

for idx, (algo_name, chain) in enumerate(algorithms):
    for k in range(num_data):
        axes[idx].plot(chain[:, k], alpha=0.1, linewidth=0.5)
    axes[idx].set_title(algo_name)
    axes[idx].set_xlabel("Iteration")
    axes[idx].set_ylabel("Value")

plt.tight_layout()
plt.show()

## Posterior summary

Compute posterior mean and standard deviation, then compare to the true latent function.

In [None]:
f_true = data["f_true"]

samples_ess = chain_ess[burn_in + 1:]  # Skip initial state and burn_in
samples_mess = chain_mess[burn_in + 1:]
samples_mess_ang = chain_mess_ang[burn_in + 1:]
samples_mess_eucl = chain_mess_eucl[burn_in + 1:]

mean_ess = samples_ess.mean(axis=0)
mean_mess = samples_mess.mean(axis=0)
mean_mess_ang = samples_mess_ang.mean(axis=0)
mean_mess_eucl = samples_mess_eucl.mean(axis=0)

std_ess = samples_ess.std(axis=0)
std_mess = samples_mess.std(axis=0)
std_mess_ang = samples_mess_ang.std(axis=0)
std_mess_eucl = samples_mess_eucl.std(axis=0)

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(f_true, label="True f", color="black", linewidth=2)
plt.plot(mean_ess, label="ESS posterior mean", alpha=0.8)
plt.plot(mean_mess, label="MESS posterior mean", alpha=0.8)
plt.plot(mean_mess_ang, label="MESS Angular posterior mean", alpha=0.8)
plt.plot(mean_mess_eucl, label="MESS Euclidean posterior mean", alpha=0.8)
plt.legend()
plt.title("Posterior mean vs true latent function")
plt.xlabel("Index")
plt.ylabel("f")


## Eff. Sample Size diagnostics

Compute Eff. Sample Size and Eff. Sample Size per minute for each sampler.

In [None]:
# Compute Eff. Sample Size for each method
max_lag = 1000
ess_ess = estimate_effective_sample_size(chain_ess[burn_in + 1:, :], max_lag=max_lag)
ess_mess = estimate_effective_sample_size(chain_mess[burn_in + 1:, :], max_lag=max_lag)
ess_mess_ang = estimate_effective_sample_size(chain_mess_ang[burn_in + 1:, :], max_lag=max_lag)
ess_mess_eucl = estimate_effective_sample_size(chain_mess_eucl[burn_in + 1:, :], max_lag=max_lag)

# Eff. Sample Size per minute
percentage_useful_samples = (n_iters - burn_in) / n_iters
ess_minute_ess = ess_ess / (percentage_useful_samples * ess_time / 60.0)
ess_minute_mess = ess_mess / (percentage_useful_samples * mess_time / 60.0)
ess_minute_mess_ang = ess_mess_ang / (percentage_useful_samples * mess_ang_time / 60.0)
ess_minute_mess_eucl = ess_mess_eucl / (percentage_useful_samples * mess_eucl_time / 60.0)

In [None]:
# Eff. Sample Size histograms for all algorithms
algorithms = [
    ("ESS (M=1)", ess_ess, ess_minute_ess),
    ("MESS Uniform", ess_mess, ess_minute_mess),
    ("MESS Angular", ess_mess_ang, ess_minute_mess_ang),
    ("MESS Euclidean", ess_mess_eucl, ess_minute_mess_eucl),
]

fig, axes = plt.subplots(2, len(algorithms), figsize=(16, 7), sharey='row', sharex='row')
for col, (name, ess_vals, ess_min_vals) in enumerate(algorithms):
    ax = axes[0, col]
    ax.hist(ess_vals, bins=10, alpha=0.7, color="steelblue")
    ax.set_title(name)
    if col == 0:
        ax.set_ylabel("Frequency")
    ax.set_xlabel("Eff. Sample Size")

    ax = axes[1, col]
    ax.hist(ess_min_vals, bins=10, alpha=0.7, color="seagreen")
    if col == 0:
        ax.set_ylabel("Frequency")
    ax.set_xlabel("Eff. Sample Size / min")

plt.suptitle("Eff. Sample Size diagnostics", y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Summary table for Eff. Sample Size metrics
summary_rows = []
for name, ess_vals, ess_min_vals in algorithms:
    summary_rows.append([
        name,
        float(np.mean(ess_vals)),
        float(np.mean(ess_min_vals)),
    ])

fig, ax = plt.subplots(figsize=(7.5, 1.6 + 0.4 * len(summary_rows)))
ax.axis("off")
table = ax.table(
    cellText=[[row[0], f"{row[1]:.2f}", f"{row[2]:.2f}"] for row in summary_rows],
    colLabels=["Algorithm", "Eff. Sample Size (mean)", "Eff. Sample Size / min (mean)"],
    loc="center",
    cellLoc="center",
)
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.4)
plt.tight_layout()
plt.show()