In [None]:
import warnings
from typing import Callable, Dict, Optional, Tuple

import numpy as np
from scipy.stats import norm


def maximal_coupling_gaussian(
    mu_x: np.ndarray, mu_y: np.ndarray, cov: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Reflection maximal coupling for multivariate Gaussians.
    More efficient and stable than rejection sampling.
    """
    d = len(mu_x)

    # Sample from first distribution
    L = np.linalg.cholesky(cov)
    z = np.random.randn(d)
    x = mu_x + L @ z

    # Compute coupling probability
    delta = mu_y - mu_x
    dist_squared = delta @ np.linalg.solve(cov, delta)

    if dist_squared < 1e-10:
        return x, x

    # Reflection coupling
    e = delta / np.sqrt(dist_squared)
    proj = z @ np.linalg.solve(cov, delta) / np.sqrt(dist_squared)

    # Coupling probability based on overlap
    couple_prob = 2 * norm.cdf(-np.sqrt(dist_squared) / 2)

    if np.random.rand() < couple_prob:
        y = x  # Coupled
    else:
        # Reflection
        y = x + delta - 2 * proj * (L @ L.T @ e)

    return x, y


def coupled_rwmh_step(
    x: np.ndarray,
    y: np.ndarray,
    log_target: Callable,
    proposal_cov: np.ndarray,
    common_random: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Single coupled Random Walk Metropolis-Hastings step.
    """
    # Generate coupled proposals
    x_prop, y_prop = maximal_coupling_gaussian(x, y, proposal_cov)

    # Use common random number for accept-reject
    u = common_random if common_random is not None else np.random.rand()

    # Compute acceptance ratios
    log_alpha_x = min(0, log_target(x_prop) - log_target(x))
    log_alpha_y = min(0, log_target(y_prop) - log_target(y))

    # Accept/reject
    x_new = x_prop if np.log(u) < log_alpha_x else x.copy()
    y_new = y_prop if np.log(u) < log_alpha_y else y.copy()

    return x_new, y_new


def unbiased_mcmc_estimator(
    log_target: Callable,
    h: Callable,
    dim: int,
    n_estimators: int = 100,
    k: Optional[int] = None,
    m: Optional[int] = None,
    proposal_cov: Optional[np.ndarray] = None,
    x_init: Optional[np.ndarray] = None,
    lag: int = 1,
    pilot_runs: int = 50,
) -> Dict:
    """
    Complete unbiased MCMC estimation with automatic parameter tuning.

    Args:
        log_target: Log density of target distribution
        h: Test function to estimate E[h(X)]
        dim: Dimension of the problem
        n_estimators: Number of independent unbiased estimators
        k: Burn-in parameter (auto-tuned if None)
        m: Time horizon (auto-tuned if None)
        proposal_cov: Proposal covariance (auto-tuned if None)
        x_init: Initial value (random if None)
        lag: Lag between chains (default 1)
        pilot_runs: Number of pilot runs for tuning

    Returns:
        Dictionary with estimates, confidence intervals, and diagnostics
    """

    # Initialize
    if x_init is None:
        x_init = np.random.randn(dim)

    if proposal_cov is None:
        # Start with scaled identity
        proposal_cov = (2.38**2 / dim) * np.eye(dim)

    # Pilot runs to estimate meeting times and tune parameters
    if k is None or m is None:
        print(f"Running {pilot_runs} pilot runs for parameter tuning...")
        meeting_times = []

        for _ in range(pilot_runs):
            x_chain = [x_init + 0.1 * np.random.randn(dim)]  # Small perturbation
            y_chain = []

            # Lag phase
            for _ in range(lag):
                x_prop = np.random.multivariate_normal(x_chain[-1], proposal_cov)
                log_alpha = min(0, log_target(x_prop) - log_target(x_chain[-1]))
                if np.log(np.random.rand()) < log_alpha:
                    x_chain.append(x_prop)
                else:
                    x_chain.append(x_chain[-1].copy())

            y_chain = [x_init + 0.1 * np.random.randn(dim)]

            # Coupled phase - run until meeting
            met = False
            max_iter = 5000
            for t in range(max_iter):
                x_new, y_new = coupled_rwmh_step(
                    x_chain[-1], y_chain[-1], log_target, proposal_cov
                )
                x_chain.append(x_new)
                y_chain.append(y_new)

                if np.linalg.norm(x_new - y_new) < 1e-10:
                    meeting_times.append(len(x_chain))
                    met = True
                    break

            if not met:
                warnings.warn(f"Chains didn't meet in {max_iter} iterations")
                meeting_times.append(max_iter)

        meeting_times = np.array(meeting_times)

        if k is None:
            k = int(np.percentile(meeting_times, 90))
        if m is None:
            m = k * 10

        print("Meeting time statistics:")
        print(f"  Median: {np.median(meeting_times):.0f}")
        print(f"  90th percentile: {np.percentile(meeting_times, 90):.0f}")
        print(f"  99th percentile: {np.percentile(meeting_times, 99):.0f}")
        print(f"Selected k={k}, m={m}")

    # Generate unbiased estimators
    estimates = []
    meeting_times_actual = []

    for i in range(n_estimators):
        # Run coupled chains
        x_chain = np.zeros((m + lag + 1, dim))
        y_chain = np.zeros((m + lag + 1, dim))

        x_chain[0] = x_init + 0.1 * np.random.randn(dim)

        # Lag phase
        for t in range(lag):
            x_prop = np.random.multivariate_normal(x_chain[t], proposal_cov)
            log_alpha = min(0, log_target(x_prop) - log_target(x_chain[t]))
            if np.log(np.random.rand()) < log_alpha:
                x_chain[t + 1] = x_prop
            else:
                x_chain[t + 1] = x_chain[t]

        y_chain[lag - 1] = x_init + 0.1 * np.random.randn(dim)

        # Coupled evolution
        tau = m + lag + 1  # Default if no meeting
        for t in range(lag, m + lag):
            x_chain[t + 1], y_chain[t] = coupled_rwmh_step(
                x_chain[t], y_chain[t - 1], log_target, proposal_cov
            )

            if tau > m + lag and np.linalg.norm(x_chain[t + 1] - y_chain[t]) < 1e-10:
                tau = t + 1
                # After meeting, chains move together
                for s in range(t + 1, m + lag):
                    x_prop = np.random.multivariate_normal(x_chain[s], proposal_cov)
                    log_alpha = min(0, log_target(x_prop) - log_target(x_chain[s]))
                    if np.log(np.random.rand()) < log_alpha:
                        x_chain[s + 1] = x_prop
                        y_chain[s] = x_prop
                    else:
                        x_chain[s + 1] = x_chain[s]
                        y_chain[s] = x_chain[s]
                break

        meeting_times_actual.append(tau)

        # Compute unbiased estimator H_{k:m}
        mcmc_avg = np.mean([h(x_chain[i]) for i in range(k, m + 1)])

        # Bias correction
        bc_sum = 0.0
        for t in range(k + 1, min(tau, m + 1)):
            weight = min(1.0, (t - k) / (m - k + 1))
            bc_sum += weight * (h(x_chain[t]) - h(y_chain[t - lag]))

        estimates.append(mcmc_avg + bc_sum / (m - k + 1))

        if (i + 1) % 10 == 0:
            print(f"  Completed {i + 1}/{n_estimators} estimators")

    estimates = np.array(estimates)

    # Compute statistics
    mean_est = np.mean(estimates)
    std_err = np.std(estimates) / np.sqrt(n_estimators)
    ci_lower = mean_est - 1.96 * std_err
    ci_upper = mean_est + 1.96 * std_err

    return {
        "mean": mean_est,
        "std_error": std_err,
        "ci_95": (ci_lower, ci_upper),
        "estimates": estimates,
        "meeting_times": meeting_times_actual,
        "k": k,
        "m": m,
        "n_estimators": n_estimators,
    }


# Test Case 1: Gaussian with known mean and variance
def test_gaussian():
    """Test on a simple shifted Gaussian where we know the true values."""
    print("\n" + "=" * 60)
    print("TEST 1: Gaussian(mean=3, variance=2)")
    print("=" * 60)

    # Target: N(3, 2)
    true_mean = 3.0
    true_var = 2.0

    def log_target(x):
        return -0.5 * np.sum((x - true_mean) ** 2) / true_var

    # Test E[X]
    def h_mean(x):
        return x[0]

    result = unbiased_mcmc_estimator(
        log_target,
        h_mean,
        dim=1,
        n_estimators=100,
        proposal_cov=np.array([[1.5]]),
        pilot_runs=20,
    )

    print("\nEstimating E[X]:")
    print(f"  True value:  {true_mean:.4f}")
    print(f"  Estimate:    {result['mean']:.4f}")
    print(f"  95% CI:      [{result['ci_95'][0]:.4f}, {result['ci_95'][1]:.4f}]")
    print(f"  Covers true: {result['ci_95'][0] <= true_mean <= result['ci_95'][1]}")

    # Test E[X²]
    def h_sq(x):
        return x[0] ** 2

    result2 = unbiased_mcmc_estimator(
        log_target,
        h_sq,
        dim=1,
        n_estimators=100,
        k=result["k"],
        m=result["m"],  # Reuse tuned parameters
        proposal_cov=np.array([[1.5]]),
        pilot_runs=0,
    )

    true_second_moment = true_var + true_mean**2
    print("\nEstimating E[X²]:")
    print(f"  True value:  {true_second_moment:.4f}")
    print(f"  Estimate:    {result2['mean']:.4f}")
    print(f"  95% CI:      [{result2['ci_95'][0]:.4f}, {result2['ci_95'][1]:.4f}]")
    print(
        f"  Covers true: {result2['ci_95'][0] <= true_second_moment <= result2['ci_95'][1]}"
    )


# Test Case 2: Bayesian logistic regression
def test_logistic_regression():
    """Test on a 2D logistic regression posterior."""
    print("\n" + "=" * 60)
    print("TEST 2: Bayesian Logistic Regression (2D)")
    print("=" * 60)

    # Simple dataset
    np.random.seed(123)
    n = 30
    X = np.random.randn(n, 2)
    true_beta = np.array([1.0, -0.5])
    y = (np.random.rand(n) < 1 / (1 + np.exp(-X @ true_beta))).astype(int)

    def log_target(beta):
        """Log posterior with N(0, 5) prior."""
        logits = X @ beta
        log_lik = np.sum(y * logits - np.log(1 + np.exp(logits)))
        log_prior = -0.5 * np.sum(beta**2) / 5
        return log_lik + log_prior

    def h_mean(beta):
        return beta

    result = unbiased_mcmc_estimator(
        log_target, h_mean, dim=2, n_estimators=100, pilot_runs=30
    )

    print("\nPosterior mean estimate:")
    print(f"  Beta[0]: {result['mean'][0]:.4f} ± {1.96 * result['std_error']:.4f}")
    print(f"  Beta[1]: {result['mean'][1]:.4f} ± {1.96 * result['std_error']:.4f}")
    print(
        f"  Meeting times: median={np.median(result['meeting_times']):.0f}, "
        f"max={np.max(result['meeting_times']):.0f}"
    )


if __name__ == "__main__":
    test_gaussian()
    # test_logistic_regression()


TEST 1: Gaussian(mean=3, variance=2)
Running 20 pilot runs for parameter tuning...
Meeting time statistics:
  Median: 3
  90th percentile: 6
  99th percentile: 12
Selected k=6, m=60
  Completed 10/100 estimators
  Completed 20/100 estimators
  Completed 30/100 estimators
  Completed 40/100 estimators
  Completed 50/100 estimators
  Completed 60/100 estimators
  Completed 70/100 estimators
  Completed 80/100 estimators
  Completed 90/100 estimators
  Completed 100/100 estimators

Estimating E[X]:
  True value:  3.0000
  Estimate:    2.9235
  95% CI:      [2.8118, 3.0353]
  Covers true: True
  Completed 10/100 estimators
  Completed 20/100 estimators
  Completed 30/100 estimators
  Completed 40/100 estimators
  Completed 50/100 estimators
  Completed 60/100 estimators
  Completed 70/100 estimators
  Completed 80/100 estimators
  Completed 90/100 estimators
  Completed 100/100 estimators

Estimating E[X²]:
  True value:  11.0000
  Estimate:    10.9174
  95% CI:      [10.2861, 11.5487]
  