# Toy Model for Solute Transport : Dimension and M Sweep (Shared Draws)

This notebook runs uniform MESS (multiple M) and MH across dimensions, using a shared draw at d_max for the prior sample and observation noise, then subsets for each d.

Notes:
- Chains are saved under estimations/AD_toy_dim_M_sweep_shared_draws.
- a_true, theta_true, and observation noise come from the same d_max draw and are subset for each d.
- Set d_max >= max(d_list).

## Imports

In [None]:
import os
import sys
import time
import zipfile
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_path = os.path.join(repo_root, 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

from mess.problems.advection_diffusion import (
    make_omegas_power,
    make_Astar_nn,
    make_Astar_from_atrue,
    params_from_skew,
    prior_diag_from_powerlaw,
    solve_theta,
 )
from mess.problems.advection_diffusion import AdvectionDiffusionToy
from mess.algorithms.mess import mess_step
from mess.algorithms.mh import mh_chain
from mess.algorithms.effective_sample_size import estimate_effective_sample_size

## Configuration

In [None]:
# Sweep configuration
seed_data = 0
seed_mcmc = 0
n_iters = 30000
burn_in = 10000
thin = 1
max_lag = 1500

d_list = [10] #[10, 15, 20, 25, 30, 35, 40, 45, 50] 
d_max = 100
M_list = [10] #[1, 10, 50, 100] #, 100, 200]

# Data hyperparameters
kappa = 0.02
sigma = 0.5
alpha = 3
gamma = 2
tau2 = 2.0
a_mode = 'nearest_neighbor'
use_prior_A = True
shared_draws_seed = seed_data

if max(d_list) > d_max:
    raise ValueError('d_max must be >= max(d_list)')

# Observation configuration
obs_highest_freq = 6
obs_bandwidth = 3
obs_config = "central_modes"

# MH proposal covariance ("isotropic" or "prior").
mh_proposal_cov = "prior"
mh_proposal_isotropic_std = 0.000018
mh_proposal_prior_std = 0.105
if mh_proposal_cov == "isotropic":
    mh_proposal_std_chosen = mh_proposal_isotropic_std
elif mh_proposal_cov == "prior":
    mh_proposal_std_chosen = mh_proposal_prior_std

# MH proposal stds per d (tuned to get 23.4% for d=20).
mh_proposal_stds_scaled = np.repeat(mh_proposal_std_chosen, len(d_list))
if len(mh_proposal_stds_scaled) != len(d_list):
    raise ValueError('mh_proposal_stds_scaled must match d_list length')


# Run flags (default False to avoid long runs).
run_mess = True
run_mh = False
recompute_corrupt_chains = True

# Parallelization (dim-level)
use_parallel = True
max_workers = 1  # Keep small to avoid overloading the machine.

# Cache generated datasets per dimension.
datasets_by_dim = {}

# Output directory under estimations/
run_tag = (
    f"priorA{use_prior_A}_obs_{obs_config}_tau2{tau2}_sigma{sigma}_seed{seed_data}_"
    f"dmax{d_max}_Niters{n_iters}"
 )
output_dir = Path(repo_root) / 'estimations' / "AD_toy_dim_M_sweep_shared_draws" / run_tag
output_dir.mkdir(parents=True, exist_ok=True)

# Reports directory for figures, tables, and metrics.
reports_dir = Path(repo_root) / 'reports' / "AD_toy_dim_M_sweep_shared_draws" / run_tag
reports_dir.mkdir(parents=True, exist_ok=True)

print('Output dir:', output_dir)
print('Reports dir:', reports_dir)

## Helpers

In [None]:
def compute_msjd_per_param(chain):
    if chain.shape[0] < 2:
        return np.zeros(chain.shape[1])
    jumps = np.diff(chain, axis=0)
    msjd = np.mean(jumps * jumps, axis=0)
    return msjd

def compute_ess_per_param(chain, max_lag):
    if chain.shape[0] < 2:
        return np.zeros(chain.shape[1])
    variances = np.var(chain, axis=0)
    if np.all(variances == 0):
        return np.zeros(chain.shape[1])
    ess_vals = estimate_effective_sample_size(chain, max_lag=max_lag)
    ess_vals = np.asarray(ess_vals, dtype=float)
    ess_vals[variances == 0] = 0.0
    return ess_vals

def chain_path(output_dir, d, alg, M=None, proposal_std=None, proposal_cov=None):
    alg_key = alg.lower()
    if alg_key == 'mh':
        if proposal_std is None:
            return output_dir / f'chain_d{d}_mh_sigma2unknown.npz'
        sigma_tag = f'{proposal_std:.6g}'
        cov_tag = '' if proposal_cov in (None, 'isotropic') else f'_cov{proposal_cov}'
        return output_dir / f'chain_d{d}_mh_sigma2{sigma_tag}{cov_tag}.npz'
    return output_dir / f'chain_d{d}_{alg_key}_M{M}.npz'

def load_chain(output_dir, d, alg, M=None, proposal_std=None, proposal_cov=None):
    path = chain_path(output_dir, d, alg, M, proposal_std=proposal_std, proposal_cov=proposal_cov)
    if not path.exists() and alg.lower() == 'mh':
        if proposal_std is None:
            matches = sorted(output_dir.glob(f'chain_d{d}_mh_sigma2*.npz'))
        else:
            sigma_tag = f'{proposal_std:.6g}'
            matches = sorted(output_dir.glob(f'chain_d{d}_mh_sigma2{sigma_tag}*.npz'))
        if matches:
            path = matches[0]
    if not path.exists():
        return None
    try:
        data = np.load(path)
        return data['chain']
    except (zipfile.BadZipFile, ValueError, KeyError) as exc:
        print(f'Corrupt or unreadable chain file: {path.name} ({exc})')
        return None

def is_chain_readable(path):
    if not path.exists():
        return False
    try:
        with np.load(path) as data:
            _ = data['chain']
        return True
    except (zipfile.BadZipFile, ValueError, KeyError):
        return False

def save_chain(path, chain, metadata):
    np.savez_compressed(path, chain=chain, **metadata)

def get_obs_indices(dim_value, highest_freq, bandwidth):
    highest_freq = min(highest_freq, dim_value)
    bandwidth = min(bandwidth, dim_value)
    start = max(0, highest_freq - bandwidth + 1)
    return np.arange(start, highest_freq + 1, dtype=int)

def get_param_indices_for_dim(dim, shared_draws):
    cache = shared_draws.setdefault('param_indices_cache', {})
    if dim not in cache:
        iju = shared_draws['param_iju']
        mask = (iju[0] < dim) & (iju[1] < dim)
        cache[dim] = np.nonzero(mask)[0]
    return cache[dim]

def build_shared_draws(
    d_max,
    kappa,
    sigma,
    alpha,
    gamma,
    tau2,
    offset,
    a_mode,
    seed,
 ):
    rng = np.random.default_rng(seed)
    m_max = d_max * (d_max - 1) // 2
    prior_diag_max = prior_diag_from_powerlaw(
        d_max, alpha=alpha, gamma=gamma, tau2=tau2, offset=offset
    )
    if prior_diag_max.shape != (m_max,):
        raise ValueError(f'prior_diag_max must have shape ({m_max},), got {prior_diag_max.shape}')
    if a_mode == 'nearest_neighbor':
        omegas = make_omegas_power(d_max, beta=alpha, c=2.0 ** (-gamma), offset=offset)
        A_true_max = make_Astar_nn(d_max, omegas)
        a_true_max = params_from_skew(A_true_max)
    elif a_mode == 'prior':
        z_prior = rng.standard_normal(m_max)
        a_true_max = z_prior * np.sqrt(prior_diag_max)
        A_true_max = make_Astar_from_atrue(d_max, a_true_max)
    else:
        raise ValueError("a_mode must be 'nearest_neighbor' or 'prior'")
    g_max = np.zeros(d_max, dtype=float)
    g_max[0] = 1.0
    theta_true_max = solve_theta(d_max, a_true_max, g_max, kappa)
    noise_max = rng.standard_normal(d_max)
    z_init = rng.standard_normal(m_max)
    a_init_max = z_init * np.sqrt(prior_diag_max)
    return {
        'd_max': d_max,
        'm_max': m_max,
        'kappa': kappa,
        'sigma': sigma,
        'alpha': alpha,
        'gamma': gamma,
        'tau2': tau2,
        'offset': offset,
        'a_mode': a_mode,
        'param_iju': np.triu_indices(d_max, k=1),
        'param_indices_cache': {},
        'prior_diag': prior_diag_max,
        'a_true': a_true_max,
        'A_true': A_true_max,
        'g': g_max,
        'theta_true': theta_true_max,
        'noise': noise_max,
        'a_init': a_init_max,
    }

In [None]:
def generate_advection_diffusion_data_shared(dim, obs_indices, shared_draws):
    a_mode_local = shared_draws['a_mode']
    param_idx = get_param_indices_for_dim(dim, shared_draws)
    prior_diag = shared_draws['prior_diag'][param_idx]
    g = shared_draws['g'][:dim]
    if a_mode_local == 'nearest_neighbor':
        omegas = make_omegas_power(
            dim,
            beta=shared_draws['alpha'],
            c=2.0 ** (-shared_draws['gamma']),
            offset=shared_draws['offset'],
        )
        A_true = make_Astar_nn(dim, omegas)
        a_true = params_from_skew(A_true)
        theta_true = solve_theta(dim, a_true, g, shared_draws['kappa'])
    elif a_mode_local == 'prior':
        a_true = shared_draws['a_true'][param_idx]
        A_true = make_Astar_from_atrue(dim, a_true)
        theta_true = shared_draws['theta_true'][:dim]
    else:
        raise ValueError("a_mode must be 'nearest_neighbor' or 'prior'")
    noise = shared_draws['noise'][:dim]
    y = theta_true[obs_indices] + shared_draws['sigma'] * noise[obs_indices]
    a_init = shared_draws['a_init'][param_idx]
    return {
        'dim': dim,
        'kappa': shared_draws['kappa'],
        'alpha': shared_draws['alpha'],
        'gamma': shared_draws['gamma'],
        'tau2': shared_draws['tau2'],
        'sigma': shared_draws['sigma'],
        'obs_indices': obs_indices,
        'prior_diag': prior_diag,
        'a_true': a_true,
        'A_true': A_true,
        'g': g,
        'theta_true': theta_true,
        'y': y,
        'a_init': a_init,
    }

def get_dataset_for_dim(d, seed=0):
    if d in datasets_by_dim:
        return datasets_by_dim[d]
    obs_indices = get_obs_indices(d, obs_highest_freq, obs_bandwidth)
    data = generate_advection_diffusion_data_shared(d, obs_indices, shared_draws)
    data['obs_indices'] = obs_indices
    datasets_by_dim[d] = data
    return data

def build_problem_for_dim(d, seed=0):
    data = get_dataset_for_dim(d, seed=seed)
    obs_indices = data['obs_indices']
    problem = AdvectionDiffusionToy(
        dim=d,
        kappa=kappa,
        sigma=sigma,
        y=data['y'],
        obs_indices=obs_indices,
        g=data['g'],
        prior_diag=data['prior_diag'],
    )
    return problem, data['a_init'], obs_indices, data

In [None]:
shared_draws = build_shared_draws(
    d_max=d_max,
    kappa=kappa,
    sigma=sigma,
    alpha=alpha,
    gamma=gamma,
    tau2=tau2,
    offset=1.0,
    a_mode='prior' if use_prior_A else a_mode,
    seed=shared_draws_seed,
 )

In [None]:
# Visual check: A_true and observations for each d (including d_max).
plot_dims = [10, 20, 40]
n_cols = len(plot_dims)
fig, axes = plt.subplots(2, n_cols, figsize=(12, 6))
axes = np.atleast_2d(axes)

row_label_size = 14
title_size = 13
tick_size = 11
axis_label_size = 12
cbar_tick_size = 11

last_im = None
for col_idx, d_cur in enumerate(plot_dims):
    data = get_dataset_for_dim(d_cur, seed=seed_data)

    ax_A = axes[0, col_idx]
    ax_obs = axes[1, col_idx]

    last_im = ax_A.imshow(data['A_true'], cmap='coolwarm', aspect='auto')
    if col_idx == 0:
        ax_A.set_ylabel('i', fontsize=axis_label_size)
    ax_A.set_xlabel('j', fontsize=axis_label_size)
    ax_A.tick_params(axis='both', labelsize=tick_size)

    theta_true = data['theta_true']
    obs_indices = data['obs_indices']
    ax_obs.plot(
        np.arange(d_cur),
        theta_true,
        color='tab:blue',
        label=r"$\mathbf{\Theta}(\mathbf{A})$",
    )
    ax_obs.scatter(obs_indices, data['y'], color='tab:orange', s=20, label=r"$y$")
    if col_idx == 0:
        ax_obs.set_ylabel('Value', fontsize=axis_label_size)
    ax_obs.set_xlabel('i', fontsize=axis_label_size)
    ax_obs.grid(alpha=0.2)
    ax_obs.legend(loc='best', fontsize=tick_size)
    ax_obs.tick_params(axis='both', labelsize=tick_size)

    ax_A.set_title(rf"$d={d_cur}$", fontsize=title_size)

if last_im is not None:
    cbar = fig.colorbar(last_im, ax=axes[0, -1], fraction=0.05, pad=0.02)
    cbar.ax.tick_params(labelsize=cbar_tick_size)

fig.text(0.01, 0.73, r"$\mathbf{A}$", rotation=0, va='center', ha='left', fontsize=row_label_size)
fig.text(
    0.01,
    0.27,
    "Obs.",
    rotation=0,
    va='center',
    ha='left',
    fontsize=row_label_size,
 )
fig.tight_layout(rect=[0.05, 0.02, 0.95, 0.98])

fig_path = reports_dir / "visual_check_A_theta_y.png"
fig.savefig(fig_path, dpi=600, bbox_inches="tight")
print(f"Saved {fig_path}")

plt.show()

## Tune MH

In [None]:
# # Automated MH tuning (d=20).
# tune_d = 20
# tune_iters = 30000
# target_accept = 0.234
# sigma_center = 0.05
# grid_factors = [2.1, 2.2, 2.3, 2.4, 2.5] # [0.2, 0.35, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0]
# sigma_grid = [sigma_center * f for f in grid_factors]

# problem_cur, x0_cur, obs_indices_cur, _ = build_problem_for_dim(tune_d, seed=seed_data)
# print(f'Tuning MH at d={tune_d} with obs_indices={obs_indices_cur.tolist()}')

# results = []
# for proposal_std in sigma_grid:
#     rng = np.random.default_rng(seed_mcmc)
#     chain_mh, acc = mh_chain(
#         x0_cur, problem_cur, rng, tune_iters, proposal_std=proposal_std, proposal_cov=mh_proposal_cov
#     )
#     results.append((proposal_std, acc))
#     print(f'std={proposal_std:.8g} | acc={acc:.4f}')

# best_std, best_acc = min(results, key=lambda item: abs(item[1] - target_accept))
# print('\nBest proposal std (closest to target acceptance):')
# print(f'std={best_std:.8g} | acc={best_acc:.4f}')

In [None]:
# # Profiling: break down MH step costs by component.
# # Adjust dims_to_profile and n_reps for quick checks.
# import time
# import math

# dims_to_profile = [10, 20, 30, 40]
# n_reps = 200
# proposal_cov_to_profile = mh_proposal_cov  # 'prior' or 'isotropic'

# def _timed(fn, n_calls=1):
#     t0 = time.perf_counter()
#     out = None
#     for _ in range(n_calls):
#         out = fn()
#     t1 = time.perf_counter()
#     return (t1 - t0) / max(1, n_calls), out

# def _profile_dim(d_cur):
#     problem_cur, x0_cur, _, _ = build_problem_for_dim(d_cur, seed=seed_data)
#     rng = np.random.default_rng(seed_mcmc)
#     x = x0_cur.copy()

#     # Proposal generation cost.
#     def _proposal_only():
#         if proposal_cov_to_profile == "prior":
#             z = rng.standard_normal(problem_cur.dim)
#             return x + mh_proposal_std_chosen * (problem_cur.L @ z)
#         return x + mh_proposal_std_chosen * rng.standard_normal(problem_cur.dim)

#     # Log-density components.
#     def _log_prior():
#         return problem_cur.log_prior(x)

#     def _log_likelihood():
#         return problem_cur.log_likelihood(x)

#     def _log_posterior():
#         return problem_cur.log_posterior(x)

#     # Theta solve only (often dominant).
#     def _theta_solve():
#         _ = problem_cur.theta_from_params(x)
#         return None

#     proposal_t, _ = _timed(_proposal_only, n_reps)
#     prior_t, _ = _timed(_log_prior, n_reps)
#     like_t, _ = _timed(_log_likelihood, n_reps)
#     post_t, _ = _timed(_log_posterior, n_reps)
#     theta_t, _ = _timed(_theta_solve, n_reps)

#     # Rough per-MH-step estimate: proposal + 2 * log_posterior evaluations.
#     approx_step = proposal_t + 2.0 * post_t

#     return {
#         "d": d_cur,
#         "proposal_ms": 1e3 * proposal_t,
#         "log_prior_ms": 1e3 * prior_t,
#         "log_like_ms": 1e3 * like_t,
#         "log_post_ms": 1e3 * post_t,
#         "theta_solve_ms": 1e3 * theta_t,
#         "approx_step_ms": 1e3 * approx_step,
#     }

# print("MH timing breakdown (per-call milliseconds):")
# for d_cur in dims_to_profile:
#     stats = _profile_dim(d_cur)
#     print(stats)

## Run Sweep and Save Chains

In [None]:
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat

def _run_dim_task(d_idx, d_cur, config):
    seed_mcmc = config['seed_mcmc']
    seed_data = config['seed_data']
    n_iters_dim = config['n_iters_dim']
    burn_in_cur = config['burn_in']
    thin = config['thin']
    M_list = config['M_list']
    run_mess = config['run_mess']
    run_mh = config['run_mh']
    output_dir = Path(config['output_dir'])
    mh_proposal_stds_scaled = config['mh_proposal_stds_scaled']
    mh_proposal_cov = config['mh_proposal_cov']
    recompute_corrupt_chains = config['recompute_corrupt_chains']

    problem_cur, x0_cur, obs_indices_cur, _ = build_problem_for_dim(d_cur, seed=seed_data)
    print(f'--- d={d_cur} | burn_in={burn_in_cur} | obs_indices={obs_indices_cur.tolist()}')

    # MESS sweep
    if run_mess:
        for M in M_list:
            out_path = chain_path(output_dir, d_cur, 'mess', M=M)
            if out_path.exists() and is_chain_readable(out_path):
                print(f'Skip existing: {out_path.name}')
                continue
            if out_path.exists() and not is_chain_readable(out_path):
                if recompute_corrupt_chains:
                    print(f'Recomputing corrupt chain: {out_path.name}')
                else:
                    print(f'Corrupt chain found, skipping: {out_path.name}')
                    continue
            print(f'\nStart MESS chain: d={d_cur}, M={M}, n_iters={n_iters_dim}, burn_in={burn_in_cur}, thin={thin}')
            rng = np.random.default_rng(seed_mcmc)
            chain = np.zeros((n_iters_dim + 1, x0_cur.shape[0]))
            chain[0] = x0_cur.copy()
            x = x0_cur.copy()
            t0 = time.perf_counter()
            for t in range(n_iters_dim):
                x, _, _ = mess_step(x, problem_cur, rng, M=M, use_lp=False)
                chain[t + 1] = x
            t1 = time.perf_counter()
            post = chain[::thin]
            metadata = {
                'alg': 'mess',
                'M': M,
                'd': d_cur,
                'n_iters': n_iters_dim,
                'burn_in': burn_in_cur,
                'thin': thin,
                'seed_mcmc': seed_mcmc,
                'seed_data': seed_data,
                'runtime_sec': t1 - t0,
            }
            save_chain(out_path, post, metadata)
            print(f'Saved {out_path.name}')

    # MH sweep
    if run_mh:
        mh_std = mh_proposal_stds_scaled[d_idx]
        out_path = chain_path(
            output_dir,
            d_cur,
            'mh',
            proposal_std=mh_std,
            proposal_cov=mh_proposal_cov,
        )
        if out_path.exists() and is_chain_readable(out_path):
            print(f'Skip existing: {out_path.name}')
        else:
            if out_path.exists() and not is_chain_readable(out_path):
                if recompute_corrupt_chains:
                    print(f'Recomputing corrupt chain: {out_path.name}')
                else:
                    print(f'Corrupt chain found, skipping: {out_path.name}')
                    return
            print(
                f'\nStart MH chain: d={d_cur}, n_iters={n_iters_dim}, burn_in={burn_in_cur}, thin={thin}, proposal_std={mh_std}, proposal_cov={mh_proposal_cov}'
            )
            rng = np.random.default_rng(seed_mcmc)
            t0 = time.perf_counter()
            chain_mh, acc = mh_chain(
                x0_cur,
                problem_cur,
                rng,
                n_iters_dim,
                proposal_std=mh_std,
                proposal_cov=mh_proposal_cov,
            )
            t1 = time.perf_counter()
            post = chain_mh[::thin]
            metadata = {
                'alg': 'mh',
                'd': d_cur,
                'n_iters': n_iters_dim,
                'burn_in': burn_in_cur,
                'thin': thin,
                'seed_mcmc': seed_mcmc,
                'seed_data': seed_data,
                'proposal_std': mh_std,
                'proposal_cov': mh_proposal_cov,
                'acceptance': acc,
                'runtime_sec': t1 - t0,
            }
            save_chain(out_path, post, metadata)
            print(f'Saved {out_path.name}')

if not run_mess and not run_mh:
    print('Set run_mess and/or run_mh to True to generate chains.')
else:
    # High-level run context.
    print('Starting sweep with config:')
    print({
        'seed_mcmc': seed_mcmc,
        'seed_data': seed_data,
        'n_iters_dim': n_iters,
        'burn_in': burn_in,
        'thin': thin,
        'max_lag': max_lag,
        'd_list': d_list,
        'M_list': M_list,
        'output_dir': str(output_dir),
        'run_mess': run_mess,
        'run_mh': run_mh,
        'mh_proposal_cov': mh_proposal_cov,
        'use_parallel': use_parallel,
        'max_workers': max_workers,
        'recompute_corrupt_chains': recompute_corrupt_chains,
    })

    config = {
        'seed_mcmc': seed_mcmc,
        'seed_data': seed_data,
        'n_iters_dim': n_iters,
        'burn_in': burn_in,
        'thin': thin,
        'M_list': M_list,
        'run_mess': run_mess,
        'run_mh': run_mh,
        'output_dir': str(output_dir),
        'mh_proposal_stds_scaled': mh_proposal_stds_scaled,
        'mh_proposal_cov': mh_proposal_cov,
        'recompute_corrupt_chains': recompute_corrupt_chains,
    }

    tasks = list(enumerate(d_list))
    if use_parallel and max_workers > 1:
        print(f'Parallel sweep by dimension with max_workers={max_workers}')
        try:
            mp_context = mp.get_context('fork')
        except ValueError:
            mp_context = None
        if mp_context is None:
            print('Fork context unavailable; falling back to serial execution.')
            for d_idx, d_cur in tasks:
                _run_dim_task(d_idx, d_cur, config)
        else:
            with ProcessPoolExecutor(max_workers=max_workers, mp_context=mp_context) as executor:
                list(executor.map(_run_dim_task, [t[0] for t in tasks], [t[1] for t in tasks], repeat(config)))
    else:
        print('Serial sweep by dimension')
        for d_idx, d_cur in tasks:
            _run_dim_task(d_idx, d_cur, config)

print(f"Sweep completed. Chains saved to: {output_dir}")

## Traceplots

In [None]:
# Publication-quality traceplots: components 1,2,9; d=10,25,30; MESS M=1/M=50 and MH.
trace_iters = 10000
plot_dims = d_list
plot_components = [0, 1, 2, 9]
plot_Ms = [1, 50]
save_figs = True
trace_dir = reports_dir / "traceplots_pub"
trace_dir.mkdir(parents=True, exist_ok=True)

plt.rcParams.update({
    "figure.dpi": 120
    ,"savefig.dpi": 600
    ,"font.size": 14
    ,"axes.titlesize": 15
    ,"axes.labelsize": 14
    ,"axes.linewidth": 0.9
    ,"xtick.labelsize": 12
    ,"ytick.labelsize": 12
    ,"legend.fontsize": 12
    ,"lines.linewidth": 0.9
})

M_color_values = [1, 10, 20, 50, 100, 200, 300]
M_colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(M_color_values)))
M_color_map = {M: color for M, color in zip(M_color_values, M_colors)}

