# Diffusion-Assisted MCMC Sampling

This notebook demonstrates the diffusion-assisted Markov Chain Monte Carlo algorithm from:

> Hunt-Smith et al. (2023), "Accelerating Markov Chain Monte Carlo sampling with diffusion models", arXiv:2309.01454v1

## Key Idea

Traditional MCMC methods like Metropolis-Hastings use local proposals (e.g., Gaussian steps from the current point). While this works well for unimodal distributions, it struggles with:

1. **Multi-modal distributions**: Hard to jump between distant modes
2. **Long-tailed distributions**: Can get stuck far from the mode

The solution: **Augment MCMC with a diffusion model** that:
- Learns the shape of the posterior from collected samples
- Can propose non-local jumps to important regions
- Is periodically retrained as sampling progresses
- Doesn't require gradient information (works with black-box likelihoods)

## Algorithm Overview

**Algorithm 1: Diffusion-Assisted Metropolis-Hastings**

```
for i = 1, ..., n_samples:
    if random() < p_diff:
        # Global proposal from diffusion model
        theta' ~ DiffusionModel()
        accept with probability min(1, P(theta') / P(theta) * Q(theta) / Q(theta'))
    else:
        # Local Gaussian proposal
        theta' ~ Normal(theta, sigma^2)
        accept with probability min(1, P(theta') / P(theta))
    
    if i % retrain_interval == 0:
        retrain diffusion model on existing samples
```

The diffusion model learns by:
1. **Forward process**: Add Gaussian noise progressively to samples
2. **Reverse process**: Learn to denoise by fitting parameters
3. **Sampling**: Start from pure noise and apply reverse process

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

from mc_lab.diffusion_mcmc import diffusion_assisted_mcmc

# Set random seed for reproducibility
np.random.seed(42)

## Example 1: Standard Bivariate Gaussian

Let's start with a simple 2D Gaussian: $\theta \sim N(0, I)$

In [None]:
# Define log-posterior
def log_posterior_gaussian(theta):
    """Log-posterior for standard Gaussian: log P(theta) = -0.5 * ||theta||^2"""
    return -0.5 * np.sum(theta**2)

# Sample using diffusion-assisted MCMC
initial = np.array([0.0, 0.0])
samples, accepted, info = diffusion_assisted_mcmc(
    log_posterior_gaussian,
    initial,
    n_samples=5000,
    p_diff=0.3,
    sigma_mh=1.0,
    retrain_interval=500,
    random_state=42,
    verbose=True,
)

print(f"\nOverall acceptance rate: {np.mean(accepted):.1%}")
print(f"Diffusion acceptance rate: {info['diffusion_acceptance_rate']:.1%}")
print(f"Gaussian acceptance rate: {info['gaussian_acceptance_rate']:.1%}")

In [None]:
# Visualize results
burn_in = 1000
samples_burned = samples[burn_in:]

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Trace plots
axes[0].plot(samples[:, 0], alpha=0.5, color="#117733")
axes[0].axvline(burn_in, color="#CC6677", linestyle="--", label="Burn-in")
axes[0].set_xlabel("Iteration")
axes[0].set_ylabel(r"$\theta_1$")
axes[0].set_title("Trace Plot (Dimension 1)")
axes[0].legend()
axes[0].grid(alpha=0.3)

# 2D scatter
axes[1].scatter(samples_burned[:, 0], samples_burned[:, 1], alpha=0.3, s=1, color="#44AA99")
axes[1].set_xlabel(r"$\theta_1$")
axes[1].set_ylabel(r"$\theta_2$")
axes[1].set_title("Samples (after burn-in)")
axes[1].grid(alpha=0.3)
axes[1].axis("equal")

