In [None]:
# %% [markdown]
# # Particle Feynman Benchmark: Expanded Synthetic Library
# 
# This notebook demonstrates the performance of ML-II, MAP-II, and SMC-VFE on a diverse library of synthetic functions, ranging from smooth multiscale patterns to discontinuous waves and spiky signals.

In [None]:
import os
import sys
import time
from pathlib import Path

# Force CPU execution for stability with JAX Metal on macOS
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "True"

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

# Ensure repository root is in path
cwd = Path(os.getcwd())
repo_root = cwd
while repo_root.parent != repo_root:
    if (repo_root / 'infodynamics_jax').exists():
        break
    repo_root = repo_root.parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from infodynamics_jax.core import Phi
from infodynamics_jax.gp.kernels import rbf
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.likelihoods import get as get_likelihood
from infodynamics_jax.gp.predict import predict_typeii
from infodynamics_jax.gp.sparsify import fitc_log_evidence
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.inference.optimisation.vfe import make_vfe_objective
from infodynamics_jax.infodynamics import make_hyperprior
from utils import synthetic, compute_metrics, setup_plot_style, COLORS, plot_with_uncertainty
from utils.smc_array_only import annealed_smc_array

# Initialize plotting style
setup_plot_style()
matplotlib.use('module://matplotlib_inline.backend_inline')

In [None]:
class CFG:
    N_train = 120
    N_test = 60
    noise_std = 0.2
    domain = (-2.5, 2.5)
    M = 20
    n_particles = 64
    n_steps = 24
    ess_threshold = 0.6
    rejuvenation_steps = 2
    step_size = 0.02
    n_leapfrog = 8
    typeii_steps = 300
    typeii_lr = 1e-2

cfg = CFG()

# %% [markdown]
# ## Dataset Gallery
# Let's visualize the newly added synthetic functions.

In [None]:
all_functions = synthetic._functions.keys()
# Filter out some legacy ones for the gallery
gallery_functions = [f for f in all_functions if synthetic.get(f)[3] != 'periodic' and synthetic.get(f)[3] != 'linear' and synthetic.get(f)[3] != 'polynomial' and synthetic.get(f)[3] != 'smooth']

cols = 4
rows = int(np.ceil(len(gallery_functions) / cols))
plt.figure(figsize=(20, rows * 4))

x_plot = jnp.linspace(cfg.domain[0], cfg.domain[1], 500)
for i, name in enumerate(gallery_functions):
    fn, title, _, cat = synthetic.get(name)
    plt.subplot(rows, cols, i + 1)
    plt.plot(x_plot, fn(x_plot), lw=2)
    plt.title(f"{title}\n({cat})", fontsize=10)
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# %% [markdown]
# ## Inference Utilities

In [None]:
def unpack_state(theta, shape_z):
    log_ell, log_sf2, log_sn2 = theta[0], theta[1], theta[2]
    Z = theta[3:].reshape(shape_z)
    return log_ell, log_sf2, log_sn2, Z

def get_energy_fn(X, y, hyperprior_fn, shape_z, jitter):
    def energy_theta(theta):
        log_ell, log_sf2, log_sn2, Z = unpack_state(theta, shape_z)
        params = KernelParams(lengthscale=jnp.exp(log_ell), variance=jnp.exp(log_sf2))
        noise_var = jnp.exp(log_sn2)
        E_fitc = -fitc_log_evidence(kernel_fn=rbf, params=params, X=X, y=y, Z=Z, noise_var=noise_var, jitter=jitter)
        phi = Phi(kernel_params=params, Z=Z, likelihood_params={'noise_var': noise_var}, jitter=jitter)
        return E_fitc + hyperprior_fn(phi)
    return energy_theta