algorithms = [
    ("mess", "MESS (M=1)", plot_Ms[0], M_color_map.get(plot_Ms[0], "#1b9e77")),
    ("mess", "MESS (M=50)", plot_Ms[1], M_color_map.get(plot_Ms[1], "#d95f02")),
    ("mh", "MH", None, "#7570b3"),
]

def _load_trace_chain(d_cur, alg_key, M=None):
    if alg_key == "mh":
        d_idx = d_list.index(d_cur)
        mh_std = mh_proposal_stds_scaled[d_idx]
        return load_chain(
            output_dir,
            d_cur,
            "mh",
            proposal_std=mh_std,
            proposal_cov=mh_proposal_cov,
        )
    return load_chain(output_dir, d_cur, "mess", M=M)

def _component_label(dim_value, comp_idx):
    count = 0
    for i in range(dim_value):
        for j in range(i + 1, dim_value):
            if count == comp_idx:
                return f"a_{{{i}{j}}}"
            count += 1
    return f"param_{comp_idx}"

def _warn_if_flat(chain, comp_idx, label):
    if chain is None or chain.size == 0 or comp_idx >= chain.shape[1]:
        return
    series = chain[:trace_iters, comp_idx]
    series_std = float(np.std(series))
    if series_std == 0.0:
        print(f"Warning: flat trace for {label} comp={comp_idx} (std=0).")
    elif series_std < 1e-12:
        print(f"Warning: nearly flat trace for {label} comp={comp_idx} (std={series_std:.2e}).")

