In [None]:
from typing import Callable, Tuple

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats


class RejectionControl:
    """
    Implementation of the Rejection Control algorithm from Exercise 5.
    Combines rejection sampling with importance weighting to reduce variance.
    """

    def __init__(
        self,
        target_pdf: Callable,  # π(x) - unnormalized
        proposal_pdf: Callable,  # q(x) - normalized
        proposal_sampler: Callable,  # samples from q
        threshold: float = 1.0,
    ):  # c parameter
        """
        Initialize rejection control sampler.

        Args:
            target_pdf: Unnormalized target density π̃(x)
            proposal_pdf: Normalized proposal density q(x)
            proposal_sampler: Function that returns samples from q
            threshold: Threshold parameter c for rejection control
        """
        self.target_pdf = target_pdf
        self.proposal_pdf = proposal_pdf
        self.proposal_sampler = proposal_sampler
        self.c = threshold

    def sample(self, n_samples: int) -> Tuple[np.ndarray, np.ndarray, dict]:
        """
        Generate samples using rejection control.

        Returns:
            samples: Accepted samples
            weights: Importance weights w*(x) for each accepted sample
            info: Dictionary with algorithm statistics
        """
        samples = []
        weights = []
        n_proposed = 0
        n_accepted = 0

        while len(samples) < n_samples:
            # Step a: Generate X ~ q, U ~ U[0,1]
            x = self.proposal_sampler()
            u = np.random.uniform()
            n_proposed += 1

            # Calculate w(x) = π(x)/q(x)
            w_x = self.target_pdf(x) / self.proposal_pdf(x)

            # Accept if U ≤ min{1, w(X)/c}
            accept_prob = min(1.0, w_x / self.c)

            if u <= accept_prob:
                samples.append(x)
                # Calculate w*(x) for the accepted sample
                # w*(x) = w(x) * c / min{c, w(x)} = max{c, w(x)}
                w_star = max(self.c, w_x)
                weights.append(w_star)
                n_accepted += 1

        info = {
            "acceptance_rate": n_accepted / n_proposed,
            "n_proposed": n_proposed,
            "n_accepted": n_accepted,
        }

        return np.array(samples), np.array(weights), info


