<a href="https://colab.research.google.com/github/jingjieyuan573-bite/Composite_Distribution_analysis/blob/main/Composite_Monte_Threshold_Sensitivity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
"""
Composite Monte Carlo Simulation (Fixed thresholds)

- Runs for four fixed thresholds: Baseline, Lower, Higher, Extreme
- Computes Bias, MSE, Avg LOGLIKE, Tail KS with 95% bootstrap CI
- Prints results in aligned table form
- Saves summary CSV and LaTeX
"""

import time
import numpy as np
import pandas as pd
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# ---------- RNG helper ----------
def ensure_rng(rs=None):
    if rs is None:
        return np.random.RandomState(None)
    if isinstance(rs, (int, np.integer)):
        return np.random.RandomState(int(rs))
    if isinstance(rs, np.random.RandomState):
        return rs
    return np.random.RandomState(None)

# ---------- Composite sampling ----------
def composite_rvs(size=1, theta1=-0.3, theta2=0.3, left_frac=0.08, right_frac=0.08, random_state=None):
    rng = ensure_rng(random_state)
    n = int(size)
    samples = np.empty(n, dtype=float)
    u = rng.rand(n)
    r1 = left_frac; r2 = right_frac
    comp = np.where(u < r1, 0, np.where(u < r1 + (1 - r1 - r2), 1, 2))

    for k in (0, 1, 2):
        idx = np.where(comp == k)[0]
        if idx.size == 0:
            continue
        need = idx.size
        draws = []
        while len(draws) < need:
            if k == 1:
                batch = stats.uniform.rvs(loc=theta1, scale=theta2-theta1, size=max(need*2, 500), random_state=rng)
                valid = batch[(batch > theta1) & (batch < theta2)]
            else:
                batch = stats.norm.rvs(loc=0, scale=1, size=max(need*2, 500), random_state=rng)
                if k == 0:
                    valid = batch[batch <= theta1]
                else:
                    valid = batch[batch >= theta2]
            draws.extend(valid.tolist())
        samples[idx] = np.array(draws[:need])
    return samples

# ---------- Empirical CDF ----------
def empirical_cdf_factory(sample):
    s = np.sort(np.asarray(sample))
    n = len(s)
    def cdf(x):
        x = np.asarray(x)
        return np.searchsorted(s, x, side='right') / float(n)
    return cdf

# ---------- Tail KS ----------
def compute_tailks(cdf_func, data_sorted, upper_pct, lower_pct):
    lower_val = np.percentile(data_sorted, lower_pct)
    upper_val = np.percentile(data_sorted, upper_pct)
    tail_points = np.concatenate([data_sorted[data_sorted <= lower_val], data_sorted[data_sorted >= upper_val]])
    if len(tail_points) == 0:
        return np.nan
    fitted_vals = np.asarray(cdf_func(tail_points))
    n = len(data_sorted)
    empirical_vals = np.searchsorted(data_sorted, tail_points, side='right') / float(n)
    return float(np.max(np.abs(fitted_vals - empirical_vals)))

