# Bayesian Inference with CMC

This notebook demonstrates the complete Bayesian analysis workflow:

1. NLSQ warm-start optimization
2. CMC (Consensus Monte Carlo) configuration
3. Running NUTS sampling via NumPyro
4. Posterior analysis with ArviZ
5. Comparing NLSQ and CMC uncertainty estimates

**When to use CMC:** for publication-quality uncertainty estimates, multi-modal
posteriors, or uncertainty propagation into derived quantities.

**Always run NLSQ first** as a warm-start — this reduces divergence rates from
~28% (cold start) to < 5%.

---

## 1. Setup

In [None]:
import os
import tempfile

import arviz as az
import matplotlib.pyplot as plt
import numpy as np

from homodyne.config import ConfigManager
from homodyne.optimization.cmc import fit_mcmc_jax
from homodyne.optimization.cmc.sampler import SamplingPlan
from homodyne.optimization.nlsq import fit_nlsq_jax
from homodyne.utils.logging import get_logger, log_phase

logger = get_logger(__name__)
print(f"ArviZ version: {az.__version__}")

## 2. Generate Synthetic Data

In [None]:
rng = np.random.default_rng(seed=777)

TRUE = {
    "D0": 1500.0,  # Å²/s
    "alpha": -0.4,
    "D_offset": 0.03,
    "contrast": 0.10,
    "offset": 1.0,
}

q = 0.054  # Å⁻¹
n_t = 35
n_phi = 6
dt = 0.1
t = dt * np.arange(n_t)
phi_deg = np.linspace(0, 300, n_phi)

c2 = np.zeros((n_phi, n_t, n_t))
for i_phi in range(n_phi):
    for i_t1 in range(n_t):
        for i_t2 in range(i_t1, n_t):
            t1_v, t2_v = t[i_t1], t[i_t2]
            J = TRUE["D0"] * (
                t2_v ** (TRUE["alpha"] + 1) - t1_v ** (TRUE["alpha"] + 1)
            ) / (TRUE["alpha"] + 1) + TRUE["D_offset"] * (t2_v - t1_v)
            val = TRUE["offset"] + TRUE["contrast"] * np.exp(-2 * q**2 * J)
            noise = 0.003 * rng.standard_normal()
            c2[i_phi, i_t1, i_t2] = val + noise
            c2[i_phi, i_t2, i_t1] = val + noise

data = {
    "c2_exp": c2,
    "t1": t,
    "t2": t,
    "phi_angles_list": phi_deg,
    "wavevector_q_list": np.array([q]),
    "sigma": 0.003 * np.ones_like(c2),
    "L": 5.0e6,
    "dt": dt,
}

print(f"Dataset: {c2.size:,} correlation values  ({n_phi} angles × {n_t}² times)")

## 3. Step 1: NLSQ Warm-Start

In [None]:
config_yaml = """
data:
  file_path: "dummy.h5"
  q_value: 0.054
  dt: 0.1

analysis:
  mode: "static"

optimization:
  method: "nlsq"
  nlsq:
    anti_degeneracy:
      per_angle_mode: "auto"

parameter_space:
  D0:
    initial: 1000.0
    bounds: [0.1, 1.0e5]
  alpha:
    initial: -0.5
    bounds: [-2.0, 1.0]
  D_offset:
    initial: 0.1
    bounds: [0.0, 100.0]
"""

with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
    f.write(config_yaml)
    config_path = f.name

config = ConfigManager.from_yaml(config_path)

print("Running NLSQ warm-start...")
with log_phase("NLSQ Warm-Start"):
    nlsq_result = fit_nlsq_jax(data, config)

