# 03 - Model Diagnostics & Convergence Troubleshooting

Bayesian inference via MCMC only produces valid results when the sampler **converges**.
Non-convergent chains explore different regions of the posterior and give contradictory estimates.
This notebook:

1. Loads the fitted model from notebook 02
2. Runs a full diagnostic suite (R-hat, ESS, divergences, max tree depth)
3. Interprets each metric visually
4. Diagnoses the root cause of the convergence failure
5. Implements remediation (tighter priors + better sampling settings)
6. Refits a v2 model and compares results

## Convergence Thresholds (CLAUDE.md)

| Metric | Threshold | Meaning if violated |
|--------|-----------|--------------------|
| **R-hat** | < 1.01 | Chains disagree — parameter not identified |
| **ESS (bulk)** | > 400 | Too few independent samples — estimates unreliable |
| **Divergences** | = 0 | Sampler hit a pathological region of the posterior |
| **Max tree depth** | < 100% | NUTS can't explore efficiently — step size too large |

In [None]:
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

warnings.filterwarnings("ignore")
sns.set_theme(style="whitegrid")
%matplotlib inline

from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
from pymc_marketing.prior import Prior

from mmm_demo.config import OUTPUTS_DIR, ModelConfig
from mmm_demo.data import load_mmm_weekly_data
from mmm_demo.diagnostics import ESS_THRESHOLD, RHAT_THRESHOLD, check_convergence
from mmm_demo.model import fit_model

---
## 1. Load the Saved Model

> **Prerequisite:** Run notebook 02 first to generate `outputs/models/mmm_fit_*.nc`.

In [None]:
# Load weekly data (needed for posterior predictive checks later)
df = load_mmm_weekly_data()
config = ModelConfig()
feature_cols = [config.date_column, *config.channel_columns, *config.control_columns]
X = df[feature_cols]
y = df[config.target_column]

# Find most recent saved model
model_dir = OUTPUTS_DIR / "models"
model_files = sorted(model_dir.glob("mmm_fit_*.nc"))
if not model_files:
    raise FileNotFoundError(
        "No saved model found in outputs/models/. Run notebook 02 first."
    )

model_path = model_files[-1]
print(f"Loading: {model_path.name}")
mmm = MMM.load(str(model_path))
idata = mmm.idata

print("Model loaded.")
print(f"Chains: {idata.posterior.dims['chain']}, Draws: {idata.posterior.dims['draw']}")
print(f"Parameters: {list(idata.posterior.data_vars)}")

---
## 2. Convergence Diagnostics

In [None]:
diag = check_convergence(idata)

print("=" * 50)
print("CONVERGENCE DIAGNOSTIC SUMMARY")
print("=" * 50)
print(f"Overall:     {'PASSED' if diag.passed else 'FAILED'}")
print()
print(
    f"R-hat        max={diag.max_rhat:.4f}  threshold<{RHAT_THRESHOLD}  {'OK' if diag.rhat_ok else 'FAIL'}"
)
print(
    f"ESS (bulk)   min={diag.min_ess:.0f}   threshold>{ESS_THRESHOLD}  {'OK' if diag.ess_ok else 'FAIL'}"
)
print(
    f"Divergences  {diag.divergences}        threshold=0      {'OK' if diag.divergences == 0 else 'FAIL'}"
)

In [None]:
# Full parameter summary — sorted by R-hat descending
summary = az.summary(idata)

# Structural parameters only (exclude derived mu[date] rows)
param_mask = ~summary.index.str.startswith("mu[")
param_summary = summary[param_mask].copy()
param_summary_sorted = param_summary.sort_values("r_hat", ascending=False)

print(f"Total parameters in model:       {len(summary)}")
print(f"Structural parameters:           {len(param_summary)}")
print(
    f"With R-hat > {RHAT_THRESHOLD}:           {(param_summary['r_hat'] > RHAT_THRESHOLD).sum()}"
)
print(
    f"With ESS < {ESS_THRESHOLD}:             {(param_summary['ess_bulk'] < ESS_THRESHOLD).sum()}"
)
print()
print("Top 10 worst parameters by R-hat:")
param_summary_sorted[["mean", "sd", "hdi_3%", "hdi_97%", "ess_bulk", "r_hat"]].head(10)

### R-hat (Gelman-Rubin Statistic)

R-hat measures **agreement between chains**. With 4 chains:
- R-hat ≈ **1.00** — chains explored the same distribution (converged)
- R-hat **1.01–1.1** — mild disagreement, borderline
- R-hat **> 1.1** — chains are exploring different regions — the parameter is **not identified** by the data
- R-hat **> 2.0** — catastrophic non-convergence (common with unidentifiable models)

