# Semi-Blind Deconvolution with ESS and MESS

This notebook implements the SBD model from Senn et al. (2025, 2026) and probes the posterior with MESS. 

The forward model is: **d = W @ c + e**

where:
- **c** is the true image (vectorized), n = n_v Ã— n_h pixels
- **W** is a 1D convolution matrix formed by blur kernel w
- **e** is observation noise
- **d** is the observed blurred and noisy image

Key metrics tracked:
- Effective Sample Size (ESS)
- Mean Squared Jumping Distance (MSJD)
- Computation time
- Reconstruction quality (RMSE)

## Import Required Libraries

In [None]:
import sys
import os
import time

# Get absolute path to src directory (go up from notebooks to repo root)
repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_path = os.path.join(repo_root, 'src')
sys.path.insert(0, src_path)

print(f"Repo root: {repo_root}")
print(f"Added to path: {src_path}")

import numpy as np
import matplotlib.pyplot as plt
from mess.algorithms.ess import ess_step
from mess.algorithms.mess import mess_step
from mess.algorithms.effective_sample_size import (
    estimate_effective_sample_size,
    compute_mean_squared_jumping_distance,
    compute_normalized_jumping_distance
)
from mess.problems.sbd import SemiBlindDeconvolution
from mess.data.sbd import generate_sbd_data

## Generate Synthetic Semi-Blind Deconvolution Data

In [None]:
# Generate synthetic SBD data
data = generate_sbd_data(
    n_v=30,              # Image height (rows)
    n_h=30,              # Image width (columns)
    kernel_length=5,     # Blur kernel length
    prior_var=1.0,
    noise_variance=0.1,
    seed=42
)

c_true = data['c_true']
c_init = data['c_init']
d = data['d']
w = data['w']
W = data['W']
n_v = data['n_v']
n_h = data['n_h']

# Create problem
problem = SemiBlindDeconvolution(
    d=d,
    w=w,
    n_v=n_v,
    n_h=n_h,
    prior_var=1.0,
    noise_variance=0.1,
)

print(f"Data shapes:")
print(f"  Image size: {n_v} x {n_h} = {problem.dim} pixels")
print(f"  Blur kernel length: {len(w)}")
print(f"  Convolution matrix W: {W.shape}")
print(f"\nBlur kernel: {w}")
print(f"\nInitial log-likelihood: {problem.log_likelihood(c_init):.4f}")
print(f"True image log-likelihood: {problem.log_likelihood(c_true):.4f}")

# Visualize data
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# True image
im0 = axes[0].imshow(c_true.reshape(n_v, n_h), cmap='gray', aspect='auto')
axes[0].set_title('True Image', fontsize=12, fontweight='bold')
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

# Observed (blurred + noisy)
im1 = axes[1].imshow(d.reshape(n_v, n_h), cmap='gray', aspect='auto')
axes[1].set_title('Observed (Blurred + Noisy)', fontsize=12, fontweight='bold')
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