for comp in plot_components:
    fig, axes = plt.subplots(
        nrows=len(plot_dims),
        ncols=len(algorithms),
        figsize=(10.5, 6.5),
        sharex=True,
        constrained_layout=True,
    )
    axes = np.atleast_2d(axes)
    comp_label = _component_label(max(plot_dims), comp)
    fig.suptitle(f"Traceplots for ${comp_label}$ (first {trace_iters} iterations)")

    # Preload chains and compute per-dimension y-limits.
    chain_cache = {}
    y_limits = {}
    for d_cur in plot_dims:
        for alg_key, alg_label, M, _ in algorithms:
            chain = _load_trace_chain(d_cur, alg_key, M=M)
            chain_cache[(d_cur, alg_key, M)] = chain
            if chain is None or chain.size == 0:
                continue
            series = chain[:trace_iters, comp]
            if d_cur not in y_limits:
                y_limits[d_cur] = [series.min(), series.max()]
            else:
                y_limits[d_cur][0] = min(y_limits[d_cur][0], series.min())
                y_limits[d_cur][1] = max(y_limits[d_cur][1], series.max())

    for row_idx, d_cur in enumerate(plot_dims):
        for col_idx, (alg_key, alg_label, M, color) in enumerate(algorithms):
            ax = axes[row_idx, col_idx]
            chain = chain_cache[(d_cur, alg_key, M)]
            if chain is None or chain.size == 0:
                ax.axis("off")
                continue
            series = chain[:trace_iters, comp]
            ax.plot(series, color=color, alpha=0.85)
            ax.set_ylim(y_limits[d_cur])
            if row_idx == 0:
                ax.set_title(alg_label)
            if col_idx == 0:
                ax.set_ylabel(f"d={d_cur}")
            if row_idx == len(plot_dims) - 1:
                ax.set_xlabel("Iteration")

            _warn_if_flat(chain, comp, f"{alg_label} (d={d_cur})")

    if save_figs:
        fig_path = trace_dir / f"traceplots_comp{comp}.png"
        fig.savefig(fig_path, bbox_inches="tight")
        print(f"Saved {fig_path}")

    plt.show()

