In [None]:
import matplotlib.pyplot as plt
import numpy as np


class RobustEllipticalSliceSampler:
    """
    Robust implementation of ESS that handles numerical edge cases.
    """

    def __init__(self, log_likelihood_fn, prior_cov, max_iterations=1000):
        self.log_likelihood_fn = log_likelihood_fn
        self.prior_cov = prior_cov
        self.prior_chol = np.linalg.cholesky(prior_cov)
        self.max_iterations = max_iterations

    def sample_prior(self):
        z = np.random.standard_normal(self.prior_cov.shape[0])
        return self.prior_chol @ z

    def step(self, f_current):
        """Robust ESS step with detailed diagnostics."""

        # Step 1: Draw auxiliary variable
        nu = self.sample_prior()

        # Step 2: Compute threshold
        current_log_lik = self.log_likelihood_fn(f_current)
        u = np.random.random()
        log_y = current_log_lik + np.log(u)

        print(f"Current log-likelihood: {current_log_lik:.6f}")
        print(f"Random u: {u:.6f}, log(u): {np.log(u):.6f}")
        print(f"Threshold log_y: {log_y:.6f}")

        # Step 3: Verify θ=0 is valid (should always be true!)
        f_at_zero = f_current * np.cos(0) + nu * np.sin(0)  # = f_current
        log_lik_at_zero = self.log_likelihood_fn(f_at_zero)

        print(f"Log-likelihood at θ=0: {log_lik_at_zero:.6f}")
        print(f"Difference from current: {log_lik_at_zero - current_log_lik:.2e}")

        if log_lik_at_zero < log_y - 1e-10:  # Small tolerance for numerical errors
            print("⚠️  WARNING: θ=0 should be valid but isn't due to numerical error!")
            print(f"   Expected: {log_lik_at_zero:.10f} >= {log_y:.10f}")
            print(f"   Difference: {log_lik_at_zero - log_y:.2e}")
            # Fix by slightly lowering threshold
            log_y = log_lik_at_zero - 1e-12
            print(f"   Adjusted threshold to: {log_y:.10f}")

        # Step 4: Bracket shrinking with safeguards
        theta = np.random.uniform(0, 2 * np.pi)
        theta_min = theta - 2 * np.pi
        theta_max = theta + 2 * np.pi

        iteration = 0
        angles_tested = []
        likelihoods = []

        while iteration < self.max_iterations:
            # Test current angle
            f_proposal = f_current * np.cos(theta) + nu * np.sin(theta)
            proposal_log_lik = self.log_likelihood_fn(f_proposal)

            angles_tested.append(theta)
            likelihoods.append(proposal_log_lik)

            if proposal_log_lik > log_y:
                print(f"✅ Success at iteration {iteration}, θ={theta:.4f}")
                print(f"   Final bracket: [{theta_min:.4f}, {theta_max:.4f}]")
                return f_proposal

            # Shrink bracket
            if theta < 0:
                theta_min = theta
            else:
                theta_max = theta

            # Sample new angle
            theta = np.random.uniform(theta_min, theta_max)
            iteration += 1

            # Progress update
            if iteration % 100 == 0:
                bracket_size = theta_max - theta_min
                print(f"   Iteration {iteration}: bracket size = {bracket_size:.6f}")

        # If we get here, something went wrong
        print(f"❌ FAILED after {self.max_iterations} iterations!")
        print(f"Final bracket: [{theta_min:.6f}, {theta_max:.6f}]")
        print(f"Bracket size: {theta_max - theta_min:.2e}")

        # Diagnostic plot
        self._plot_failure_diagnostics(
            angles_tested, likelihoods, log_y, f_current, nu, current_log_lik
        )

        # Emergency fallback: return current state (not ideal!)
        print("🆘 Emergency fallback: returning current state")
        return f_current

    def _plot_failure_diagnostics(
        self, angles_tested, likelihoods, log_y, f_current, nu, current_log_lik
    ):
        """Plot diagnostics when bracket shrinking fails."""

        plt.figure(figsize=(12, 8))

        # Plot 1: Likelihood along the ellipse
        plt.subplot(2, 2, 1)
        theta_fine = np.linspace(-np.pi, np.pi, 200)
        log_liks_fine = []

        for t in theta_fine:
            f_t = f_current * np.cos(t) + nu * np.sin(t)
            log_liks_fine.append(self.log_likelihood_fn(f_t))

        plt.plot(
            theta_fine,
            log_liks_fine,
            "b-",
            linewidth=2,
            label="Likelihood along ellipse",
        )
        plt.axhline(
            log_y, color="red", linestyle="--", label=f"Threshold = {log_y:.3f}"
        )
        plt.axhline(
            current_log_lik,
            color="green",
            linestyle="--",
            label=f"Current = {current_log_lik:.3f}",
        )
        plt.scatter(
            angles_tested,
            likelihoods,
            c="orange",
            s=10,
            alpha=0.7,
            label="Tested points",
        )
        plt.xlabel("Angle θ")
        plt.ylabel("Log-likelihood")
        plt.title("Likelihood Along Ellipse")
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Plot 2: Zoom around θ=0
        plt.subplot(2, 2, 2)
        theta_zoom = np.linspace(-0.1, 0.1, 50)
        log_liks_zoom = []

        for t in theta_zoom:
            f_t = f_current * np.cos(t) + nu * np.sin(t)
            log_liks_zoom.append(self.log_likelihood_fn(f_t))

        plt.plot(theta_zoom, log_liks_zoom, "b-", linewidth=2)
        plt.axhline(log_y, color="red", linestyle="--")
        plt.axvline(0, color="black", linestyle="-", alpha=0.5, label="θ=0")
        plt.xlabel("Angle θ (near 0)")
        plt.ylabel("Log-likelihood")
        plt.title("Zoom Around θ=0")
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Plot 3: Bracket shrinking progress
        plt.subplot(2, 2, 3)
        plt.scatter(
            range(len(angles_tested)),
            angles_tested,
            c=likelihoods,
            cmap="viridis",
            s=20,
        )
        plt.colorbar(label="Log-likelihood")
        plt.axhline(0, color="black", linestyle="-", alpha=0.5)
        plt.xlabel("Iteration")
        plt.ylabel("Tested angle θ")
        plt.title("Bracket Shrinking Progress")
        plt.grid(True, alpha=0.3)

        # Plot 4: Distribution of tested likelihoods
        plt.subplot(2, 2, 4)
        plt.hist(likelihoods, bins=20, alpha=0.7, edgecolor="black")
        plt.axvline(log_y, color="red", linestyle="--", linewidth=2, label="Threshold")
        plt.axvline(
            current_log_lik, color="green", linestyle="--", linewidth=2, label="Current"
        )
        plt.xlabel("Log-likelihood")
        plt.ylabel("Count")
        plt.title("Distribution of Tested Likelihoods")
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