# Blur kernel
axes[2].plot(w, 'o-', linewidth=2, markersize=8)
axes[2].set_title('Blur Kernel', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Index')
axes[2].set_ylabel('Value')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Experiment Setup

In [None]:
# Sampler parameters
n_iters = 200
burn_in = 100
seed = 42

# MESS M values to test
M_values = [10]

print(f"Experiment setup:")
print(f"  n_iters: {n_iters}")
print(f"  burn_in: {burn_in}")
print(f"  M values: {M_values}")

## Run MCMC Sampling Experiments

Run MESS with uniform, angular, and euclidean variants for different M values.

In [None]:
# Store results for all variants
results = {
    'uniform': {'intervals': {}, 'chains': {}, 'times': [], 'statistics': {}},
    'angular': {'intervals': {}, 'chains': {}, 'times': [], 'statistics': {}},
    'euclidean': {'intervals': {}, 'chains': {}, 'times': [], 'statistics': {}},
    'ess': {'uniform': {}, 'angular': {}, 'euclidean': {}},
    'msjd': {'uniform': {}, 'angular': {}, 'euclidean': {}},
    'log_likelihood': {'uniform': {}, 'angular': {}, 'euclidean': {}},
    'rmse': {'uniform': {}, 'angular': {}, 'euclidean': {}},
}

# Run MESS Uniform
print(f"Running MESS Uniform...")
for M in M_values:
    print(f"  M = {M}")
    rng = np.random.default_rng(seed)
    chain = np.zeros((n_iters, problem.dim))
    intervals = np.zeros(n_iters, dtype=int)
    x = c_init.copy()
    
    t0 = time.time()
    for t in range(n_iters):
        x, nr_intervals, _ = mess_step(x, problem, rng, M=M, use_lp=False)
        chain[t] = x
        intervals[t] = nr_intervals
    elapsed = time.time() - t0
    
    results['uniform']['chains'][M] = chain
    results['uniform']['intervals'][M] = intervals
    results['uniform']['times'].append(elapsed)
    
    # Compute statistics
    intervals_post = intervals[burn_in:]
    results['uniform']['statistics'][M] = {
        'mean_intervals': np.mean(intervals_post),
        'std_intervals': np.std(intervals_post),
        'median_intervals': np.median(intervals_post),
    }
    
    # Compute RMSE
    c_mean = np.mean(chain[burn_in:], axis=0)
    rmse = np.sqrt(np.mean((c_mean - c_true) ** 2))
    results['rmse']['uniform'][M] = rmse
    
    print(f"    Time: {elapsed:.2f}s, Mean intervals: {np.mean(intervals_post):.4f}, RMSE: {rmse:.6f}")

# Run MESS Angular
print(f"\nRunning MESS Angular...")
for M in M_values:
    if M == 1:
        continue
    print(f"  M = {M}")
    rng = np.random.default_rng(seed)
    chain = np.zeros((n_iters + 1, problem.dim))
    chain[0] = c_init.copy()
    intervals = np.zeros(n_iters, dtype=int)
    x = c_init.copy()
    
    t0 = time.time()
    for t in range(n_iters):
        x, nr_intervals, _ = mess_step(x, problem, rng, M=M, use_lp=True,
                                       distance_metric='angular', lam=0.05)
        chain[t + 1] = x
        intervals[t] = nr_intervals
    elapsed = time.time() - t0
    
    results['angular']['chains'][M] = chain
    results['angular']['intervals'][M] = intervals
    results['angular']['times'].append(elapsed)
    
    # Compute statistics
    intervals_post = intervals[burn_in:]
    results['angular']['statistics'][M] = {
        'mean_intervals': np.mean(intervals_post),
        'std_intervals': np.std(intervals_post),
        'median_intervals': np.median(intervals_post),
    }
    
    # Compute RMSE
    c_mean = np.mean(chain[burn_in+1:], axis=0)
    rmse = np.sqrt(np.mean((c_mean - c_true) ** 2))
    results['rmse']['angular'][M] = rmse
    
    print(f"    Time: {elapsed:.2f}s, Mean intervals: {np.mean(intervals_post):.4f}, RMSE: {rmse:.6f}")

# Run MESS Euclidean
print(f"\nRunning MESS Euclidean...")
for M in M_values:
    if M == 1:
        continue
    print(f"  M = {M}")
    rng = np.random.default_rng(seed)
    chain = np.zeros((n_iters + 1, problem.dim))
    chain[0] = c_init.copy()
    intervals = np.zeros(n_iters, dtype=int)
    x = c_init.copy()
    
    t0 = time.time()
    for t in range(n_iters):
        x, nr_intervals, _ = mess_step(x, problem, rng, M=M, use_lp=True,
                                       distance_metric='euclidean', lam=0.05)
        chain[t + 1] = x
        intervals[t] = nr_intervals
    elapsed = time.time() - t0
    
    results['euclidean']['chains'][M] = chain
    results['euclidean']['intervals'][M] = intervals
    results['euclidean']['times'].append(elapsed)
    
    # Compute statistics
    intervals_post = intervals[burn_in:]
    results['euclidean']['statistics'][M] = {
        'mean_intervals': np.mean(intervals_post),
        'std_intervals': np.std(intervals_post),
        'median_intervals': np.median(intervals_post),
    }
    
    # Compute RMSE
    c_mean = np.mean(chain[burn_in+1:], axis=0)
    rmse = np.sqrt(np.mean((c_mean - c_true) ** 2))
    results['rmse']['euclidean'][M] = rmse
    
    print(f"    Time: {elapsed:.2f}s, Mean intervals: {np.mean(intervals_post):.4f}, RMSE: {rmse:.6f}")

print(f"\nAll experiments completed!")

## Compute ESS and MSJD

In [None]:
max_lag = 50

print(f"\nComputing ESS and MSJD")

# Compute ESS for Uniform
for M in M_values:
    chain = results['uniform']['chains'][M][burn_in:, :]
    ess_values = estimate_effective_sample_size(chain, max_lag=max_lag)
    msjd_values = compute_mean_squared_jumping_distance(chain)
    
    results['ess']['uniform'][M] = ess_values
    results['msjd']['uniform'][M] = msjd_values

# Compute ESS for Angular
for M in M_values:
    if M == 1:
        continue
    chain = results['angular']['chains'][M][burn_in+1:, :]
    ess_values = estimate_effective_sample_size(chain, max_lag=max_lag)
    msjd_values = compute_mean_squared_jumping_distance(chain)
    
    results['ess']['angular'][M] = ess_values
    results['msjd']['angular'][M] = msjd_values

# Compute ESS for Euclidean
for M in M_values:
    if M == 1:
        continue
    chain = results['euclidean']['chains'][M][burn_in+1:, :]
    ess_values = estimate_effective_sample_size(chain, max_lag=max_lag)
    msjd_values = compute_mean_squared_jumping_distance(chain)
    
    results['ess']['euclidean'][M] = ess_values
    results['msjd']['euclidean'][M] = msjd_values

# Compute log-likelihood for all chains
for M in M_values:
    # Uniform
    chain_uniform = results['uniform']['chains'][M]
    ll_uniform = np.array([problem.log_likelihood(chain_uniform[t, :]) for t in range(n_iters)])
    results['log_likelihood']['uniform'][M] = ll_uniform
    
    # Angular and Euclidean (only for M > 1)
    if M > 1:
        chain_angular = results['angular']['chains'][M]
        ll_angular = np.array([problem.log_likelihood(chain_angular[t, :]) for t in range(n_iters + 1)])
        results['log_likelihood']['angular'][M] = ll_angular
        
        chain_euclidean = results['euclidean']['chains'][M]
        ll_euclidean = np.array([problem.log_likelihood(chain_euclidean[t, :]) for t in range(n_iters + 1)])
        results['log_likelihood']['euclidean'][M] = ll_euclidean

print("ESS, MSJD, and log-likelihood computation completed!")

## Figure 1: Effective Sample Size and MSJD Boxplots

Compare ESS and MSJD distributions across algorithms (Uniform, Angular, Euclidean) for different M values.

In [None]:
# Create 1x2 figure: ESS (left) and MSJD (right)
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Semi-Blind Deconvolution: Eff. Sample Size and MSJD Comparison', fontsize=16, fontweight='bold')

# Prepare boxplot data
uniform_ess_box = []
angular_ess_box = []
euclidean_ess_box = []
uniform_msjd_box = []
angular_msjd_box = []
euclidean_msjd_box = []
M_box = []

for M in M_values:
    uniform_ess_box.append(results['ess']['uniform'][M])
    uniform_msjd_box.append(results['msjd']['uniform'][M])
    M_box.append(M)
    
    if M > 1:
        angular_ess_box.append(results['ess']['angular'][M])
        euclidean_ess_box.append(results['ess']['euclidean'][M])
        angular_msjd_box.append(results['msjd']['angular'][M])
        euclidean_msjd_box.append(results['msjd']['euclidean'][M])

# ===== ESS BOXPLOT (LEFT) =====
ax = axes[0]

# Positions
positions_uniform = []
positions_angular = []
positions_euclidean = []
pos_counter = 0

for i, M in enumerate(M_box):
    positions_uniform.append(pos_counter)
    pos_counter += 1
    
    if M > 1:
        positions_angular.append(pos_counter)
        pos_counter += 1
        positions_euclidean.append(pos_counter)
        pos_counter += 1
    
    pos_counter += 1

# Create boxplots
bp1 = ax.boxplot(uniform_ess_box, positions=positions_uniform, widths=0.6,
                  patch_artist=True, label='Uniform')
bp2 = ax.boxplot(angular_ess_box, positions=positions_angular, widths=0.6,
                  patch_artist=True, label='Angular')
bp3 = ax.boxplot(euclidean_ess_box, positions=positions_euclidean, widths=0.6,
                  patch_artist=True, label='Euclidean')

# Color boxes
for patch in bp1['boxes']:
    patch.set_facecolor('#1f77b4')
    patch.set_alpha(0.7)
for patch in bp2['boxes']:
    patch.set_facecolor('#ff7f0e')
    patch.set_alpha(0.7)
for patch in bp3['boxes']:
    patch.set_facecolor('#2ca02c')
    patch.set_alpha(0.7)

# X-axis labels
x_ticks = []
x_labels = []
for i, M in enumerate(M_box):
    if M == 1:
        x_ticks.append(positions_uniform[i])
        x_labels.append(f"M={int(M)}\n(ESS)")
    else:
        j = len([m for m in M_box[:i+1] if m > 1]) - 1
        avg_pos = (positions_uniform[i] + positions_angular[j] + positions_euclidean[j]) / 3
        x_ticks.append(avg_pos)
        x_labels.append(str(int(M)))

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Effective Sample Size (ESS)', fontsize=12, fontweight='bold')
ax.set_title('ESS', fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3, axis='y')

# ===== MSJD BOXPLOT (RIGHT) =====
ax = axes[1]

# Positions
positions_uniform_m = []
positions_angular_m = []
positions_euclidean_m = []
pos_counter = 0

for i, M in enumerate(M_box):
    positions_uniform_m.append(pos_counter)
    pos_counter += 1
    
    if M > 1:
        positions_angular_m.append(pos_counter)
        pos_counter += 1
        positions_euclidean_m.append(pos_counter)
        pos_counter += 1
    
    pos_counter += 1

# Create boxplots
bp1 = ax.boxplot(uniform_msjd_box, positions=positions_uniform_m, widths=0.6,
                  patch_artist=True, label='Uniform')
bp2 = ax.boxplot(angular_msjd_box, positions=positions_angular_m, widths=0.6,
                  patch_artist=True, label='Angular')
bp3 = ax.boxplot(euclidean_msjd_box, positions=positions_euclidean_m, widths=0.6,
                  patch_artist=True, label='Euclidean')

# Color boxes
for patch in bp1['boxes']:
    patch.set_facecolor('#1f77b4')
    patch.set_alpha(0.7)
for patch in bp2['boxes']:
    patch.set_facecolor('#ff7f0e')
    patch.set_alpha(0.7)
for patch in bp3['boxes']:
    patch.set_facecolor('#2ca02c')
    patch.set_alpha(0.7)

# X-axis labels
x_ticks = []
x_labels = []
for i, M in enumerate(M_box):
    if M == 1:
        x_ticks.append(positions_uniform_m[i])
        x_labels.append(f"M={int(M)}\n(ESS)")
    else:
        j = len([m for m in M_box[:i+1] if m > 1]) - 1
        avg_pos = (positions_uniform_m[i] + positions_angular_m[j] + positions_euclidean_m[j]) / 3
        x_ticks.append(avg_pos)
        x_labels.append(str(int(M)))

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Mean Squared Jumping Distance (MSJD)', fontsize=12, fontweight='bold')
ax.set_title('MSJD', fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## Figure 2: Shrinking Steps and Computation Time

In [None]:
# Create 1x2 figure: Intervals and Time
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Semi-Blind Deconvolution: Shrinking Steps and Computation Time', fontsize=16, fontweight='bold')

# Prepare data
uniform_intervals_box = []
angular_intervals_box = []
euclidean_intervals_box = []
M_box = []

for M in M_values:
    intervals = results['uniform']['intervals'][M][burn_in:]
    uniform_intervals_box.append(intervals)
    M_box.append(M)
    
    if M > 1:
        intervals_ang = results['angular']['intervals'][M][burn_in:]
        intervals_euc = results['euclidean']['intervals'][M][burn_in:]
        angular_intervals_box.append(intervals_ang)
        euclidean_intervals_box.append(intervals_euc)

# ===== SHRINKING STEPS BOXPLOT (LEFT) =====
ax = axes[0]

# Positions
positions_uniform = []
positions_angular = []
positions_euclidean = []
pos_counter = 0

for i, M in enumerate(M_box):
    positions_uniform.append(pos_counter)
    pos_counter += 1
    
    if M > 1:
        positions_angular.append(pos_counter)
        pos_counter += 1
        positions_euclidean.append(pos_counter)
        pos_counter += 1
    
    pos_counter += 1

# Create boxplots
bp1 = ax.boxplot(uniform_intervals_box, positions=positions_uniform, widths=0.6,
                  patch_artist=True, label='Uniform')
bp2 = ax.boxplot(angular_intervals_box, positions=positions_angular, widths=0.6,
                  patch_artist=True, label='Angular')
bp3 = ax.boxplot(euclidean_intervals_box, positions=positions_euclidean, widths=0.6,
                  patch_artist=True, label='Euclidean')

# Color boxes
for patch in bp1['boxes']:
    patch.set_facecolor('#1f77b4')
    patch.set_alpha(0.7)
for patch in bp2['boxes']:
    patch.set_facecolor('#ff7f0e')
    patch.set_alpha(0.7)
for patch in bp3['boxes']:
    patch.set_facecolor('#2ca02c')
    patch.set_alpha(0.7)

# X-axis labels
x_ticks = []
x_labels = []
for i, M in enumerate(M_box):
    if M == 1:
        x_ticks.append(positions_uniform[i])
        x_labels.append(f"M={int(M)}\n(ESS)")
    else:
        j = len([m for m in M_box[:i+1] if m > 1]) - 1
        avg_pos = (positions_uniform[i] + positions_angular[j] + positions_euclidean[j]) / 3
        x_ticks.append(avg_pos)
        x_labels.append(str(int(M)))

ax.set_xticks(x_ticks)
ax.set_xticklabels(x_labels)
ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Number of Shrinking Steps', fontsize=12, fontweight='bold')
ax.set_title('Shrinking Steps', fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3, axis='y')

# ===== COMPUTATION TIME PLOT (RIGHT) =====
ax = axes[1]

# Prepare time data
M_uniform = M_values
times_uniform = results['uniform']['times']

M_angular = [M for M in M_values if M > 1]
times_angular = results['angular']['times']

M_euclidean = [M for M in M_values if M > 1]
times_euclidean = results['euclidean']['times']

# Plot
ax.plot(M_uniform, times_uniform, 'o-', label='Uniform', linewidth=2, markersize=8, color='#1f77b4')
ax.plot(M_angular, times_angular, 's-', label='Angular', linewidth=2, markersize=8, color='#ff7f0e')
ax.plot(M_euclidean, times_euclidean, '^-', label='Euclidean', linewidth=2, markersize=8, color='#2ca02c')

ax.set_xlabel('M (Number of Proposals)', fontsize=12, fontweight='bold')
ax.set_ylabel('Computation Time (seconds)', fontsize=12, fontweight='bold')
ax.set_title('Computation Time', fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='best')
ax.grid(True, alpha=0.3)
ax.set_xscale('log')
ax.set_yscale('log')

plt.tight_layout()
plt.show()

## Figure 5: Log-Likelihood Trace Plots

Compare the log-likelihood evolution for different methods.

In [None]:
# Select one M value to compare
M_compare = 10

fig, axes = plt.subplots(1, 2, figsize=(16, 5))
fig.suptitle(f'Log-Likelihood Traces (M={M_compare})', fontsize=16, fontweight='bold')

# Left: Full trace
ax = axes[0]
ax.plot(results['log_likelihood']['uniform'][M_compare], label='Uniform', alpha=0.7, linewidth=1)
if M_compare > 1:
    ax.plot(results['log_likelihood']['angular'][M_compare], label='Angular', alpha=0.7, linewidth=1)
    ax.plot(results['log_likelihood']['euclidean'][M_compare], label='Euclidean', alpha=0.7, linewidth=1)
ax.axhline(problem.log_likelihood(c_true), color='red', linestyle='--', linewidth=2, label='True')
ax.axvline(burn_in, color='black', linestyle=':', linewidth=1, alpha=0.5)
ax.set_xlabel('Iteration', fontsize=12)
ax.set_ylabel('Log-Likelihood', fontsize=12)
ax.set_title('Full Trace', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Right: Post burn-in
ax = axes[1]
ax.plot(results['log_likelihood']['uniform'][M_compare][burn_in:], label='Uniform', alpha=0.7, linewidth=1)
if M_compare > 1:
    ax.plot(results['log_likelihood']['angular'][M_compare][burn_in+1:], label='Angular', alpha=0.7, linewidth=1)
    ax.plot(results['log_likelihood']['euclidean'][M_compare][burn_in+1:], label='Euclidean', alpha=0.7, linewidth=1)
ax.axhline(problem.log_likelihood(c_true), color='red', linestyle='--', linewidth=2, label='True')
ax.set_xlabel('Iteration (post burn-in)', fontsize=12)
ax.set_ylabel('Log-Likelihood', fontsize=12)
ax.set_title('Post Burn-in', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary Statistics Table

In [None]:
# Print summary table
print("=" * 100)
print(f"{'Method':<12} {'M':<6} {'Mean Int.':<12} {'ESS (min)':<12} {'ESS (mean)':<12} {'MSJD (mean)':<12} {'RMSE':<12} {'Time (s)':<12}")
print("=" * 100)

for i, M in enumerate(M_values):
    # Uniform
    mean_int = results['uniform']['statistics'][M]['mean_intervals']
    ess_vals = results['ess']['uniform'][M]
    msjd_vals = results['msjd']['uniform'][M]
    rmse = results['rmse']['uniform'][M]
    time_s = results['uniform']['times'][i]
    
    print(f"{'Uniform':<12} {M:<6} {mean_int:<12.4f} {np.min(ess_vals):<12.2f} {np.mean(ess_vals):<12.2f} {np.mean(msjd_vals):<12.6f} {rmse:<12.6f} {time_s:<12.2f}")
    
    # Angular
    if M > 1:
        mean_int = results['angular']['statistics'][M]['mean_intervals']
        ess_vals = results['ess']['angular'][M]
        msjd_vals = results['msjd']['angular'][M]
        rmse = results['rmse']['angular'][M]
        time_s = results['angular']['times'][i-1]  # Offset by 1 since M=1 is skipped
        
        print(f"{'Angular':<12} {M:<6} {mean_int:<12.4f} {np.min(ess_vals):<12.2f} {np.mean(ess_vals):<12.2f} {np.mean(msjd_vals):<12.6f} {rmse:<12.6f} {time_s:<12.2f}")
    
    # Euclidean
    if M > 1:
        mean_int = results['euclidean']['statistics'][M]['mean_intervals']
        ess_vals = results['ess']['euclidean'][M]
        msjd_vals = results['msjd']['euclidean'][M]
        rmse = results['rmse']['euclidean'][M]
        time_s = results['euclidean']['times'][i-1]  # Offset by 1 since M=1 is skipped
        
        print(f"{'Euclidean':<12} {M:<6} {mean_int:<12.4f} {np.min(ess_vals):<12.2f} {np.mean(ess_vals):<12.2f} {np.mean(msjd_vals):<12.6f} {rmse:<12.6f} {time_s:<12.2f}")
    
    if M > 1:
        print("-" * 100)

print("=" * 100)