# ---------- Monte Carlo per threshold ----------
def monte_carlo_sim(M=50, n=2000, theta1=-0.3, theta2=0.3, mc_sample_size=20000, seed=None):
    rng = ensure_rng(seed)
    rows = []

    for rep in range(int(M)):
        data = composite_rvs(size=n, theta1=theta1, theta2=theta2, random_state=rng)
        data_sorted = np.sort(data)

        # Fit composite "model" by region
        n_left = np.sum(data <= theta1)
        n_mid = np.sum((data > theta1) & (data < theta2))
        n_right = np.sum(data >= theta2)

        # Create parametric sample for estimated CDF
        n_left_s = max(1, int(mc_sample_size * n_left / n))
        n_mid_s = max(1, int(mc_sample_size * n_mid / n))
        n_right_s = max(1, mc_sample_size - n_left_s - n_mid_s)

        left_sample = stats.norm.rvs(loc=0, scale=1, size=n_left_s, random_state=rng)
        mid_sample = stats.uniform.rvs(loc=theta1, scale=theta2-theta1, size=n_mid_s, random_state=rng)
        right_sample = stats.norm.rvs(loc=0, scale=1, size=n_right_s, random_state=rng)

        param_sample = np.concatenate([left_sample, mid_sample, right_sample])
        fitted_cdf = empirical_cdf_factory(param_sample)
        tailks = compute_tailks(fitted_cdf, data_sorted, 95, 5)

        sim_draw = rng.choice(param_sample, size=n, replace=True)
        bias_mean = float(np.mean(sim_draw))
        mse_mean = float(np.mean((sim_draw - 0.0)**2))
        avg_loglike = float(np.mean(np.log(np.maximum(fitted_cdf(data), 1e-12))))

        rows.append({'rep': rep, 'model': 'Composite',
                     'avg_loglike': avg_loglike,
                     'tailks': tailks,
                     'bias_mean': bias_mean,
                     'mse_mean': mse_mean})
    return pd.DataFrame(rows)

# ---------- Bootstrap CI ----------
def bootstrap_ci(df, model, col, B=500, alpha=0.05, rng=None):
    rng = ensure_rng(rng)
    vals = df[df['model'] == model][col].values
    n = len(vals)
    if n == 0:
        return (np.nan, np.nan)
    boot_means = [np.mean(rng.choice(vals, size=n, replace=True)) for _ in range(int(B))]
    lo = np.percentile(boot_means, 100*alpha/2)
    hi = np.percentile(boot_means, 100*(1-alpha/2))
    return float(lo), float(hi)

# ---------- Threshold wrapper ----------
def threshold_sensitivity_composite(thresholds, n=2000, M=50, B_boot=100, seed=1234, mc_sample_size=20000,
                                    out_prefix="threshold_composite_results"):
    rng_master = ensure_rng(seed)
    summaries = []
    rows_all = []

    for th in thresholds:
        t0 = time.time()
        print(f"\nRunning threshold '{th['label']}': theta1={th['theta1']}, theta2={th['theta2']} (n={n}, M={M}, B={B_boot})")
        rng_n = ensure_rng(rng_master.randint(2**31 - 1))

        df_n = monte_carlo_sim(M=M, n=n, theta1=th['theta1'], theta2=th['theta2'],
                               seed=None, mc_sample_size=mc_sample_size)
        df_n['threshold_label'] = th['label']
        df_n['theta1'] = th['theta1']
        df_n['theta2'] = th['theta2']
        rows_all.append(df_n)

        df_comp = df_n[df_n['model']=='Composite']
        mean_avglog = df_comp['avg_loglike'].mean()
        mean_tailks = df_comp['tailks'].mean()
        mean_bias = df_comp['bias_mean'].mean()
        mean_mse = df_comp['mse_mean'].mean()

        ci_rng = ensure_rng(rng_n.randint(2**31 - 1))
        avglog_lo, avglog_hi = bootstrap_ci(df_n, 'Composite', 'avg_loglike', B=B_boot, rng=ci_rng)
        tailks_lo, tailks_hi = bootstrap_ci(df_n, 'Composite', 'tailks', B=B_boot, rng=ci_rng)
        bias_lo, bias_hi = bootstrap_ci(df_n, 'Composite', 'bias_mean', B=B_boot, rng=ci_rng)
        mse_lo, mse_hi = bootstrap_ci(df_n, 'Composite', 'mse_mean', B=B_boot, rng=ci_rng)

        summaries.append({
            'threshold_label': th['label'],
            'theta1': th['theta1'],
            'theta2': th['theta2'],
            'avglog_mean': mean_avglog, 'avglog_lo': avglog_lo, 'avglog_hi': avglog_hi,
            'tailks_mean': mean_tailks, 'tailks_lo': tailks_lo, 'tailks_hi': tailks_hi,
            'bias_mean': mean_bias, 'bias_lo': bias_lo, 'bias_hi': bias_hi,
            'mse_mean': mean_mse, 'mse_lo': mse_lo, 'mse_hi': mse_hi
        })

        # ---------- Print in aligned style ----------
        print(f"Composite ({th['label']}) 95% CI:")
        print(f"  Avg LOGLIKE: ({avglog_lo}, {avglog_hi})")
        print(f"  Tail KS   : ({tailks_lo}, {tailks_hi})")
        print(f"  Bias      : ({bias_lo}, {bias_hi})")
        print(f"  MSE       : ({mse_lo}, {mse_hi})")
        print(f"Finished {th['label']} in {time.time()-t0:.1f}s")

    df_all = pd.concat(rows_all, ignore_index=True)
    df_summary = pd.DataFrame(summaries)

    # ---------- Save CSV ----------
    csv_file = f"{out_prefix}_composite_thresholds_summary.csv"
    df_summary.to_csv(csv_file, index=False)

    # ---------- Save LaTeX ----------
    tex_file = f"{out_prefix}_composite_thresholds_summary.tex"
    try:
        with open(tex_file, "w", encoding="utf-8") as f:
            f.write(df_summary.to_latex(index=False, float_format="%.6f"))
    except Exception as e:
        print("Warning: could not write LaTeX file:", e)

    print(f"\nSaved CSV to {csv_file} and LaTeX to {tex_file} (if no error).")

    # ---------- Print full summary table ----------
    print("\nThreshold sensitivity summary (Composite only):")
    # Align columns nicely
    col_order = ['threshold_label', 'theta1', 'theta2',
                 'avglog_mean', 'avglog_lo', 'avglog_hi',
                 'tailks_mean', 'tailks_lo', 'tailks_hi',
                 'bias_mean', 'bias_lo', 'bias_hi',
                 'mse_mean', 'mse_lo', 'mse_hi']
    print(df_summary[col_order].to_string(index=False, float_format="%.6f"))
    return df_all, df_summary