print("\nNLSQ Result:")
print(f"  Convergence:  {nlsq_result.convergence_status}")
print(f"  chi^2_nu:     {nlsq_result.reduced_chi_squared:.4f}")
print(
    f"  D0:           {nlsq_result.parameters[0]:.1f} ± {nlsq_result.uncertainties[0]:.1f}"
)
print(
    f"  alpha:        {nlsq_result.parameters[1]:.3f} ± {nlsq_result.uncertainties[1]:.3f}"
)
print(
    f"  D_offset:     {nlsq_result.parameters[2]:.4f} ± {nlsq_result.uncertainties[2]:.4f}"
)

## 4. Prepare Pooled Data for CMC

CMC requires flat (pooled) arrays, not the (n_phi, n_t1, n_t2) format.

In [None]:
# Pool data from all angles into flat arrays
c2_arr = data["c2_exp"]  # (n_phi, n_t1, n_t2)
phi_arr = data["phi_angles_list"]  # (n_phi,)
t1_arr = data["t1"]
t2_arr = data["t2"]

# Create meshgrid for all (phi, t1, t2) combinations
PHI, T1, T2 = np.meshgrid(phi_arr, t1_arr, t2_arr, indexing="ij")
c2_flat = c2_arr.ravel()
phi_flat = PHI.ravel()
t1_flat = T1.ravel()
t2_flat = T2.ravel()

# Keep only upper triangle (t2 >= t1) to avoid double-counting
mask = t2_flat >= t1_flat
c2_pooled = c2_flat[mask]
phi_pooled = phi_flat[mask]
t1_pooled = t1_flat[mask]
t2_pooled = t2_flat[mask]

q_val = float(data["wavevector_q_list"][0])
L_val = float(data.get("L", 5.0e6))
dt_val = float(data.get("dt", 0.1))

print(
    f"Pooled data: {len(c2_pooled):,} points (from {c2_arr.size:,} total, upper triangle only)"
)

## 5. CMC Configuration

Configure CMC for this dataset size and analysis mode.

In [None]:
# CMC configuration
cmc_config = {
    "max_points_per_shard": "auto",  # ALWAYS use auto
    "sharding_strategy": "stratified",
    "num_warmup": 200,  # Small for demo; use 500+ for production
    "num_samples": 500,  # Small for demo; use 1500+ for production
    "num_chains": 2,  # 2 for demo; use 4 for production
    "max_tree_depth": 8,
    "adaptive_sampling": True,
    "per_angle_mode": "auto",  # Match NLSQ setting
    "validation": {
        "max_divergence_rate": 0.15,  # Slightly relaxed for demo
    },
}

# Check what SamplingPlan would use for this shard size
from homodyne.optimization.cmc.config import CMCConfig

cfg_obj = CMCConfig.from_dict(cmc_config)

shard_size_estimate = min(5000, len(c2_pooled))
plan = SamplingPlan.from_config(cfg_obj, shard_size=shard_size_estimate, n_params=5)
print(f"SamplingPlan for shard_size={shard_size_estimate}:")
print(f"  Warmup:   {plan.n_warmup}")
print(f"  Samples:  {plan.n_samples}")
print(f"  Adapted:  {plan.was_adapted}")

## 6. Run CMC

This is the main CMC call. The NLSQ result is passed as `nlsq_result` for
warm-start priors.

**Note:** On a real dataset this takes minutes to hours. For this demo notebook
with small n_samples, it should complete in a few minutes.

In [None]:
# Build parameter space from config
parameter_space = config.get_parameter_space()
initial_values = config.get_initial_parameters()

print("Running CMC analysis...")
print(f"  Dataset: {len(c2_pooled):,} pooled points")
print("  Mode:    static")
print(f"  Chains:  {cmc_config['num_chains']}")
print(f"  Warmup:  {cmc_config['num_warmup']}, Samples: {cmc_config['num_samples']}")
print()

with log_phase("CMC"):
    cmc_result = fit_mcmc_jax(
        data=c2_pooled,
        t1=t1_pooled,
        t2=t2_pooled,
        phi=phi_pooled,
        q=q_val,
        L=L_val,
        analysis_mode="static",
        cmc_config=cmc_config,
        initial_values=initial_values,
        parameter_space=parameter_space,
        dt=dt_val,
        nlsq_result=nlsq_result,  # NLSQ warm-start
        progress_bar=True,
    )