An R-hat of 4.1 means the between-chain variance is 4× the within-chain variance — the four chains are
wandering in completely separate parts of the parameter space.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

rhat_colors = [
    "tomato" if r > RHAT_THRESHOLD else "steelblue"
    for r in param_summary_sorted["r_hat"]
]
param_summary_sorted["r_hat"].plot.bar(
    ax=axes[0], color=rhat_colors, edgecolor="black", linewidth=0.4
)
axes[0].axhline(
    RHAT_THRESHOLD,
    color="red",
    linestyle="--",
    linewidth=1.5,
    label=f"Threshold ({RHAT_THRESHOLD})",
)
axes[0].set_title("R-hat per Structural Parameter")
axes[0].set_ylabel("R-hat")
axes[0].tick_params(axis="x", rotation=90, labelsize=7)
axes[0].legend()

ess_colors = [
    "tomato" if e < ESS_THRESHOLD else "steelblue"
    for e in param_summary_sorted["ess_bulk"]
]
param_summary_sorted["ess_bulk"].plot.bar(
    ax=axes[1], color=ess_colors, edgecolor="black", linewidth=0.4
)
axes[1].axhline(
    ESS_THRESHOLD,
    color="red",
    linestyle="--",
    linewidth=1.5,
    label=f"Threshold ({ESS_THRESHOLD})",
)
axes[1].set_title("ESS (bulk) per Structural Parameter")
axes[1].set_ylabel("ESS")
axes[1].tick_params(axis="x", rotation=90, labelsize=7)
axes[1].legend()

plt.suptitle("Convergence Metrics by Parameter (red = failed threshold)", fontsize=13)
plt.tight_layout()
plt.show()

### Effective Sample Size (ESS)

MCMC produces **autocorrelated** samples. ESS is the equivalent number of **independent** samples.
With 4 chains × 1000 draws = 4000 total, ideal ESS ≈ 4000. Our ESS of ~4 means the sampler is
almost completely stuck — consecutive samples are nearly identical (autocorrelation ≈ 1.0).

- **ESS > 400** per parameter: reliable posterior estimates
- **ESS < 100**: estimates are noisy and quantiles are unreliable
- **ESS < 10**: the sampler is stuck — results are essentially meaningless

In [None]:
print("ESS (bulk) statistics for structural parameters:")
print(f"  Min:  {param_summary['ess_bulk'].min():.0f}")
print(f"  Mean: {param_summary['ess_bulk'].mean():.0f}")
print(f"  Max:  {param_summary['ess_bulk'].max():.0f}")
print()
print("Parameters with lowest ESS:")
param_summary.sort_values("ess_bulk")[
    ["mean", "sd", "ess_bulk", "ess_tail", "r_hat"]
].head(10)

### Divergences and Maximum Tree Depth

**Divergences** occur when the NUTS sampler's numerical trajectory goes off the rails —
the Hamiltonian energy diverges, indicating the sampler hit a curved or narrow region of the posterior.

**Max tree depth**: NUTS builds a binary tree of proposal steps. If the tree hits the max depth limit,
NUTS is forced to stop, resulting in inefficient exploration. The default is 10 (1024 leapfrog steps).

In [None]:
if hasattr(idata, "sample_stats"):
    ss = idata.sample_stats

    # Divergences
    if "diverging" in ss:
        total_div = int(ss["diverging"].sum().values)
        div_per_chain = ss["diverging"].sum(dim="draw").values
        print(f"Divergent transitions: {total_div} total")
        print(f"  Per chain: {div_per_chain}")
    else:
        print("No divergence info in sample_stats")

    print()

    # Tree depth
    if "tree_depth" in ss:
        max_depth_reached = int(ss["tree_depth"].max().values)
        mean_depth = float(ss["tree_depth"].mean().values)
        # Default max_treedepth is 10
        default_max = 10
        pct_at_max = float((ss["tree_depth"] >= default_max).mean().values) * 100
        print("Tree depth statistics:")
        print(
            f"  Max depth reached: {max_depth_reached} (default limit: {default_max})"
        )
        print(f"  Mean depth: {mean_depth:.1f}")
        print(f"  % draws at max depth: {pct_at_max:.1f}%")
        if pct_at_max > 5:
            print(
                "  WARNING: >5% draws hit max_treedepth — NUTS cannot explore efficiently"
            )
            print(
                "  Remedy: increase max_treedepth or reduce step size via higher target_accept"
            )

