# Notebook 04: MCMC - Metropolis-Hastings & Diagnostics

**Learning Goals:**
- Understand Markov Chain Monte Carlo (MCMC) sampling
- Implement Metropolis-Hastings algorithm
- Learn proposal tuning (step size, acceptance rate trade-offs)
- Master diagnostics: trace plots, autocorrelation, ESS
- Identify and fix common failure modes

**Runtime:** ~2 minutes

In [None]:
# Setup
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys
from scipy.stats import norm, multivariate_normal

repo_root = Path().resolve().parents[2]
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from modules._import_helper import safe_import_from

set_seed, get_rng = safe_import_from(
    '00_repo_standards.src.mlphys_core.seeding',
    'set_seed', 'get_rng'
)
MetropolisHastings, MCMCDiagnostics = safe_import_from(
    '02_stat_inference_uq.src.mcmc_basics',
    'MetropolisHastings', 'MCMCDiagnostics'
)

set_seed(42)

reports_dir = Path("../reports")
reports_dir.mkdir(exist_ok=True)

print("‚úÖ Setup complete")

## 1. Intuition: Why MCMC?

**The problem:** Many Bayesian posteriors cannot be computed analytically
- Example: $p(\theta | \text{data})$ for complex likelihoods
- Direct sampling is impossible (don't know normalization constant)
- High-dimensional integrals are intractable

**MCMC solution:** Generate samples from target distribution without knowing normalization
- Build a Markov chain whose stationary distribution is the target
- After "burn-in", samples approximate draws from $p(\theta | \text{data})$
- Use samples to compute expectations: $\mathbb{E}[f(\theta)] \approx \frac{1}{N}\sum_i f(\theta_i)$

**Metropolis-Hastings algorithm:**
1. Propose new state: $\theta' \sim q(\cdot | \theta)$
2. Compute acceptance ratio: $\alpha = \min\left(1, \frac{p(\theta') q(\theta | \theta')}{p(\theta) q(\theta' | \theta)}\right)$
3. Accept $\theta'$ with probability $\alpha$, else stay at $\theta$
4. Repeat

**Key insight:** Only need to compute ratios $p(\theta')/p(\theta)$ ‚Üí normalization cancels!

## 2. Minimal Math

**Target distribution:** $\pi(\theta)$ (unnormalized is fine)

**Proposal distribution:** $q(\theta' | \theta)$ (e.g., random walk: $\theta' = \theta + \epsilon$, $\epsilon \sim \mathcal{N}(0, \sigma^2 I)$)

**For symmetric proposals** ($q(\theta' | \theta) = q(\theta | \theta')$):
$$\alpha = \min\left(1, \frac{\pi(\theta')}{\pi(\theta)}\right)$$

**Acceptance probability:**
- If $\pi(\theta') > \pi(\theta)$ (better state): always accept
- If $\pi(\theta') < \pi(\theta)$ (worse state): accept with probability $\pi(\theta') / \pi(\theta)$

**Key diagnostics:**
1. **Acceptance rate**: Fraction of proposals accepted
   - Too high (>80%): proposals too small, slow exploration
   - Too low (<20%): proposals too large, rejecting too often
   - Optimal: ~20-50% for high-dimensional problems

2. **Autocorrelation**: $\rho(k) = \text{Corr}(\theta_t, \theta_{t+k})$
   - High autocorrelation ‚Üí samples are dependent
   - Want $\rho(k) \to 0$ quickly as $k$ increases

3. **Effective Sample Size (ESS)**: $\text{ESS} \approx \frac{N}{1 + 2\sum_{k=1}^\infty \rho(k)}$
   - Accounts for autocorrelation
   - Higher ESS ‚Üí more independent information

## 3. Implementation: Sample from 1D Gaussian

In [None]:
# Target: Standard normal N(0, 1)
def log_prob_1d(x):
    """Log probability of N(0,1)."""
    return -0.5 * x**2  # Ignoring constant terms

# Run MCMC
sampler = MetropolisHastings(
    log_prob_fn=log_prob_1d,
    proposal_std=1.0,
    n_samples=5000,
    n_burn=500,
    random_state=42
)

samples = sampler.sample(x0=np.array([5.0]), verbose=False)

print(f"Acceptance rate: {sampler.accept_rate_:.1%}")
print(f"Sample mean: {samples.mean():.3f} (true: 0.0)")
print(f"Sample std: {samples.std():.3f} (true: 1.0)")
print(f"Number of samples after burn-in: {len(samples)}")

## 4. Experiments: Proposal Tuning & Diagnostics

In [None]:
# Experiment 1: Effect of proposal step size
proposal_stds = [0.1, 0.5, 1.0, 3.0]
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.ravel()

for idx, prop_std in enumerate(proposal_stds):
    ax = axes[idx]
    
    sampler_exp = MetropolisHastings(
        log_prob_fn=log_prob_1d,
        proposal_std=prop_std,
        n_samples=1000,
        n_burn=100,
        random_state=42
    )
    samples_exp = sampler_exp.sample(x0=np.array([5.0]), verbose=False)
    
    # Plot histogram vs true distribution
    ax.hist(samples_exp, bins=30, density=True, alpha=0.6, 
            color='steelblue', edgecolor='black', label='MCMC samples')
    x_range = np.linspace(-4, 4, 200)
    ax.plot(x_range, norm.pdf(x_range), 'r-', linewidth=2, label='True N(0,1)')
    
    ax.set_xlabel('x', fontsize=11)
    ax.set_ylabel('Density', fontsize=11)
    ax.set_title(f'œÉ_prop = {prop_std:.1f} | Accept rate: {sampler_exp.accept_rate_:.1%}',
                fontsize=12)
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    ax.set_xlim(-4, 4)

plt.tight_layout()
plt.savefig(reports_dir / '04_proposal_tuning.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ Saved: reports/04_proposal_tuning.png")
print("\nüìä Observations:")
print("   œÉ=0.1: High accept (slow exploration)")
print("   œÉ=3.0: Low accept (many rejections, gets stuck)")
print("   œÉ=1.0: Goldilocks zone (~40-60% accept)")

In [None]:
# Experiment 2: Trace plots and diagnostics
# Sample from 2D Gaussian
def log_prob_2d(x):
    """Log prob of 2D Gaussian with correlation."""
    mu = np.array([1.0, -0.5])
    cov = np.array([[1.0, 0.7], [0.7, 1.0]])
    diff = x - mu
    return -0.5 * diff @ np.linalg.inv(cov) @ diff

sampler_2d = MetropolisHastings(
    log_prob_fn=log_prob_2d,
    proposal_std=1.0,
    n_samples=5000,
    n_burn=500,
    random_state=42
)

samples_2d = sampler_2d.sample(x0=np.array([0.0, 0.0]), verbose=False)

print(f"\n2D Sampling Results:")
print(f"Acceptance rate: {sampler_2d.accept_rate_:.1%}")
print(f"Mean: {samples_2d.mean(axis=0)} (true: [1.0, -0.5])")
print(f"Std: {samples_2d.std(axis=0)} (true: [1.0, 1.0])")

# Create diagnostics
diagnostics = MCMCDiagnostics(samples_2d)

# Trace plots
fig = diagnostics.trace_plot()
plt.savefig(reports_dir / '04_trace_plots.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n‚úÖ Saved: reports/04_trace_plots.png")

# Marginal histograms
fig = diagnostics.marginal_histograms(true_mean=np.array([1.0, -0.5]))
plt.savefig(reports_dir / '04_marginal_histograms.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Saved: reports/04_marginal_histograms.png")

In [None]:
# Experiment 3: Autocorrelation and ESS
acf_x1 = diagnostics.autocorrelation(max_lag=50, param_idx=0)
acf_x2 = diagnostics.autocorrelation(max_lag=50, param_idx=1)

iat_x1 = diagnostics.integrated_autocorr_time(param_idx=0)
iat_x2 = diagnostics.integrated_autocorr_time(param_idx=1)

ess_x1 = diagnostics.effective_sample_size(param_idx=0)
ess_x2 = diagnostics.effective_sample_size(param_idx=1)

print(f"\nAutocorrelation Diagnostics:")
print(f"Integrated autocorrelation time: [{iat_x1:.2f}, {iat_x2:.2f}]")
print(f"Effective sample size: [{ess_x1:.0f}, {ess_x2:.0f}] (out of {len(samples_2d)})")
print(f"Efficiency: {ess_x1/len(samples_2d):.1%}")

# Plot autocorrelation
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(acf_x1, 'o-', linewidth=2, markersize=4)
axes[0].axhline(0, color='black', linestyle='--', alpha=0.5)
axes[0].set_xlabel('Lag', fontsize=12)
axes[0].set_ylabel('Autocorrelation', fontsize=12)
axes[0].set_title(f'Dimension 1 (IAT={iat_x1:.2f}, ESS={ess_x1:.0f})', fontsize=13)
axes[0].grid(alpha=0.3)

axes[1].plot(acf_x2, 's-', linewidth=2, markersize=4, color='darkorange')
axes[1].axhline(0, color='black', linestyle='--', alpha=0.5)
axes[1].set_xlabel('Lag', fontsize=12)
axes[1].set_ylabel('Autocorrelation', fontsize=12)
axes[1].set_title(f'Dimension 2 (IAT={iat_x2:.2f}, ESS={ess_x2:.0f})', fontsize=13)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(reports_dir / '04_autocorrelation.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n‚úÖ Saved: reports/04_autocorrelation.png")

In [None]:
# Experiment 4: Failure modes
# Mode 1: Too small step size (random walk gets stuck)
# Mode 2: Too large step size (high rejection rate)
# Mode 3: Multimodal target (chain doesn't explore all modes)

def multimodal_log_prob(x):
    """Mixture of two Gaussians (bimodal)."""
    # Modes at x=-2 and x=+2
    log_p1 = -0.5 * (x[0] + 2)**2 - 0.5 * x[1]**2
    log_p2 = -0.5 * (x[0] - 2)**2 - 0.5 * x[1]**2
    # Log-sum-exp trick
    max_log_p = max(log_p1, log_p2)
    return max_log_p + np.log(np.exp(log_p1 - max_log_p) + np.exp(log_p2 - max_log_p))

# Try sampling with small step (will get stuck in one mode)
sampler_fail = MetropolisHastings(
    log_prob_fn=multimodal_log_prob,
    proposal_std=0.5,  # Too small to jump between modes
    n_samples=3000,
    n_burn=100,
    random_state=42
)

samples_fail = sampler_fail.sample(x0=np.array([-2.0, 0.0]), verbose=False)

print(f"\nMultimodal sampling (failure mode):")
print(f"Acceptance rate: {sampler_fail.accept_rate_:.1%}")
print(f"Mean of x1: {samples_fail[:, 0].mean():.2f} (should be ~0 if exploring both modes)")
print(f"Std of x1: {samples_fail[:, 0].std():.2f} (should be >2 if exploring both modes)")
print(f"‚ö†Ô∏è Chain likely stuck in one mode!")

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Trace plot
ax1.plot(samples_fail[:, 0], linewidth=0.5)
ax1.axhline(-2, color='r', linestyle='--', label='Left mode')
ax1.axhline(2, color='g', linestyle='--', label='Right mode')
ax1.set_xlabel('Iteration', fontsize=12)
ax1.set_ylabel('x‚ÇÅ', fontsize=12)
ax1.set_title('Trace Plot: Stuck in Left Mode', fontsize=13)
ax1.legend()
ax1.grid(alpha=0.3)

# 2D scatter
ax2.scatter(samples_fail[:, 0], samples_fail[:, 1], s=10, alpha=0.3)
ax2.scatter([-2, 2], [0, 0], s=200, c=['red', 'green'], 
           marker='*', edgecolor='black', linewidth=2, label='True modes')
ax2.set_xlabel('x‚ÇÅ', fontsize=12)
ax2.set_ylabel('x‚ÇÇ', fontsize=12)
ax2.set_title('Samples: Only Exploring Left Mode', fontsize=13)
ax2.legend()
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(reports_dir / '04_failure_mode_multimodal.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ Saved: reports/04_failure_mode_multimodal.png")
print("\nüìä Lesson: MCMC can fail on multimodal distributions!")
print("   Solutions: Parallel tempering, Hamiltonian MC, or multiple chains")

## 5. Sanity Checks

In [None]:
# Sanity check 1: Sample mean should converge to true mean
print("Sanity Check 1: Convergence to true mean")
print(f"   Sample mean: {samples_2d.mean(axis=0)}")
print(f"   True mean: [1.0, -0.5]")
print(f"   Error: {np.linalg.norm(samples_2d.mean(axis=0) - np.array([1.0, -0.5])):.4f}")
print(f"   ‚úÖ PASSED" if np.allclose(samples_2d.mean(axis=0), [1.0, -0.5], atol=0.1) else "   ‚ùå FAILED")

# Sanity check 2: Acceptance rate in reasonable range
print("\nSanity Check 2: Acceptance rate")
print(f"   Rate: {sampler_2d.accept_rate_:.1%}")
reasonable = 0.2 <= sampler_2d.accept_rate_ <= 0.7
print(f"   ‚úÖ PASSED (in [20%, 70%])" if reasonable else "   ‚ö†Ô∏è Outside optimal range")

# Sanity check 3: ESS should be less than total samples
print("\nSanity Check 3: ESS < N")
print(f"   ESS: {ess_x1:.0f}, {ess_x2:.0f}")
print(f"   Total samples: {len(samples_2d)}")
print(f"   ‚úÖ PASSED" if (ess_x1 < len(samples_2d) and ess_x2 < len(samples_2d)) else "   ‚ùå FAILED")

# Sanity check 4: Autocorrelation should decay
print("\nSanity Check 4: Autocorrelation decay")
print(f"   ACF at lag 0: {acf_x1[0]:.3f} (should be 1.0)")
print(f"   ACF at lag 20: {acf_x1[20]:.3f} (should be < 0.2)")
decays = acf_x1[0] > acf_x1[10] > acf_x1[20]
print(f"   ‚úÖ PASSED (decays)" if decays else "   ‚ö†Ô∏è Not decaying properly")

## 6. Key Takeaways

‚úÖ **MCMC enables sampling from complex distributions** without knowing normalization

‚úÖ **Metropolis-Hastings**: Simple and general MCMC algorithm
   - Only needs to evaluate probability ratios
   - Guaranteed to converge to target distribution (eventually)

‚úÖ **Proposal tuning is critical**:
   - Too small: high acceptance but slow exploration
   - Too large: low acceptance, chain gets stuck
   - Aim for ~20-50% acceptance rate

‚úÖ **Diagnostics are essential**:
   - **Trace plots**: Check for convergence and mixing
   - **Autocorrelation**: Quantify sample dependence
   - **ESS**: Effective number of independent samples

‚úÖ **Common pitfalls**:
   - Insufficient burn-in (discard early samples)
   - Multimodal targets (chain may miss modes)
   - High autocorrelation (need more samples)

**When to thin vs not thin:**
- ‚úÖ Thin if storage is limited
- ‚ùå Generally don't thin: use all samples for better estimates

## 7. Exercises

**Exercise 1:** Implement adaptive Metropolis: adjust proposal std during sampling to maintain ~40% acceptance.

**Exercise 2:** Sample from a banana-shaped distribution. How does proposal shape (isotropic vs adapted) affect efficiency?

**Exercise 3:** For the multimodal example, try parallel tempering or running multiple chains from different initializations.

**Exercise 4:** Implement Gelman-Rubin diagnostic (R-hat) to assess convergence across multiple chains.

**Exercise 5:** Compare MCMC to importance sampling for a 1D target. Which is more efficient?

**Exercise 6:** Research Hamiltonian Monte Carlo (HMC). Why is it more efficient than random-walk MH?

In [None]:
# Your solutions here


---

## Solutions

In [None]:
# Solution 4: Gelman-Rubin R-hat
def gelman_rubin(chains):
    """Compute R-hat statistic for multiple chains.
    
    Args:
        chains: List of arrays, each (n_samples, n_dim)
    Returns:
        R_hat per dimension
    """
    n_chains = len(chains)
    n_samples = chains[0].shape[0]
    n_dim = chains[0].shape[1]
    
    chain_means = np.array([chain.mean(axis=0) for chain in chains])  # (n_chains, n_dim)
    grand_mean = chain_means.mean(axis=0)  # (n_dim,)
    
    # Between-chain variance
    B = n_samples / (n_chains - 1) * np.sum((chain_means - grand_mean)**2, axis=0)
    
    # Within-chain variance
    W = np.mean([np.var(chain, axis=0, ddof=1) for chain in chains], axis=0)
    
    # Marginal posterior variance estimate
    var_plus = ((n_samples - 1) / n_samples) * W + (1 / n_samples) * B
    
    # R-hat
    R_hat = np.sqrt(var_plus / W)
    return R_hat

# Run 4 chains from different starting points
n_chains = 4
chains = []
for i in range(n_chains):
    sampler_i = MetropolisHastings(
        log_prob_fn=log_prob_2d,
        proposal_std=1.0,
        n_samples=2000,
        n_burn=200,
        random_state=42 + i
    )
    x0 = np.random.randn(2) * 2  # Random start
    samples_i = sampler_i.sample(x0=x0, verbose=False)
    chains.append(samples_i)

R_hat = gelman_rubin(chains)
print(f"Solution 4: Gelman-Rubin R-hat = {R_hat}")
print(f"   R-hat < 1.1 indicates convergence")
print(f"   Status: {'‚úÖ Converged' if np.all(R_hat < 1.1) else '‚ö†Ô∏è Not converged'}")

# Solution 6: HMC explanation
print("\nSolution 6: Why HMC is more efficient:")
print("   1. Uses gradient information (not just random walk)")
print("   2. Proposes distant states with high acceptance")
print("   3. Explores posterior along natural curvature")
print("   4. Lower autocorrelation ‚Üí higher ESS per iteration")
print("   Example: Stan, PyMC3 use HMC/NUTS by default")

---

**Congratulations!** You've completed the Statistical Inference & UQ module. You now understand:
- Aleatoric vs epistemic uncertainty
- Bayesian regression and posterior predictive distributions
- Calibration diagnostics and temperature scaling
- MCMC sampling and convergence diagnostics

**Next steps:** Apply these concepts to real ML problems in subsequent modules!