print("\nCMC Result:")
print(f"  Convergence:  {cmc_result.convergence_status}")
print(f"  Divergences:  {cmc_result.divergences}")
print(f"  Time:         {cmc_result.execution_time:.1f} s")

## 7. Convergence Diagnostics

In [None]:
print("Convergence Diagnostics")
print("=" * 50)

print("\nR-hat (should be < 1.05 for all parameters):")
all_ok = True
for param, rhat in cmc_result.r_hat.items():
    status = "OK" if rhat < 1.05 else "WARNING"
    if rhat >= 1.05:
        all_ok = False
    print(f"  {param:<25} {rhat:.4f}  [{status}]")
print(
    f"  → {'All R-hat OK' if all_ok else 'Some R-hat elevated (increase num_warmup)'}"
)

print("\nBulk ESS (should be > 400):")
for param, ess in cmc_result.ess_bulk.items():
    status = "OK" if ess >= 400 else "LOW"
    print(f"  {param:<25} {ess:.0f}  [{status}]")

n_total_transitions = cmc_result.n_chains * cmc_result.n_samples
div_rate = cmc_result.divergences / max(n_total_transitions, 1) * 100
print(
    f"\nDivergences: {cmc_result.divergences}/{n_total_transitions} = {div_rate:.1f}%"
)
if div_rate < 5:
    print("  → Excellent: < 5% divergences")
elif div_rate < 15:
    print("  → Acceptable: < 15% divergences")
else:
    print("  → High divergence rate: consider NLSQ warm-start, increase max_tree_depth")

## 8. Posterior Analysis with ArviZ

In [None]:
idata = cmc_result.inference_data

# Summary table
print("Posterior Summary")
print("=" * 60)
summary = az.summary(idata, var_names=cmc_result.param_names)
print(summary[["mean", "sd", "hdi_3%", "hdi_97%", "r_hat", "ess_bulk"]].to_string())

In [None]:
# Trace plots — check chain mixing
# (Only for parameters that exist in the inference data)
plot_vars = [v for v in ["D0", "alpha", "D_offset"] if v in idata.posterior]
if plot_vars:
    az.plot_trace(idata, var_names=plot_vars, compact=True)
    plt.suptitle("MCMC Trace Plots (visual convergence check)", y=1.02)
    plt.tight_layout()
    plt.show()
else:
    print(
        "No standard parameter names found in inference data (using reparameterized names)"
    )
    print(f"Available: {list(idata.posterior.data_vars)[:5]}")

In [None]:
# Posterior distributions
fig, axes = plt.subplots(1, 3, figsize=(13, 4))

param_display = [
    ("D0", "D₀ (Å²/s)", TRUE["D0"]),
    ("alpha", "α", TRUE["alpha"]),
    ("D_offset", "D_offset (Å²/s)", TRUE["D_offset"]),
]

for ax, (param_name, label, true_val) in zip(axes, param_display):
    # CMC posterior samples
    if param_name in cmc_result.samples:
        samples = cmc_result.samples[param_name].ravel()
        ax.hist(
            samples,
            bins=40,
            density=True,
            alpha=0.6,
            color="orange",
            label="CMC posterior",
        )
        ax.axvline(
            np.mean(samples),
            color="orange",
            linestyle="-",
            linewidth=2,
            label=f"CMC mean: {np.mean(samples):.3g}",
        )

    # NLSQ Gaussian approximation
    i = ["D0", "alpha", "D_offset"].index(param_name)
    nlsq_mean = nlsq_result.parameters[i]
    nlsq_std = nlsq_result.uncertainties[i]

    from scipy.stats import norm

    x_range = np.linspace(nlsq_mean - 5 * nlsq_std, nlsq_mean + 5 * nlsq_std, 200)
    ax.plot(
        x_range,
        norm.pdf(x_range, nlsq_mean, nlsq_std),
        "b-",
        linewidth=2,
        label=f"NLSQ: {nlsq_mean:.3g} ± {nlsq_std:.3g}",
    )

    # True value
    ax.axvline(
        true_val,
        color="green",
        linestyle="--",
        linewidth=2,
        label=f"True: {true_val:.3g}",
    )

    ax.set_xlabel(label)
    ax.set_ylabel("Probability density")
    ax.set_title(f"Posterior: {param_name}")
    ax.legend(fontsize=8)