---
## 3. Visual Diagnostics

Numbers tell us *that* convergence failed. Plots tell us *how* it failed.

In [None]:
# Trace plots: left = marginal KDE per chain, right = samples over time
# Converged: KDEs overlap, traces look like 'fuzzy caterpillars'
# Non-converged: KDEs are separated, traces drift or get stuck
key_params = ["intercept", "adstock_alpha", "saturation_lam", "saturation_beta"]

axes = az.plot_trace(idata, var_names=key_params, compact=True, figsize=(12, 10))
plt.suptitle(
    "Trace Plots — Key Parameters\n"
    "(LEFT: KDE per chain should overlap | RIGHT: trace should look like fuzzy caterpillar)",
    y=1.02,
    fontsize=11,
)
plt.tight_layout()
plt.show()

In [None]:
# Rank plots: uniformly distributed ranks = converged chains
# Non-uniform bars = one chain dominates a region of the posterior
axes = az.plot_rank(idata, var_names=key_params, kind="bars", figsize=(12, 8))
plt.suptitle(
    "Rank Plots — Key Parameters\n"
    "(Uniform bars = converged | Peaked/sloped bars = chain imbalance)",
    y=1.02,
    fontsize=11,
)
plt.tight_layout()
plt.show()

In [None]:
# Energy plot: BFMI (Bayesian Fraction of Missing Information)
# Marginal energy distribution should match the transition energy distribution
# A mismatch suggests the sampler cannot explore the full energy landscape
ax = az.plot_energy(idata, figsize=(10, 4))
plt.title(
    "Energy Plot (BFMI Diagnostic)\n"
    "Marginal energy (blue) should match transition energy (orange).\n"
    "BFMI < 0.3 indicates the sampler cannot freely explore the posterior."
)
plt.tight_layout()
plt.show()

try:
    bfmi = az.bfmi(idata)
    print(f"BFMI per chain: {bfmi.round(3)}")
    print("(Rule of thumb: BFMI < 0.3 is problematic)")
except Exception as e:
    print(f"BFMI calculation: {e}")

---
## 4. Root Cause Analysis

The diagnostic numbers are clear: the model has not converged. But **why?**

Let's investigate the data and model structure to understand the fundamental problem.

In [None]:
# ROOT CAUSE 1: Monthly patterns repeat — effective N is only 12
# Media spend is distributed pro-rata across weeks within each month,
# so all weeks in the same month have IDENTICAL channel spend values.

weekly_media = df[["Date"] + config.channel_columns].copy()
weekly_media["month"] = pd.to_datetime(weekly_media["Date"]).dt.strftime("%Y-%m")

print("First 12 rows of weekly media spend:")
print("Note how spend values repeat identically within each month.")
print()
print(weekly_media.head(12).to_string(index=False))

print()
unique_patterns = weekly_media[config.channel_columns].drop_duplicates()
print(f"Total rows:              {len(weekly_media)}")
print(f"Unique media patterns:   {len(unique_patterns)}")
print(f"Effective N (for media): {len(unique_patterns)}")
print()
print("The model sees 52 observations but only 12 DISTINCT media inputs.")
print("This is the primary driver of non-convergence.")

In [None]:
# ROOT CAUSE 2: Too many parameters for too little data
n_channels = len(config.channel_columns)
n_params_per_channel = 3  # adstock_alpha, saturation_lam, saturation_beta
channel_params = n_channels * n_params_per_channel
control_params = len(config.control_columns)
other_params = 2  # intercept, sigma
total_params = channel_params + control_params + other_params

n_obs = len(df)
n_effective = 12

print("Parameter count vs effective observations")
print("=" * 45)
print(f"Channels:              {n_channels}")
print(
    f"Params per channel:    {n_params_per_channel} (adstock_alpha, saturation_lam, saturation_beta)"
)
print(f"Channel params:        {channel_params}")
print(f"Control params:        {control_params}")
print(f"Intercept + sigma:     {other_params}")
print(f"Total structural:      {total_params}")
print()
print(f"Total observations:    {n_obs} (weekly rows)")
print(f"Effective media obs:   {n_effective} (unique monthly patterns)")
print(
    f"Params / effective N:  {total_params} / {n_effective} = {total_params/n_effective:.1f}"
)
print()
print("At 1.4 parameters per effective observation, priors carry most of the")
print("information — the likelihood cannot identify all parameters.")