# Marginal distributions
axes[2].hist(samples_burned[:, 0], bins=50, alpha=0.6, color="#117733", density=True, label=r"$\theta_1$")
axes[2].hist(samples_burned[:, 1], bins=50, alpha=0.6, color="#44AA99", density=True, label=r"$\theta_2$")
x = np.linspace(-4, 4, 100)
axes[2].plot(x, stats.norm.pdf(x), "k--", label="True N(0,1)")
axes[2].set_xlabel(r"$\theta$")
axes[2].set_ylabel("Density")
axes[2].set_title("Marginal Distributions")
axes[2].legend()
axes[2].grid(alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print(f"Sample mean: [{np.mean(samples_burned[:, 0]):.3f}, {np.mean(samples_burned[:, 1]):.3f}]")
print(f"Sample std:  [{np.std(samples_burned[:, 0]):.3f}, {np.std(samples_burned[:, 1]):.3f}]")
print(f"True mean:   [0.000, 0.000]")
print(f"True std:    [1.000, 1.000]")

## Example 2: Bimodal Distribution

Now let's try a more challenging bimodal distribution: a mixture of two Gaussians.

$$P(\theta) = 0.5 \cdot N(\theta; [-3, 0], I) + 0.5 \cdot N(\theta; [3, 0], I)$$

This is difficult for standard MCMC because the modes are separated, requiring large jumps to transition between them.

In [None]:
def log_posterior_bimodal(theta):
    """Log-posterior for bimodal Gaussian mixture."""
    # Two modes at [-3, 0] and [3, 0]
    log_p1 = stats.multivariate_normal.logpdf(theta, mean=[-3, 0])
    log_p2 = stats.multivariate_normal.logpdf(theta, mean=[3, 0])
    
    # Log of mixture: log(0.5 * exp(log_p1) + 0.5 * exp(log_p2))
    max_log = max(log_p1, log_p2)
    return max_log + np.log(0.5 * np.exp(log_p1 - max_log) + 0.5 * np.exp(log_p2 - max_log))

# Standard MH (for comparison)
print("Standard Metropolis-Hastings:")
samples_mh, accepted_mh, info_mh = diffusion_assisted_mcmc(
    log_posterior_bimodal,
    np.array([0.0, 0.0]),
    n_samples=3000,
    p_diff=0.0,  # No diffusion
    sigma_mh=0.5,
    random_state=42,
)

# Diffusion-assisted MH (seed with samples near both modes)
print("\nDiffusion-Assisted MCMC:")
seed_samples = np.array([[-3, 0], [3, 0], [-3, 0], [3, 0]])  # Seed both modes
samples_diff, accepted_diff, info_diff = diffusion_assisted_mcmc(
    log_posterior_bimodal,
    np.array([0.0, 0.0]),
    n_samples=3000,
    p_diff=0.5,
    sigma_mh=0.5,
    seed_samples=seed_samples,
    retrain_interval=300,
    random_state=43,
)

In [None]:
# Visualize comparison
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

burn_in = 500

# Standard MH
axes[0, 0].plot(samples_mh[:, 0], alpha=0.7, color="#332288", linewidth=0.5)
axes[0, 0].axvline(burn_in, color="#CC6677", linestyle="--", alpha=0.7)
axes[0, 0].set_xlabel("Iteration")
axes[0, 0].set_ylabel(r"$\theta_1$")
axes[0, 0].set_title("Standard MH: Trace Plot")
axes[0, 0].grid(alpha=0.3)

samples_mh_burned = samples_mh[burn_in:]
axes[0, 1].scatter(samples_mh_burned[:, 0], samples_mh_burned[:, 1], 
                   alpha=0.4, s=2, color="#332288")
axes[0, 1].set_xlabel(r"$\theta_1$")
axes[0, 1].set_ylabel(r"$\theta_2$")
axes[0, 1].set_title("Standard MH: Samples")
axes[0, 1].grid(alpha=0.3)
axes[0, 1].axis("equal")

# Count mode visits
mh_left = np.sum(samples_mh_burned[:, 0] < 0)
mh_right = np.sum(samples_mh_burned[:, 0] > 0)
axes[0, 1].text(0.02, 0.98, f"Left mode: {mh_left}\nRight mode: {mh_right}",
                transform=axes[0, 1].transAxes, verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

# Diffusion-assisted
axes[1, 0].plot(samples_diff[:, 0], alpha=0.7, color="#117733", linewidth=0.5)
axes[1, 0].axvline(burn_in, color="#CC6677", linestyle="--", alpha=0.7)
axes[1, 0].set_xlabel("Iteration")
axes[1, 0].set_ylabel(r"$\theta_1$")
axes[1, 0].set_title("Diffusion-Assisted: Trace Plot")
axes[1, 0].grid(alpha=0.3)

samples_diff_burned = samples_diff[burn_in:]
axes[1, 1].scatter(samples_diff_burned[:, 0], samples_diff_burned[:, 1], 
                   alpha=0.4, s=2, color="#117733")
axes[1, 1].set_xlabel(r"$\theta_1$")
axes[1, 1].set_ylabel(r"$\theta_2$")
axes[1, 1].set_title("Diffusion-Assisted: Samples")
axes[1, 1].grid(alpha=0.3)
axes[1, 1].axis("equal")

# Count mode visits
diff_left = np.sum(samples_diff_burned[:, 0] < 0)
diff_right = np.sum(samples_diff_burned[:, 0] > 0)
axes[1, 1].text(0.02, 0.98, f"Left mode: {diff_left}\nRight mode: {diff_right}",
                transform=axes[1, 1].transAxes, verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("Mode Exploration Comparison:")
print("="*60)
print(f"Standard MH:")
print(f"  Left mode:  {mh_left:4d} samples ({mh_left/len(samples_mh_burned)*100:.1f}%)")
print(f"  Right mode: {mh_right:4d} samples ({mh_right/len(samples_mh_burned)*100:.1f}%)")
print(f"  Balance: {min(mh_left, mh_right)/max(mh_left, mh_right):.2f}")
print(f"\nDiffusion-Assisted:")
print(f"  Left mode:  {diff_left:4d} samples ({diff_left/len(samples_diff_burned)*100:.1f}%)")
print(f"  Right mode: {diff_right:4d} samples ({diff_right/len(samples_diff_burned)*100:.1f}%)")
print(f"  Balance: {min(diff_left, diff_right)/max(diff_left, diff_right):.2f}")
print("\nNote: Balance closer to 1.0 indicates better exploration of both modes.")

## Example 3: Himmelblau Function (4 Modes)

The Himmelblau function is a famous optimization test function with 4 minima of equal depth:

$$f(x, y) = (x^2 + y - 11)^2 + (x + y^2 - 7)^2$$

The four minima (modes for our posterior) are located at:
- $(3.0, 2.0)$
- $(-2.8, 3.1)$
- $(-3.8, -3.3)$
- $(3.6, -1.8)$

For MCMC, we define the log-posterior as $\log P(\theta) = -f(\theta)$.

In [None]:
def log_posterior_himmelblau(theta):
    """Log-posterior based on Himmelblau function (4 modes)."""
    x, y = theta
    val = (x**2 + y - 11)**2 + (x + y**2 - 7)**2
    return -val  # Negative for log-posterior (higher is better)

# Seed with points near all four modes
himmelblau_modes = np.array([
    [3.0, 2.0],
    [-2.8, 3.1],
    [-3.8, -3.3],
    [3.6, -1.8]
])

# Add some noise to seed samples
seed_samples = himmelblau_modes + np.random.randn(*himmelblau_modes.shape) * 0.5

print("Sampling from Himmelblau function (4 modes)...")
samples_himmel, accepted_himmel, info_himmel = diffusion_assisted_mcmc(
    log_posterior_himmelblau,
    himmelblau_modes[0],  # Start at first mode
    n_samples=5000,
    seed_samples=seed_samples,
    p_diff=0.8,  # High probability of diffusion proposals
    sigma_mh=0.15,
    retrain_interval=500,
    random_state=42,
    verbose=True,
)

In [None]:
# Visualize Himmelblau sampling
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

burn_in = 1000
samples_himmel_burned = samples_himmel[burn_in:]

# Create contour plot of Himmelblau function
x_range = np.linspace(-5, 5, 200)
y_range = np.linspace(-5, 5, 200)
X, Y = np.meshgrid(x_range, y_range)
Z = -np.array([[log_posterior_himmelblau([x, y]) for x in x_range] for y in y_range])

# Plot 1: Samples with contours
contour = axes[0].contourf(X, Y, Z, levels=20, cmap="viridis", alpha=0.6)
axes[0].scatter(samples_himmel_burned[:, 0], samples_himmel_burned[:, 1],
                c="#CC6677", s=1, alpha=0.5, label="Samples")
axes[0].scatter(himmelblau_modes[:, 0], himmelblau_modes[:, 1],
                c="white", s=100, marker="*", edgecolors="black", linewidths=2,
                label="True modes", zorder=5)
axes[0].set_xlabel("x")
axes[0].set_ylabel("y")
axes[0].set_title("Samples on Himmelblau Function")
axes[0].legend()
axes[0].grid(alpha=0.3)
plt.colorbar(contour, ax=axes[0], label="f(x, y)")

# Plot 2: Trace plot showing mode jumps
axes[1].plot(samples_himmel[:, 0], alpha=0.7, color="#117733", linewidth=0.5)
axes[1].axvline(burn_in, color="#CC6677", linestyle="--", alpha=0.7, label="Burn-in")
for i, mode in enumerate(himmelblau_modes):
    axes[1].axhline(mode[0], color="#332288", linestyle=":", alpha=0.5)
axes[1].set_xlabel("Iteration")
axes[1].set_ylabel("x")
axes[1].set_title("Trace Plot (showing mode jumps)")
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

# Count visits to each mode
mode_counts = []
mode_regions = [
    (2, 5, 1, 3),      # Mode 1: x in [2,5], y in [1,3]
    (-4, -2, 2, 4),    # Mode 2: x in [-4,-2], y in [2,4]
    (-5, -3, -4, -2),  # Mode 3: x in [-5,-3], y in [-4,-2]
    (2, 5, -3, 0),     # Mode 4: x in [2,5], y in [-3,0]
]

print("\n" + "="*60)
print("Mode Visit Statistics:")
print("="*60)
for i, (x_min, x_max, y_min, y_max) in enumerate(mode_regions, 1):
    mask = (samples_himmel_burned[:, 0] >= x_min) & (samples_himmel_burned[:, 0] <= x_max) & \
           (samples_himmel_burned[:, 1] >= y_min) & (samples_himmel_burned[:, 1] <= y_max)
    count = np.sum(mask)
    percentage = count / len(samples_himmel_burned) * 100
    print(f"Mode {i} at {himmelblau_modes[i-1]}: {count:4d} samples ({percentage:5.1f}%)")
    mode_counts.append(count)

modes_visited = sum(c > 50 for c in mode_counts)
print(f"\nNumber of modes with >50 samples: {modes_visited}/4")
print(f"\nThis demonstrates the algorithm's ability to explore multiple modes!")

## Key Takeaways

1. **Multi-modal sampling**: Diffusion-assisted MCMC can effectively jump between distant modes, which standard MCMC struggles with.

2. **Adaptive learning**: The diffusion model improves as sampling progresses, learning the shape of the posterior.

3. **No gradients needed**: Unlike some advanced methods (HMC, MALA), this works with black-box posteriors.

4. **Asymptotic exactness**: Despite using an approximate proposal, the acceptance step ensures convergence to the true posterior.

5. **Hyperparameter tuning**: Key parameters are:
   - `p_diff`: Balance between local and global proposals (0.3-0.8 typically works well)
   - `sigma_mh`: Step size for local proposals
   - `retrain_interval`: How often to update the diffusion model (trade-off with compute time)
   - `seed_samples`: For multi-modal problems, seeding near all modes helps

## When to Use This Algorithm

**Good for:**
- Multi-modal posteriors
- Posteriors with well-separated regions of high probability
- Black-box likelihoods (no gradient information)
- When you want a learned approximation of the posterior for fast sampling

**Maybe not needed for:**
- Simple unimodal Gaussians (standard MCMC is fine)
- When you have gradient information (consider HMC/NUTS instead)
- Very high dimensions (diffusion model training becomes expensive)

## References

- Hunt-Smith et al. (2023), "Accelerating Markov Chain Monte Carlo sampling with diffusion models", arXiv:2309.01454v1
- Implementation available at: https://github.com/NickHunt-Smith/MCMC-diffusion