In [None]:
# Trace/hist panel with stacked traces on left, histograms on right.
trace_iters = 30000
hist_bins = 30
panel_comps = [0, 1]
hist_d = d_list[0]
font_size = 18
mess_plot_Ms = [1, 10, 50, 100]
save_figs = True
panel_dir = reports_dir / "trace_hist_panels"
panel_dir.mkdir(parents=True, exist_ok=True)
trace_xticks = [0, 5000, 10000, 15000, 20000]

def get_component_labels(dim_value, comps):
    labels = []
    for k in comps:
        count = 0
        found = False
        for i in range(dim_value):
            for j in range(i + 1, dim_value):
                if count == k:
                    labels.append(f"a_{{{i}{j}}}")
                    found = True
                    break
                count += 1
            if found:
                break
        if not found:
            labels.append(f"param_{k}")
    return labels

def _load_chain_for_plot(d_cur, alg, M=None):
    if alg == 'mh':
        d_idx = d_list.index(d_cur)
        mh_std = mh_proposal_stds_scaled[d_idx]
        return (
            load_chain(
                output_dir,
                d_cur,
                'mh',
                proposal_std=mh_std,
                proposal_cov=mh_proposal_cov,
            ),
            f'MH d={d_cur}',
        )
    return load_chain(output_dir, d_cur, 'mess', M=M), f'MESS d={d_cur} M={M}'