In [None]:
# ROOT CAUSE 3: Extreme variance in GMV creates a difficult likelihood landscape
print("GMV distribution characteristics:")
print(f"  Min:  {y.min():>15,.0f}")
print(f"  Max:  {y.max():>15,.0f}")
print(f"  Mean: {y.mean():>15,.0f}")
print(f"  Std:  {y.std():>15,.0f}")
print(f"  CV:   {y.std()/y.mean():.2f} (coefficient of variation)")
print(f"  Max/Min ratio: {y.max()/y.min():.0f}x")
print()
print("A 25,000x range from min to max creates a highly irregular likelihood.")
print("The two extreme outlier months (Oct 2015, Mar 2016) dominate the fit.")

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(df["Date"], y, marker="o", linewidth=2)
axes[0].set_title("Weekly GMV Over Time")
axes[0].set_ylabel("GMV")
axes[0].tick_params(axis="x", rotation=45)

axes[1].hist(np.log10(y + 1), bins=15, edgecolor="black")
axes[1].set_title("log10(GMV) Distribution")
axes[1].set_xlabel("log10(GMV + 1)")
axes[1].set_ylabel("Count")

plt.tight_layout()
plt.show()

---
## 5. Remediation Strategies

Given the root causes, here are the available levers:

| Strategy | What it does | Expected impact |
|----------|-------------|----------------|
| **Increase `target_accept`** (0.95 → 0.99) | Smaller NUTS steps, fewer tree-depth violations | Reduces max-tree-depth hits |
| **Increase `tune`** (2000 → 3000) | More adaptation iterations to find good step size | Better mass matrix estimate |
| **Tighter `saturation_beta` prior** (HalfNormal(2) → HalfNormal(0.5)) | Constrains channel contributions to plausible range | Reduces posterior volume |
| **Reduce `adstock_max_lag`** (4 → 2) | Fewer adstock parameters | Not tried — already at 4 |
| **Use monthly data** | 12 rows, 1:1 with media | Lose weekly variation |

### Why tighten `saturation_beta`?

PyMC-Marketing scales all inputs to [0, 1] via MaxAbsScaler. After scaling:
- The target `total_gmv` is in [0, 1]
- Channel contributions must sum to roughly 1.0 (plus intercept)
- `HalfNormal(sigma=2)` puts most prior mass above 1.0 — way too wide
- `HalfNormal(sigma=0.5)` concentrates mass in [0, 0.5] per channel — much more realistic

In [None]:
# Visualize prior comparison: HalfNormal(2) vs HalfNormal(0.5)
from scipy.stats import halfnorm

x_prior = np.linspace(0, 5, 300)
fig, ax = plt.subplots(figsize=(10, 4))

ax.plot(
    x_prior,
    halfnorm.pdf(x_prior, scale=2),
    label="HalfNormal(sigma=2) — v1",
    color="tomato",
    linewidth=2,
)
ax.plot(
    x_prior,
    halfnorm.pdf(x_prior, scale=0.5),
    label="HalfNormal(sigma=0.5) — v2",
    color="steelblue",
    linewidth=2,
)
ax.axvline(
    1.0, color="gray", linestyle="--", alpha=0.7, label="Max scaled contribution = 1.0"
)

ax.set_title("Prior for saturation_beta: v1 vs v2")
ax.set_xlabel("saturation_beta value")
ax.set_ylabel("Prior density")
ax.legend()
ax.set_xlim(0, 5)
plt.tight_layout()
plt.show()

print(f"P(beta > 1.0) under HalfNormal(2):   {halfnorm.sf(1.0, scale=2):.1%}")
print(f"P(beta > 1.0) under HalfNormal(0.5): {halfnorm.sf(1.0, scale=0.5):.1%}")

In [None]:
# Build V2 model with improved settings
config_v2 = ModelConfig(
    target_accept=0.99,  # was 0.95
    tune=3000,  # was 2000
    draws=1000,
)

# Override saturation_beta prior only
model_config_v2 = config_v2.get_model_config()
model_config_v2["saturation_beta"] = Prior("HalfNormal", sigma=0.5)

mmm_v2 = MMM(
    date_column=config_v2.date_column,
    channel_columns=config_v2.channel_columns,
    control_columns=config_v2.control_columns,
    adstock=GeometricAdstock(l_max=config_v2.adstock_max_lag),
    saturation=LogisticSaturation(),
    model_config=model_config_v2,
)