# ---------- Main ----------
if __name__ == "__main__":
    thresholds = [
        {'theta1': -0.3, 'theta2': 0.3, 'label': 'Baseline'},
        {'theta1': -1.0, 'theta2': 0.2, 'label': 'Lower'},
        {'theta1': -0.2, 'theta2': 1.0, 'label': 'Higher'},
        {'theta1': -2.0, 'theta2': 2.0, 'label': 'Extreme'}
    ]

    df_all, df_summary = threshold_sensitivity_composite(
        thresholds,
        n=2000,
        M=50,
        B_boot=100,
        seed=1234,
        mc_sample_size=20000,
        out_prefix="threshold_composite_results"
    )



Running threshold 'Baseline': theta1=-0.3, theta2=0.3 (n=2000, M=50, B=100)
Composite (Baseline) 95% CI:
  Avg LOGLIKE: (-1.0398103048605658, -1.0293002942761744)
  Tail KS   : (0.013627500000000006, 0.014708150000000005)
  Bias      : (-0.0008430394612083925, 0.005624549347923798)
  MSE       : (0.18081600710600518, 0.18844987016737352)
Finished Baseline in 0.2s

Running threshold 'Lower': theta1=-1.0, theta2=0.2 (n=2000, M=50, B=100)
Composite (Lower) 95% CI:
  Avg LOGLIKE: (-1.208265932797394, -1.1928596444869954)
  Tail KS   : (0.033753875, 0.034639025000000004)
  Bias      : (-0.338940525538692, -0.33227448157096556)
  MSE       : (0.3902579554522684, 0.3990323971617788)
Finished Lower in 0.2s

Running threshold 'Higher': theta1=-0.2, theta2=1.0 (n=2000, M=50, B=100)
Composite (Higher) 95% CI:
  Avg LOGLIKE: (-0.9920327913942906, -0.9789177799795598)
  Tail KS   : (0.03322857499999998, 0.034048774999999996)
  Bias      : (0.3321206410907942, 0.34007717435211704)
  MSE       : (0.