def predict_bma(particles, logw, X_star, X_tr, Y_tr, shape_z, jitter):
    w = jnp.exp(logw - jax.scipy.special.logsumexp(logw))
    mus, vars_ = [], []
    for i in range(len(w)):
        ll, lv, ln, Z = unpack_state(particles[i], shape_z)
        phi_i = Phi(KernelParams(jnp.exp(ll), jnp.exp(lv)), Z, {'noise_var': jnp.exp(ln)}, jitter)
        m, v = predict_typeii(phi_i, X_star, X_tr, Y_tr, rbf, residual='fitc')
        mus.append(m); vars_.append(v)
    mus, vars_ = jnp.stack(mus), jnp.stack(vars_)
    mean_bma = (w[:, None] * mus).sum(axis=0)
    var_bma = (w[:, None] * (vars_ + mus**2)).sum(axis=0) - mean_bma**2
    return mean_bma, jnp.sqrt(jnp.maximum(var_bma, 1e-12)), mus, w

# %% [markdown]
# ## Unified Benchmark Function

In [None]:
def run_full_benchmark(name, key):
    print(f"\nEvaluating: {name}...")
    fn, title, _, _ = synthetic.get(name)
    key_data, key_smc, key_init = jax.random.split(key, 3)
    
    # Data
    X_all, Y_all, _ = synthetic.sample(name, N=cfg.N_train + cfg.N_test, noise=cfg.noise_std, domain=cfg.domain, key=key_data)
    X_all = X_all[:, None]
    perm = jax.random.permutation(key_data, X_all.shape[0])
    X_tr, Y_tr = X_all[perm[:cfg.N_train]], Y_all[perm[:cfg.N_train]]
    X_te, Y_te = X_all[perm[cfg.N_train:]], Y_all[perm[cfg.N_train:]]
    X_plot = jnp.linspace(cfg.domain[0], cfg.domain[1], 240)[:, None]
    Y_plot = fn(X_plot[:, 0])

    # Setup
    Z0 = jnp.linspace(cfg.domain[0], cfg.domain[1], cfg.M)[:, None]
    phi_init = Phi(KernelParams(jnp.array(1.0), jnp.array(1.0)), Z0, {'noise_var': jnp.array(cfg.noise_std**2)}, 1e-6)
    hyperprior_fn = make_hyperprior(kernel_log_lambda=4.0, kernel_fields=["lengthscale", "variance"], likelihood_log_lambda=4.0, likelihood_keys=["noise_var"], likelihood_log_mu={'noise_var': jnp.log(cfg.noise_std**2)})

    # ML-II
    typeii = TypeII(cfg=TypeIICFG(steps=cfg.typeii_steps, lr=cfg.typeii_lr, optimizer='adam', jit=True))
    vfe_obj = make_vfe_objective(kernel_fn=rbf, residual='fitc')
    res_ml = typeii.run(energy=vfe_obj, phi_init=phi_init, energy_args=(X_tr, Y_tr))
    phi_ml = res_ml.phi

    # MAP-II
    def map_ii_obj(phi, X, y): return vfe_obj(phi, X, y) + hyperprior_fn(phi)
    res_map = typeii.run(energy=map_ii_obj, phi_init=phi_init, energy_args=(X_tr, Y_tr))
    phi_map = res_map.phi

    # SMC
    energy_fn = get_energy_fn(X_tr, Y_tr, hyperprior_fn, Z0.shape, phi_init.jitter)
    def init_particles(k, n):
        kl, kv, kn, kz = jax.random.split(k, 4)
        log_l = jnp.log(phi_init.kernel_params.lengthscale) + jax.random.normal(kl, (n,)) * 0.5
        log_v = jnp.log(phi_init.kernel_params.variance) + jax.random.normal(kv, (n,)) * 0.5
        log_n = jnp.log(phi_init.likelihood_params['noise_var']) + jax.random.normal(kn, (n,)) * 0.5
        Z_noisy = Z0[None] + 0.2 * jax.random.normal(kz, (n, *Z0.shape))
        return jnp.concatenate([log_l[:,None], log_v[:,None], log_n[:,None], Z_noisy.reshape(n, -1)], axis=1)

    smc_res = annealed_smc_array(key=key_smc, init_particles=init_particles(key_init, cfg.n_particles), energy_fn=energy_fn, n_steps=cfg.n_steps, ess_threshold=cfg.ess_threshold, step_size=cfg.step_size, n_leapfrog=cfg.n_leapfrog, rejuvenation_steps=cfg.rejuvenation_steps)

    # Predictions
    ml_mean, ml_var = predict_typeii(phi_ml, X_plot, X_tr, Y_tr, rbf, residual='fitc')
    map_mean, map_var = predict_typeii(phi_map, X_plot, X_tr, Y_tr, rbf, residual='fitc')
    smc_mean, smc_std, smc_curves, smc_weights = predict_bma(smc_res['particles'], smc_res['logw'], X_plot, X_tr, Y_tr, Z0.shape, phi_init.jitter)

    # Comparison Plot
    plt.figure(figsize=(10, 5))
    plt.scatter(X_tr, Y_tr, s=10, alpha=0.3, color="C7", label="Train")
    plt.plot(X_plot, Y_plot, "k--", alpha=0.5, label="True f")
    plt.plot(X_plot, ml_mean, color="C1", lw=2, label="ML-II")
    plt.plot(X_plot, map_mean, color="C2", lw=2, label="MAP-II")
    plt.plot(X_plot, smc_mean, color="C0", lw=2, label="SMC-VFE")
    plt.fill_between(X_plot.flatten(), smc_mean - 2*smc_std, smc_mean + 2*smc_std, alpha=0.1, color="C0")
    
    # Inducing points markers
    smc_Z_mean = (smc_weights[:, None, None] * smc_res['particles'][:, 3:].reshape(-1, *Z0.shape)).sum(axis=0)
    for z in phi_ml.Z.flatten(): plt.axvline(float(z), color="C1", ls=":", alpha=0.2)
    for z in phi_map.Z.flatten(): plt.axvline(float(z), color="C2", ls="--", alpha=0.2)
    for z in smc_Z_mean.flatten(): plt.axvline(float(z), color="C0", ls="-", alpha=0.1)

    plt.title(f"{name}: Methods Comparison")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout(); plt.show()

    # Metrics
    m_te_smc, s_te_smc, _, _ = predict_bma(smc_res['particles'], smc_res['logw'], X_te, X_tr, Y_tr, Z0.shape, phi_init.jitter)
    m_te_ml, v_te_ml = predict_typeii(phi_ml, X_te, X_tr, Y_tr, rbf, residual='fitc')
    m_te_map, v_te_map = predict_typeii(phi_map, X_te, X_tr, Y_tr, rbf, residual='fitc')
    
    return {
        'ML-II': compute_metrics(Y_te, m_te_ml, jnp.sqrt(v_te_ml)),
        'MAP-II': compute_metrics(Y_te, m_te_map, jnp.sqrt(v_te_map)),
        'SMC-VFE': compute_metrics(Y_te, m_te_smc, s_te_smc)
    }

# %% [markdown]
# ## Run Benchmark on Selected Functions
# We select a representative set of functions to benchmark.

In [None]:
benchmark_set = ['nonstationary_frequency', 'piecewise_kink', 'step_local_variation', 'spike_train']
results = {}
key = jax.random.key(123)

for name in benchmark_set:
    key, subkey = jax.random.split(key)
    results[name] = run_full_benchmark(name, subkey)

# %% [markdown]
# ## Final Summary Table

In [None]:
print("\n" + "="*60)
print(f"{'Dataset':<25} | {'Method':<10} | {'RMSE':<8} | {'NLPD':<8}")
print("-" * 60)
for ds, res in results.items():
    for method in ['ML-II', 'MAP-II', 'SMC-VFE']:
        m = res[method]
        print(f"{ds:<25} | {method:<10} | {m['rmse']:<8.4f} | {m['nlpd']:<8.4f}")
print("="*60)