print("V2 config:")
print(f"  target_accept:    {config_v2.target_accept}  (was 0.95)")
print(f"  tune:             {config_v2.tune}   (was 2000)")
print("  saturation_beta:  HalfNormal(sigma=0.5)  (was HalfNormal(sigma=2))")
print()
print("V2 model initialized. Running MCMC sampling (approx 4-6 minutes)...")

In [None]:
%%time
idata_v2 = fit_model(mmm_v2, X, y, config_v2)
print("V2 sampling complete.")

In [None]:
# V2 diagnostics
diag_v2 = check_convergence(idata_v2)

print("V2 DIAGNOSTICS:")
print(f"  Passed:     {'PASSED' if diag_v2.passed else 'FAILED'}")
print(f"  Max R-hat:  {diag_v2.max_rhat:.4f}")
print(f"  Min ESS:    {diag_v2.min_ess:.0f}")
print(f"  Divergences:{diag_v2.divergences}")

print()
axes = az.plot_trace(idata_v2, var_names=key_params, compact=True, figsize=(12, 10))
plt.suptitle("V2 Trace Plots", y=1.02, fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
# Before / after comparison
diag_v1 = check_convergence(idata)

comparison = pd.DataFrame(
    {
        "V1 (original)": {
            "max R-hat": f"{diag_v1.max_rhat:.4f}",
            "min ESS": f"{diag_v1.min_ess:.0f}",
            "divergences": diag_v1.divergences,
            "passed": diag_v1.passed,
            "target_accept": 0.95,
            "tune": 2000,
            "saturation_beta prior": "HalfNormal(sigma=2)",
        },
        "V2 (improved)": {
            "max R-hat": f"{diag_v2.max_rhat:.4f}",
            "min ESS": f"{diag_v2.min_ess:.0f}",
            "divergences": diag_v2.divergences,
            "passed": diag_v2.passed,
            "target_accept": config_v2.target_accept,
            "tune": config_v2.tune,
            "saturation_beta prior": "HalfNormal(sigma=0.5)",
        },
    }
).T

print("BEFORE vs AFTER COMPARISON:")
print(comparison.to_string())
print()

if diag_v2.passed:
    print("V2 model converged! Results in notebooks 04-05 are reliable.")
else:
    print(
        "V2 still shows convergence issues — this reflects fundamental data limitations."
    )
    print(
        f"R-hat {diag_v2.max_rhat:.3f} vs {diag_v1.max_rhat:.3f} (v1) — some improvement."
    )
    print("Posteriors are influenced primarily by priors, not data.")
    print("Interpretations in notebooks 04-05 must be treated with caution.")

In [None]:
from datetime import datetime

date_str = datetime.now().strftime("%Y-%m-%d")
model_path_v2 = OUTPUTS_DIR / "models" / f"mmm_fit_{date_str}_v2.nc"
model_path_v2.parent.mkdir(parents=True, exist_ok=True)
mmm_v2.save(str(model_path_v2))
print(f"V2 model saved to: {model_path_v2}")
print()
print("Notebooks 04 and 05 will load the most recent .nc file (this v2 model).")

---
## 6. Conclusions

### What we found

The V1 model failed convergence severely (R-hat up to 4.26, ESS as low as 4). The V2 model
shows improvement. Whether it fully converges depends on the specific data and random seed.

### Root causes (ranked by severity)

1. **Effective N = 12** — media spend only has 12 unique monthly patterns driving 52 weekly rows.
   The likelihood cannot identify 17 structural parameters from 12 data points.
2. **GMV variance is extreme** — the 25,000x range between min and max weeks creates a
   difficult likelihood landscape that NUTS struggles to explore.
3. **Prior too wide** — `saturation_beta ~ HalfNormal(2)` allows contributions far outside the
   plausible range for MaxAbsScaled data, inflating the posterior volume.

### What the v2 changes did

- **Tighter prior** reduces the volume the sampler needs to explore
- **Higher `target_accept`** forces smaller, more careful NUTS steps
- **More tuning** gives the mass matrix estimator more data to adapt to the posterior shape

### What we cannot fix

With only 12 effective observations for 17 parameters, the posterior will always be
**prior-dominated** to some degree. This is not a modeling failure — it is an honest
reflection of the data. Bayesian priors encode domain knowledge precisely for this situation.

### Recommendation for production

- Use **daily data with daily media spend** (if available) to get effective N > 100
- Use **hierarchical priors** across channels to share information
- Use **informative priors** grounded in industry benchmarks for adstock and saturation
- Consider **log-transforming the target** to reduce the extreme variance

> **For this demo:** we proceed with the V2 model in notebooks 04 and 05,
> clearly caveating that results are prior-dominated.