def plot_trace_hist_panel(d_cur, comp):
    mess_colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(M_list)))
    mess_color_map = {M: color for M, color in zip(M_list, mess_colors)}
    trace_colors = {
        'mh': 'black',
        **{f'mess{M}': mess_color_map.get(M, 'gray') for M in mess_plot_Ms},
    }
    algorithms = [
        ('mess100', 'MESS (M=100)', trace_colors['mess100'], 100),
        ('mess50', 'MESS (M=50)', trace_colors['mess50'], 50),
        ('mess10', 'MESS (M=10)', trace_colors['mess10'], 10),
        ('mess1', 'MESS (M=1)', trace_colors['mess1'], 1),
        ('mh', 'MH', trace_colors['mh'], None),
    ]
    hist_algorithms = [
        ('mh', 'MH', trace_colors['mh'], None),
        ('mess1', 'MESS (M=1)', trace_colors['mess1'], 1),
        ('mess10', 'MESS (M=10)', trace_colors['mess10'], 10),
        ('mess50', 'MESS (M=50)', trace_colors['mess50'], 50),
        ('mess100', 'MESS (M=100)', trace_colors['mess100'], 100),
    ]
    chains_dict = {}
    labels_dict = {}
    for key, label, _, M in algorithms:
        chain, _ = _load_chain_for_plot(d_cur, 'mh' if key == 'mh' else 'mess', M=M)
        if chain is None or chain.size == 0:
            print(f'Missing chain for {label} at d={d_cur}')
            return
        chains_dict[key] = chain[burn_in::thin]
        labels_dict[key] = label

    data_hist = get_dataset_for_dim(d_cur, seed=seed_data)
    true_val = float(data_hist['a_true'][comp])
    label_map = dict(zip(panel_comps, get_component_labels(d_cur, panel_comps)))

    n_algs = len(algorithms)
    fig = plt.figure(figsize=(13, 5.0))
    gs = fig.add_gridspec(n_algs, 2, width_ratios=[2.0, 1.2], wspace=0.25, hspace=0.35)

    all_data = [chains_dict[key][:trace_iters, comp] for key, _, _, _ in algorithms]
    data_min = np.min([np.min(d) for d in all_data])
    data_max = np.max([np.max(d) for d in all_data])
    data_range = data_max - data_min if data_max != data_min else 1.0

    # Stacked traces
    for alg_idx, (alg_key, alg_label, color, _) in enumerate(algorithms):
        ax_trace = fig.add_subplot(gs[alg_idx, 0])
        series = chains_dict[alg_key][:trace_iters, comp]
        ax_trace.plot(series, color=color, linewidth=0.5, label=alg_label)
        ax_trace.set_ylim([data_min - 0.05 * data_range, data_max + 0.05 * data_range])
        ax_trace.set_xlim([0, trace_iters])
        ax_trace.set_xticks(trace_xticks)
        ax_trace.grid(alpha=0.2)
        ax_trace.tick_params(labelsize=font_size - 2)
        ax_trace.spines['top'].set_visible(False)
        ax_trace.spines['right'].set_visible(False)
        if alg_idx < n_algs - 1:
            ax_trace.set_xticklabels([])
        else:
            ax_trace.set_xlabel('Iteration', fontsize=font_size)

    # Histogram panel
    ax_hist = fig.add_subplot(gs[:, 1])
    legend_handles = []
    legend_labels = []
    for alg_key, alg_label, color, _ in hist_algorithms:
        _, _, patches = ax_hist.hist(
            chains_dict[alg_key][:, comp],
            bins=hist_bins,
            density=True,
            alpha=0.35,
            color=color,
            label=alg_label,
        )
        if patches:
            legend_handles.append(patches[0])
            legend_labels.append(alg_label)
    true_line = ax_hist.axvline(
        true_val,
        color='black',
        linestyle='--',
        linewidth=2,
        alpha=0.8,
        label='True value',
    )
    ax_hist.set_xlabel('Value', fontsize=font_size)
    ax_hist.set_ylabel('Density', fontsize=font_size)
    ax_hist.grid(alpha=0.2)
    ax_hist.tick_params(labelsize=font_size - 2)
    comp_label = label_map.get(comp, f'param {comp}')

    legend_order = ['MH', 'MESS (M=1)', 'MESS (M=10)', 'MESS (M=50)', 'MESS (M=100)']
    legend_map = dict(zip(legend_labels, legend_handles))
    ordered_labels = [label for label in legend_order if label in legend_map]
    ordered_handles = [legend_map[label] for label in ordered_labels]
    if len(ordered_labels) != len(legend_labels):
        ordered_handles = legend_handles
        ordered_labels = legend_labels

    fig.legend(
        ordered_handles,
        ordered_labels,
        loc='upper center',
        ncol=len(ordered_labels),
        bbox_to_anchor=(0.5, 1.04),
        fontsize=font_size - 2,
        frameon=False,
    )
    ax_hist.legend(
        handles=[true_line],
        labels=['True value'],
        loc='lower center',
        bbox_to_anchor=(0.5, 0.98),
        fontsize=font_size - 2,
        frameon=False,
        borderaxespad=0.0,
    )
    plt.tight_layout(rect=[0.0, 0.0, 1.0, 0.86])
    if save_figs:
        fig_path = panel_dir / f"trace_hist_d{d_cur}_comp{comp}.png"
        fig.savefig(fig_path, dpi=600, bbox_inches="tight")
        print(f"Saved {fig_path}")
    plt.show()

for comp in panel_comps:
    plot_trace_hist_panel(hist_d, comp)

## Load Chains and Compute ESS/MSJD