def standard_importance_sampling(
    target_pdf: Callable,
    proposal_pdf: Callable,
    proposal_sampler: Callable,
    n_samples: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """Standard importance sampling for comparison."""
    samples = np.array([proposal_sampler() for _ in range(n_samples)])
    weights = np.array([target_pdf(x) / proposal_pdf(x) for x in samples])
    return samples, weights


# Example: Heavy-tailed target with Gaussian proposal
def example_heavy_tailed():
    """
    Example where rejection control significantly reduces variance.
    Target: Student's t-distribution (heavy tails)
    Proposal: Normal distribution
    """

    # Target: Student's t-distribution with df=3 (heavy tails)
    df = 3

    def target_pdf(x):
        return stats.t.pdf(x, df)

    # Proposal: Standard normal (lighter tails)
    def proposal_pdf(x):
        return stats.norm.pdf(x, 0, 1.5)

    def proposal_sampler():
        return np.random.normal(0, 1.5)

    # Run experiments with different threshold values
    n_samples = 5000
    n_runs = 100  # Multiple runs to estimate variance

    # Test different threshold values
    thresholds = [0.5, 1.0, 2.0, 5.0, 10.0]
    results = {}

    print("Comparing Rejection Control with Standard Importance Sampling")
    print("=" * 60)

    for c in thresholds:
        # Rejection Control
        rc_sampler = RejectionControl(target_pdf, proposal_pdf, proposal_sampler, c)
        rc_weights_all = []
        acceptance_rates = []

        for _ in range(n_runs):
            _, rc_weights, info = rc_sampler.sample(n_samples)
            rc_weights_all.append(rc_weights)
            acceptance_rates.append(info["acceptance_rate"])

        rc_weights_all = np.array(rc_weights_all)
        rc_var = np.mean(np.var(rc_weights_all, axis=1))

        results[c] = {"variance": rc_var, "acceptance_rate": np.mean(acceptance_rates)}

    # Standard Importance Sampling
    is_weights_all = []
    for _ in range(n_runs):
        _, is_weights = standard_importance_sampling(
            target_pdf, proposal_pdf, proposal_sampler, n_samples
        )
        is_weights_all.append(is_weights)

    is_weights_all = np.array(is_weights_all)
    is_var = np.mean(np.var(is_weights_all, axis=1))

    print(f"\nStandard IS - Weight variance: {is_var:.2f}")
    print("\nRejection Control Results:")
    print("-" * 40)

    for c, res in results.items():
        variance_reduction = (1 - res["variance"] / is_var) * 100
        print(
            f"c = {c:4.1f}: Variance = {res['variance']:8.2f}, "
            f"Acceptance = {res['acceptance_rate']:.2%}, "
            f"Variance reduction = {variance_reduction:+.1f}%"
        )

    return results, is_var


def example_exponential_tail():
    """
    Another example: Estimating tail probabilities of exponential distribution
    using a shifted exponential proposal.
    """

    # We want to estimate E[X | X > a] for X ~ Exp(λ)
    lambda_param = 1.0
    a = 3.0  # Threshold for tail

    # Target: Exponential conditioned on X > a
    def target_pdf(x):
        return lambda_param * np.exp(-lambda_param * x) if x > a else 0

    # Proposal: Shifted exponential starting at a
    def proposal_pdf(x):
        return lambda_param * np.exp(-lambda_param * (x - a)) if x > a else 0

    def proposal_sampler():
        return a + np.random.exponential(1 / lambda_param)

    print("\n" + "=" * 60)
    print("Example 2: Exponential Tail Estimation")
    print("=" * 60)

    n_samples = 2000
    n_runs = 100

    # Test different thresholds
    thresholds = [0.5, 1.0, 2.0]

    for c in thresholds:
        rc_sampler = RejectionControl(target_pdf, proposal_pdf, proposal_sampler, c)

        # Estimate E[X | X > a]
        estimates_rc = []
        for _ in range(n_runs):
            samples, weights, _ = rc_sampler.sample(n_samples)
            # Normalized importance sampling estimate
            estimate = np.sum(samples * weights) / np.sum(weights)
            estimates_rc.append(estimate)

        # Standard IS
        estimates_is = []
        for _ in range(n_runs):
            samples, weights = standard_importance_sampling(
                target_pdf, proposal_pdf, proposal_sampler, n_samples
            )
            estimate = np.sum(samples * weights) / np.sum(weights)
            estimates_is.append(estimate)

        # True value: E[X | X > a] = a + 1/λ
        true_value = a + 1 / lambda_param

        mse_rc = np.mean((np.array(estimates_rc) - true_value) ** 2)
        mse_is = np.mean((np.array(estimates_is) - true_value) ** 2)

        print(f"\nc = {c}: MSE reduction = {(1 - mse_rc / mse_is) * 100:.1f}%")
        print(
            f"  RC: bias = {np.mean(estimates_rc) - true_value:.4f}, "
            f"std = {np.std(estimates_rc):.4f}"
        )
        print(
            f"  IS: bias = {np.mean(estimates_is) - true_value:.4f}, "
            f"std = {np.std(estimates_is):.4f}"
        )


def visualize_weights_distribution():
    """Visualize the distribution of importance weights."""

    # Setup
    df = 3

    def target_pdf(x):
        return stats.t.pdf(x, df)

    def proposal_pdf(x):
        return stats.norm.pdf(x, 0, 1.5)

    def proposal_sampler():
        return np.random.normal(0, 1.5)

    n_samples = 2000

    # Standard IS
    samples_is, weights_is = standard_importance_sampling(
        target_pdf, proposal_pdf, proposal_sampler, n_samples
    )

    # Rejection Control with c=2
    rc_sampler = RejectionControl(target_pdf, proposal_pdf, proposal_sampler, 2.0)
    _, weights_rc, info = rc_sampler.sample(n_samples)

    # Create visualization
    _, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Plot 1: Weight distributions
    axes[0].hist(weights_is, bins=50, alpha=0.5, label="Standard IS", density=True)
    axes[0].hist(
        weights_rc, bins=50, alpha=0.5, label="Rejection Control", density=True
    )
    axes[0].set_xlabel("Weight value")
    axes[0].set_ylabel("Density")
    axes[0].set_title("Distribution of Importance Weights")
    axes[0].legend()
    axes[0].set_xlim([0, 10])

    # Plot 2: Cumulative weight contribution
    weights_is_sorted = np.sort(weights_is)[::-1]
    weights_rc_sorted = np.sort(weights_rc)[::-1]

    cumsum_is = np.cumsum(weights_is_sorted) / np.sum(weights_is_sorted)
    cumsum_rc = np.cumsum(weights_rc_sorted) / np.sum(weights_rc_sorted)

    axes[1].plot(
        np.arange(len(cumsum_is)) / len(cumsum_is),
        cumsum_is,
        label="Standard IS",
        linewidth=2,
    )
    axes[1].plot(
        np.arange(len(cumsum_rc)) / len(cumsum_rc),
        cumsum_rc,
        label="Rejection Control",
        linewidth=2,
    )
    axes[1].set_xlabel("Fraction of samples")
    axes[1].set_ylabel("Cumulative weight contribution")
    axes[1].set_title("Weight Concentration")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    # Plot 3: Box plot of weights
    axes[2].boxplot(
        [weights_is, weights_rc], labels=["Standard IS", "Rejection Control"]
    )
    axes[2].set_ylabel("Weight value")
    axes[2].set_title("Weight Distribution Comparison")
    axes[2].set_ylim([0, 10])

    plt.tight_layout()
    plt.show()

    print("\n" + "=" * 60)
    print("Weight Statistics:")
    print("-" * 40)
    print(
        f"Standard IS:        mean = {np.mean(weights_is):.3f}, "
        f"std = {np.std(weights_is):.3f}, max = {np.max(weights_is):.3f}"
    )
    print(
        f"Rejection Control:  mean = {np.mean(weights_rc):.3f}, "
        f"std = {np.std(weights_rc):.3f}, max = {np.max(weights_rc):.3f}"
    )
    print(f"Acceptance rate: {info['acceptance_rate']:.2%}")


if __name__ == "__main__":
    # Run examples
    print("REJECTION CONTROL ALGORITHM DEMONSTRATION")
    print("=" * 60)

    # Example 1: Heavy-tailed distribution
    results, is_var = example_heavy_tailed()

    # Example 2: Exponential tail estimation
    example_exponential_tail()

    # Visualize weight distributions
    print("\n" + "=" * 60)
    print("Generating visualization of weight distributions...")
    visualize_weights_distribution()

REJECTION CONTROL ALGORITHM DEMONSTRATION
Comparing Rejection Control with Standard Importance Sampling


KeyboardInterrupt: 