# Example of potential failure cases
def demonstrate_edge_cases():
    """Demonstrate when ESS might struggle and how to handle it."""

    print("=== ESS Edge Case Analysis ===\n")

    # Case 1: Very peaked likelihood
    print("Case 1: Very peaked likelihood function")

    def peaked_log_likelihood(f):
        return -1000 * np.sum(f**2)  # Extremely peaked at origin

    prior_cov = np.eye(2)
    sampler = RobustEllipticalSliceSampler(
        peaked_log_likelihood, prior_cov, max_iterations=50
    )

    f_current = np.array([0.01, 0.01])  # Start near peak
    print(f"Starting point: {f_current}")
    print("This should work but might take many iterations...\n")

    try:
        f_new = sampler.step(f_current)
        print(f"Success! New point: {f_new}")
    except Exception as e:
        print(f"Failed with error: {e}")

    print("\n" + "=" * 50 + "\n")

    # Case 2: Likelihood with numerical issues
    print("Case 2: Likelihood with numerical precision issues")

    def noisy_log_likelihood(f):
        base_lik = -0.5 * np.sum(f**2)
        # Add tiny numerical noise
        noise = 1e-15 * np.random.randn()
        return base_lik + noise

    sampler2 = RobustEllipticalSliceSampler(
        noisy_log_likelihood, prior_cov, max_iterations=50
    )

    f_current = np.array([1.0, 1.0])
    print(f"Starting point: {f_current}")
    print("This might show numerical precision warnings...\n")

    try:
        f_new = sampler2.step(f_current)
        print(f"Success! New point: {f_new}")
    except Exception as e:
        print(f"Failed with error: {e}")


if __name__ == "__main__":
    demonstrate_edge_cases()

=== ESS Edge Case Analysis ===

Case 1: Very peaked likelihood function
Starting point: [0.01 0.01]
This should work but might take many iterations...

Current log-likelihood: -0.200000
Random u: 0.567762, log(u): -0.566054
Threshold log_y: -0.766054
Log-likelihood at θ=0: -0.200000
Difference from current: 0.00e+00
✅ Success at iteration 6, θ=-0.0677
   Final bracket: [-0.1991, 0.6047]
Success! New point: [ 0.02444545 -0.01123501]


Case 2: Likelihood with numerical precision issues
Starting point: [1. 1.]

Current log-likelihood: -1.000000
Random u: 0.870129, log(u): -0.139113
Threshold log_y: -1.139113
Log-likelihood at θ=0: -1.000000
Difference from current: -3.33e-16
✅ Success at iteration 0, θ=4.0411
   Final bracket: [-2.2420, 10.3243]
Success! New point: [-0.45848177 -0.96529323]