In [None]:
def compute_metrics_by_dim(
    burnin=5000,
    max_lag=1500,
    components=None,
    metrics_path=None,
    force_recompute=False,
    save_every_update=True,
 ):
    def _unwrap_loaded(value):
        if isinstance(value, np.ndarray) and value.dtype == object and value.size == 1:
            return value.item()
        if isinstance(value, np.ndarray):
            return value.tolist()
        return value

    def _ensure_list(values, length, fill_value=np.nan):
        values = list(values) if values is not None else []
        if len(values) < length:
            values = values + [fill_value] * (length - len(values))
        if len(values) > length:
            values = values[:length]
        return values

    def _ensure_by_M(metrics_dict, fill_value=np.nan):
        metrics_dict = metrics_dict or {}
        for M in M_list:
            metrics_dict[M] = _ensure_list(metrics_dict.get(M, []), len(d_list), fill_value=fill_value)
        return metrics_dict

    def _ensure_by_M_components(metrics_dict, components_count):
        metrics_dict = metrics_dict or {}
        fill_value = [np.nan] * components_count
        for M in M_list:
            metrics_dict[M] = _ensure_list(
                metrics_dict.get(M, []),
                len(d_list),
                fill_value=fill_value,
            )
        return metrics_dict

    def _is_missing(val):
        if val is None:
            return True
        try:
            return bool(np.isnan(val))
        except TypeError:
            return False

    def _apply_burnin(chain):
        if chain is None:
            return None
        if burnin is None or burnin <= 0:
            return chain
        if chain.shape[0] <= burnin:
            return chain[:0]
        return chain[burnin:]

    def _normalize_components(components_in, dim):
        if components_in is None:
            return list(range(dim))
        return [int(c) for c in components_in]

    def _split_components(components_idx, dim):
        valid_indices = []
        valid_positions = []
        for pos, idx in enumerate(components_idx):
            if 0 <= idx < dim:
                valid_indices.append(idx)
                valid_positions.append(pos)
        return valid_indices, valid_positions

    def _compute_from_chain(chain, components_idx):
        components_count = len(components_idx)
        selected_ess = [np.nan] * components_count
        selected_msjd = [np.nan] * components_count
        if chain is None or chain.size == 0:
            return selected_ess, selected_msjd
        dim = chain.shape[1]
        valid_indices, valid_positions = _split_components(components_idx, dim)
        if not valid_indices:
            return selected_ess, selected_msjd
        chain_sel = chain[:, valid_indices]
        ess_vals = compute_ess_per_param(chain_sel, max_lag=max_lag)
        msjd_vals = compute_msjd_per_param(chain_sel)
        for pos, val in zip(valid_positions, ess_vals):
            selected_ess[pos] = float(val)
        for pos, val in zip(valid_positions, msjd_vals):
            selected_msjd[pos] = float(val)
        return selected_ess, selected_msjd

    def _mean_selected(selected):
        arr = np.array(selected, dtype=float)
        if arr.size == 0:
            return np.nan
        if np.all(np.isnan(arr)):
            return np.nan
        return float(np.nanmean(arr))

    def _save_metrics(metrics_dict):
        if metrics_path is None:
            return
        metrics_to_save = dict(metrics_dict)
        metrics_to_save['components'] = np.array(components_list, dtype=int)
        metrics_to_save.pop('d_list', None)
        metrics_to_save.pop('M_list', None)
        np.savez_compressed(
            metrics_path,
            **metrics_to_save,
            d_list=np.array(d_list),
            M_list=np.array(M_list),
        )
        print(f'Saved metrics to {metrics_path}')

    def _align_metrics_to_d_list(metrics_dict):
        cached_d_list = metrics_dict.get('d_list')
        if cached_d_list is None:
            return metrics_dict
        cached_d_list = [int(v) for v in list(cached_d_list)]
        if cached_d_list == list(d_list):
            return metrics_dict
        dim_map = {dim: idx for idx, dim in enumerate(cached_d_list)}
        components_cached = metrics_dict.get('components', [])
        components_count = len(list(components_cached))
        fill_comp = [np.nan] * components_count
        metrics_dict['d_list'] = list(d_list)

        def _remap_list(values, fill_value):
            values = list(values) if values is not None else []
            remapped = [fill_value] * len(d_list)
            for new_idx, dim in enumerate(d_list):
                old_idx = dim_map.get(dim)
                if old_idx is None or old_idx >= len(values):
                    continue
                remapped[new_idx] = values[old_idx]
            return remapped

        def _remap_by_M(metrics_by_M, fill_value):
            metrics_by_M = metrics_by_M or {}
            remapped = {}
            for M, series in metrics_by_M.items():
                remapped[M] = _remap_list(series, fill_value)
            return remapped

        metrics_dict['ess_by_M'] = _remap_by_M(metrics_dict.get('ess_by_M'), np.nan)
        metrics_dict['msjd_by_M'] = _remap_by_M(metrics_dict.get('msjd_by_M'), np.nan)
        metrics_dict['ess_by_M_components'] = _remap_by_M(
            metrics_dict.get('ess_by_M_components'),
            fill_comp,
        )
        metrics_dict['msjd_by_M_components'] = _remap_by_M(
            metrics_dict.get('msjd_by_M_components'),
            fill_comp,
        )
        metrics_dict['ess_mh'] = _remap_list(metrics_dict.get('ess_mh'), np.nan)
        metrics_dict['msjd_mh'] = _remap_list(metrics_dict.get('msjd_mh'), np.nan)
        metrics_dict['ess_mh_components'] = _remap_list(
            metrics_dict.get('ess_mh_components'),
            fill_comp,
        )
        metrics_dict['msjd_mh_components'] = _remap_list(
            metrics_dict.get('msjd_mh_components'),
            fill_comp,
        )
        return metrics_dict

    metrics = None
    did_update = False
    components_list = None
    if metrics_path is not None and metrics_path.exists() and not force_recompute:
        print(f'Using cached metrics at {metrics_path}')
        cached = dict(np.load(metrics_path, allow_pickle=True))
        metrics = {k: _unwrap_loaded(v) for k, v in cached.items()}
        if components is not None and 'components' in metrics:
            cached_components = [int(v) for v in list(metrics['components'])]
            if list(components) != cached_components:
                print('Cached metrics components do not match; recomputing.')
                metrics = None
        if metrics is not None:
            metrics = _align_metrics_to_d_list(metrics)

    if metrics is None:
        # Initialize metrics from scratch.
        components_list = _normalize_components(components, dim=max(d_list))
        components_count = len(components_list)
        metrics = {
            'components': components_list,
            'ess_by_M': {M: [] for M in M_list},
            'msjd_by_M': {M: [] for M in M_list},
            'ess_by_M_components': {M: [] for M in M_list},
            'msjd_by_M_components': {M: [] for M in M_list},
            'ess_mh': [],
            'msjd_mh': [],
            'ess_mh_components': [],
            'msjd_mh_components': [],
        }
        for d_idx, d_cur in enumerate(d_list):
            for M in M_list:
                chain = load_chain(output_dir, d_cur, 'mess', M=M)
                chain = _apply_burnin(chain)
                if chain is None:
                    metrics['ess_by_M'][M].append(np.nan)
                    metrics['msjd_by_M'][M].append(np.nan)
                    metrics['ess_by_M_components'][M].append([np.nan] * components_count)
                    metrics['msjd_by_M_components'][M].append([np.nan] * components_count)
                    continue
                selected_ess, selected_msjd = _compute_from_chain(chain, components_list)
                metrics['ess_by_M'][M].append(_mean_selected(selected_ess))
                metrics['msjd_by_M'][M].append(_mean_selected(selected_msjd))
                metrics['ess_by_M_components'][M].append(selected_ess)
                metrics['msjd_by_M_components'][M].append(selected_msjd)
                did_update = True
                if save_every_update:
                    _save_metrics(metrics)

            mh_std = mh_proposal_stds_scaled[d_idx]
            chain_mh = load_chain(
                output_dir,
                d_cur,
                'mh',
                proposal_std=mh_std,
                proposal_cov=mh_proposal_cov,
            )
            chain_mh = _apply_burnin(chain_mh)
            if chain_mh is None:
                metrics['ess_mh'].append(np.nan)
                metrics['msjd_mh'].append(np.nan)
                metrics['ess_mh_components'].append([np.nan] * components_count)
                metrics['msjd_mh_components'].append([np.nan] * components_count)
            else:
                selected_ess, selected_msjd = _compute_from_chain(chain_mh, components_list)
                metrics['ess_mh'].append(_mean_selected(selected_ess))
                metrics['msjd_mh'].append(_mean_selected(selected_msjd))
                metrics['ess_mh_components'].append(selected_ess)
                metrics['msjd_mh_components'].append(selected_msjd)
                did_update = True
                if save_every_update:
                    _save_metrics(metrics)

        return metrics, did_update

    components_list = [int(v) for v in list(metrics.get('components', []))]
    if not components_list:
        components_list = _normalize_components(components, dim=max(d_list))
        metrics['components'] = components_list
    components_count = len(components_list)

    metrics['ess_by_M'] = _ensure_by_M(metrics.get('ess_by_M'))
    metrics['msjd_by_M'] = _ensure_by_M(metrics.get('msjd_by_M'))
    metrics['ess_by_M_components'] = _ensure_by_M_components(
        metrics.get('ess_by_M_components'),
        components_count,
    )
    metrics['msjd_by_M_components'] = _ensure_by_M_components(
        metrics.get('msjd_by_M_components'),
        components_count,
    )
    metrics['ess_mh'] = _ensure_list(metrics.get('ess_mh', []), len(d_list))
    metrics['msjd_mh'] = _ensure_list(metrics.get('msjd_mh', []), len(d_list))
    metrics['ess_mh_components'] = _ensure_list(
        metrics.get('ess_mh_components', []),
        len(d_list),
        fill_value=[np.nan] * components_count,
    )
    metrics['msjd_mh_components'] = _ensure_list(
        metrics.get('msjd_mh_components', []),
        len(d_list),
        fill_value=[np.nan] * components_count,
    )

    for d_idx, d_cur in enumerate(d_list):
        for M in M_list:
            if any([
                _is_missing(metrics['ess_by_M'][M][d_idx]),
                _is_missing(metrics['msjd_by_M'][M][d_idx]),
            ]):
                chain = load_chain(output_dir, d_cur, 'mess', M=M)
                chain = _apply_burnin(chain)
                if chain is None:
                    continue
                selected_ess, selected_msjd = _compute_from_chain(chain, components_list)
                metrics['ess_by_M'][M][d_idx] = _mean_selected(selected_ess)
                metrics['msjd_by_M'][M][d_idx] = _mean_selected(selected_msjd)
                metrics['ess_by_M_components'][M][d_idx] = selected_ess
                metrics['msjd_by_M_components'][M][d_idx] = selected_msjd
                did_update = True
                if save_every_update:
                    _save_metrics(metrics)

        if any([
            _is_missing(metrics['ess_mh'][d_idx]),
            _is_missing(metrics['msjd_mh'][d_idx]),
        ]):
            mh_std = mh_proposal_stds_scaled[d_idx]
            chain_mh = load_chain(
                output_dir,
                d_cur,
                'mh',
                proposal_std=mh_std,
                proposal_cov=mh_proposal_cov,
            )
            chain_mh = _apply_burnin(chain_mh)
            if chain_mh is None:
                continue
            selected_ess, selected_msjd = _compute_from_chain(chain_mh, components_list)
            metrics['ess_mh'][d_idx] = _mean_selected(selected_ess)
            metrics['msjd_mh'][d_idx] = _mean_selected(selected_msjd)
            metrics['ess_mh_components'][d_idx] = selected_ess
            metrics['msjd_mh_components'][d_idx] = selected_msjd
            did_update = True
            if save_every_update:
                _save_metrics(metrics)

    return metrics, did_update