plt.suptitle("NLSQ Gaussian Approximation vs CMC Posterior", fontsize=12)
plt.tight_layout()
plt.show()

## 9. NLSQ vs CMC Comparison

In [None]:
print("Parameter Comparison: NLSQ vs CMC")
print("=" * 65)
print(
    f"{'Parameter':<15} {'True':>10} {'NLSQ mean':>12} {'NLSQ std':>10} {'CMC mean':>12} {'CMC std':>10}"
)
print("-" * 65)

param_map = [("D0", 0), ("alpha", 1), ("D_offset", 2)]
true_map = {"D0": TRUE["D0"], "alpha": TRUE["alpha"], "D_offset": TRUE["D_offset"]}

for param_name, idx in param_map:
    true_val = true_map[param_name]
    nlsq_m = nlsq_result.parameters[idx]
    nlsq_s = nlsq_result.uncertainties[idx]

    if param_name in cmc_result.samples:
        samples = cmc_result.samples[param_name].ravel()
        cmc_m = np.mean(samples)
        cmc_s = np.std(samples)
    else:
        cmc_m = cmc_result.parameters[idx]
        cmc_s = cmc_result.uncertainties[idx]

    print(
        f"{param_name:<15} {true_val:>10.4g} {nlsq_m:>12.4g} {nlsq_s:>10.4g} {cmc_m:>12.4g} {cmc_s:>10.4g}"
    )

print()
print("Interpretation:")
print("  - CMC std / NLSQ std ratio > 2 → NLSQ underestimates uncertainty")
print("  - Consistent means → posterior is approximately Gaussian (NLSQ sufficient)")
print("  - Inconsistent means → multi-modal or non-Gaussian posterior")

## 10. Save Posterior

In [None]:
from pathlib import Path

output_dir = Path("bayesian_results")
output_dir.mkdir(exist_ok=True)

# Save ArviZ NetCDF (recommended format for posterior)
nc_path = output_dir / "cmc_posterior.nc"
cmc_result.inference_data.to_netcdf(str(nc_path))
print(f"Saved posterior to: {nc_path}")

# Reload and verify
idata_loaded = az.from_netcdf(str(nc_path))
print(f"Reloaded posterior: {list(idata_loaded.posterior.data_vars)[:5]}")
print("Posterior saved and reloaded successfully.")

## 11. Summary

Key takeaways from Bayesian inference:

- **Always use NLSQ warm-start** (`nlsq_result=nlsq_result`) to reduce divergences
- **R-hat < 1.05** for all parameters before trusting posterior summaries
- **ESS > 400** for reliable uncertainty estimates
- **CMC provides**: full posterior, multi-modal detection, proper uncertainty quantification
- **For production**: use `num_warmup=500, num_samples=1500, num_chains=4`
- **ArviZ** is the standard tool for posterior analysis and visualization

### Recommended production settings:

```yaml
optimization:
  cmc:
    sharding:
      max_points_per_shard: "auto"
    per_shard_mcmc:
      num_warmup: 500
      num_samples: 1500
      num_chains: 4
      max_tree_depth: 10
      chain_method: "parallel"
      adaptive_sampling: true
    per_angle_mode: "auto"
    validation:
      max_divergence_rate: 0.10
```

In [None]:
os.unlink(config_path)