recompute_metrics = False
metrics_path = reports_dir / 'effss_msjd.npz'
components = [0, 1, 2, 3, 9, 10, 16, 17]  # Indices into a_true to summarize.
metrics, metrics_updated = compute_metrics_by_dim(
    burnin=10000,
    max_lag=1500,
    components=components,
    metrics_path=metrics_path,
    force_recompute=recompute_metrics,
 )
if metrics_path.exists() and not recompute_metrics and not metrics_updated:
    print(f'Metrics file already exists, not overwriting: {metrics_path}')
else:
    metrics_to_save = dict(metrics)
    metrics_to_save['components'] = np.array(metrics.get('components', components), dtype=int)
    metrics_to_save.pop('d_list', None)
    metrics_to_save.pop('M_list', None)
    np.savez_compressed(
        metrics_path,
        **metrics_to_save,
        d_list=np.array(d_list),
        M_list=np.array(M_list),
    )
    print(f'Saved metrics to {metrics_path}')
metrics.keys()

## Plot ESS/MSJD vs Dimension

In [None]:
def plot_ess_msjd_vs_dim(
    ess_by_M, ess_mh, msjd_by_M, msjd_mh, ylabel_ess, ylabel_msjd, filename, plot_dir, yscale=None,
    ):
    font_size = 24
    tick_size = 20
    legend_font_size = 24
    marker_size = 10
    fig, axes = plt.subplots(1, 2, figsize=(14.5, 6.5), sharex=True)
    colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(M_list)))

    for M, color in zip(M_list, colors):
        axes[0].plot(
            d_list, ess_by_M[M], marker='o', markersize=marker_size, color=color, label=f'M={M}'
        )
        axes[1].plot(
            d_list, msjd_by_M[M], marker='o', markersize=marker_size, color=color, label=f'M={M}'
        )

    axes[0].plot(
        d_list, ess_mh, marker='s', markersize=marker_size, color='black', linestyle='--', label='MH'
    )
    axes[1].plot(
        d_list, msjd_mh, marker='s', markersize=marker_size, color='black', linestyle='--', label='MH'
    )

    ess_label = ylabel_ess
    msjd_label = ylabel_msjd
    if yscale:
        ess_label = f"{ylabel_ess} (log)"
        msjd_label = f"{ylabel_msjd} (log)"
        axes[0].set_yscale(yscale)
        axes[1].set_yscale(yscale)

    axes[0].set_xlabel('d', fontsize=font_size)
    axes[1].set_xlabel('d', fontsize=font_size)
    axes[0].set_ylabel(ess_label, fontsize=font_size)
    axes[1].set_ylabel(msjd_label, fontsize=font_size)
    axes[0].grid(alpha=0.3)
    axes[1].grid(alpha=0.3)
    axes[0].tick_params(axis='both', labelsize=tick_size)
    axes[1].tick_params(axis='both', labelsize=tick_size)

    handles, labels = axes[0].get_legend_handles_labels()
    legend_ncol = min(len(labels), 6)
    fig.legend(
        handles, labels,
        loc='upper center',
        bbox_to_anchor=(0.5, 1.02),
        ncol=legend_ncol,
        frameon=False,
        fontsize=legend_font_size,
        columnspacing=1.2,
        handlelength=1.6,
        handletextpad=0.4,
    )
    fig.tight_layout(rect=(0, 0, 1, 0.92))
    plot_dir.mkdir(parents=True, exist_ok=True)
    fig.savefig(plot_dir / filename, dpi=300)
    return fig

def scale_metric_by_M(metric_by_M, scale):
    return {
        M: (np.array(values, dtype=float) / scale).tolist()
        for M, values in metric_by_M.items()
    }

def scale_metric_list(metric_list, scale):
    return (np.array(metric_list, dtype=float) / scale).tolist()

def _component_series(metric_by_M_components, metric_mh_components, comp_pos):
    by_M = {}
    for M in M_list:
        series = []
        for row in metric_by_M_components.get(M, []):
            if row is None or len(row) <= comp_pos:
                series.append(np.nan)
            else:
                series.append(row[comp_pos])
        by_M[M] = series
    mh_series = []
    for row in metric_mh_components:
        if row is None or len(row) <= comp_pos:
            mh_series.append(np.nan)
        else:
            mh_series.append(row[comp_pos])
    return by_M, mh_series

ess_ylabel = 'Eff. Sample Size'

ess_by_M_scaled = metrics['ess_by_M']
ess_mh_scaled = metrics['ess_mh']

components_list = [int(v) for v in metrics.get('components', [])]
component_positions = list(range(min(5, len(components_list))))
component_labels = [f'component {components_list[pos]}' for pos in component_positions]

plot_dir = reports_dir / 'ess_msjd_vs_d'

# Mean over selected components
plot_ess_msjd_vs_dim(
    ess_by_M_scaled, ess_mh_scaled,
    metrics['msjd_by_M'], metrics['msjd_mh'],
    ess_ylabel, 'MSJD', 'ess_msjd_vs_d_mean.png', plot_dir,
 )
plot_ess_msjd_vs_dim(
    ess_by_M_scaled, ess_mh_scaled,
    metrics['msjd_by_M'], metrics['msjd_mh'],
    ess_ylabel, 'MSJD', 'ess_msjd_vs_d_mean_log.png', plot_dir, yscale='log',
 )

In [None]:
# Per-component ESS/MSJD plots
if len(component_positions) == 0:
    print('No components available for plotting.')
else:
    for comp_pos, comp_label in zip(component_positions, component_labels):
        ess_by_M_comp, ess_mh_comp = _component_series(
            metrics['ess_by_M_components'],
            metrics['ess_mh_components'],
            comp_pos,
        )
        ess_by_M_comp_scaled = scale_metric_by_M(ess_by_M_comp, 1)
        ess_mh_comp_scaled = scale_metric_list(ess_mh_comp, 1)
        msjd_by_M_comp, msjd_mh_comp = _component_series(
            metrics['msjd_by_M_components'],
            metrics['msjd_mh_components'],
            comp_pos,
        )
        filename = f"ess_msjd_vs_d_{comp_label.replace(' ', '_')}.png"
        plot_ess_msjd_vs_dim(
            ess_by_M_comp_scaled, ess_mh_comp_scaled,
            msjd_by_M_comp, msjd_mh_comp,
            ess_ylabel, 'MSJD', filename, plot_dir,
        )
        log_filename = f"ess_msjd_vs_d_{comp_label.replace(' ', '_')}_log.png"
        plot_ess_msjd_vs_dim(
            ess_by_M_comp_scaled, ess_mh_comp_scaled,
            msjd_by_M_comp, msjd_mh_comp,
            ess_ylabel, 'MSJD', log_filename, plot_dir, yscale='log',
        )


## Pairplots

In [None]:
# Posterior hist grid (make_hist_grid_comps) for a selected chain.
from mess.plotting.diagnostics import make_hist_grid_comps

def get_component_labels(dim_value, comps):
    labels = []
    for k in comps:
        count = 0
        found = False
        for i in range(dim_value):
            for j in range(i + 1, dim_value):
                if count == k:
                    labels.append(f"$a_{{{i}{j}}}$")
                    found = True
                    break
                count += 1
            if found:
                break
        if not found:
            labels.append(f"$\\mathrm{{param}}_{{{k}}}$")
    return labels

hist_d = 10
hist_alg = 'mess'  # 'mess' or 'mh'
hist_M = 100  # only used for MESS
hist_burnin = 10000
hist_thin = 5
hist_params = 6
zoom_factor = 0.6
share_axes = True
font_size = 18
tick_label_size = font_size - 2
pairplots_dir = reports_dir / "pairplots"
pairplots_dir.mkdir(parents=True, exist_ok=True)

if hist_alg == 'mess':
    chain_hist = load_chain(output_dir, hist_d, 'mess', M=hist_M)
    hist_label = f'MESS (M={hist_M}), d={hist_d}'
elif hist_alg == 'mh':
    d_idx = d_list.index(hist_d)
    mh_std = mh_proposal_stds_scaled[d_idx]
    chain_hist = load_chain(output_dir, hist_d, 'mh', proposal_std=mh_std)
    hist_label = f'MH d={hist_d} sigma2={mh_std}'
else:
    raise ValueError("hist_alg must be 'mess' or 'mh'")

if chain_hist is None or chain_hist.size == 0:
    print(f'No chain found for {hist_label}')
else:
    post = chain_hist[hist_burnin::hist_thin]
    n_params = min(hist_params, post.shape[1])
    comp_list = np.arange(n_params)
    comp_list10 = [0, 1, 2, 3, 9, 10]
    comp_list10_short = [0, 1, 9]
    comp_list20 = [0, 1, 2, 3, 19, 20]
    comp_list50 = [0, 1, 2, 3, 49, 50]
    comp_list = comp_list10_short
    n_params = len(comp_list)
    label_map = dict(zip(comp_list, get_component_labels(hist_d, comp_list)))
    data_hist = get_dataset_for_dim(hist_d, seed=seed_data)
    prior_diag = data_hist['prior_diag']
    C = np.diag(prior_diag)
    vals = post[:, comp_list]
    max_abs = float(np.max(np.abs(vals))) if vals.size else 1.0
    true_vals = data_hist['a_true']
    max_abs = max(max_abs, float(np.max(np.abs(true_vals[comp_list]))))
    if max_abs == 0.0:
        max_abs = 1.0
    R_full = 0.9 * max_abs
    dr = R_full / 60.0
    M_tag = hist_M if hist_alg == 'mess' else 'NA'
    filename = f"pairplot_alg{hist_alg}_d{hist_d}_M{M_tag}_n{n_params}.png"
    fig = make_hist_grid_comps(
        R_full,
        dr,
        post,
        comp_list,
        save_path=pairplots_dir / filename,
        C=C,
        beta=0.95,
        hide_plot=False,
        label_map=label_map,
        font_size=font_size+8,
        title='', #f'{hist_label}',
        figsize=(12, 12),
        true_values=data_hist['a_true'],
        show_ellipses=False
    )
    for ax in fig.axes:
        ax.tick_params(axis='both', labelsize=tick_label_size)
    if share_axes and zoom_factor and 0 < zoom_factor < 1:
        R_zoom = zoom_factor * R_full
        axes = np.array(fig.axes).reshape(n_params, n_params)
        for i in range(n_params):
            for j in range(n_params):
                ax = axes[i, j]
                ax.set_xlim([-R_zoom, R_zoom])
                if i != j:
                    ax.set_ylim([-R_zoom, R_zoom])