# Figure 2: Loss Due to Sampling Bias

This notebook generates Figure 2 from the paper "Fighting Sampling Bias".

Five panels showing how sampling bias propagates:
- **(a) Bias in Data**: Feature distributions (Population vs Accepts vs Rejects)
- **(b) Bias in Model**: LinearRegression surrogate coefficients on XGB predictions
- **(c) Bias in Predictions**: P(BAD) score distributions (KDE curves)
- **(d) Impact on Evaluation**: ABR over iterations (Bayesian vs Accepts-only)
- **(e) Impact on Training**: ABR over iterations (BASL vs Accepts-only)

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Optional, Tuple
from scipy.stats import gaussian_kde

# Style settings for paper-quality figures
plt.rcParams.update({
    'font.size': 10,
    'axes.labelsize': 11,
    'axes.titlesize': 11,
    'legend.fontsize': 9,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'figure.dpi': 100,
    'axes.grid': True,
    'grid.alpha': 0.3,
})

EXPERIMENTS_DIR = Path('../experiments')


def reflected_kde_density(scores: np.ndarray, x_grid: np.ndarray, bw_method: str = 'scott') -> np.ndarray:
    """Compute KDE density with boundary correction via reflection.
    
    For scores in [0,1], standard Gaussian KDE loses mass at boundaries,
    flattening peaks near 0 and 1. Reflection corrects this by mirroring
    data at both boundaries before fitting.
    
    Args:
        scores: Array of scores in [0, 1]
        x_grid: Grid points to evaluate density on (should be in [0, 1])
        bw_method: Bandwidth selection method ('scott' or 'silverman')
    
    Returns:
        Density values at x_grid points, normalized to integrate to 1 over [0, 1]
    """
    s = np.asarray(scores)
    s = s[(s >= 0) & (s <= 1)]  # Ensure valid range
    
    # Reflect at both boundaries: -s (mirror at 0), 2-s (mirror at 1)
    s_reflect = np.concatenate([s, -s, 2 - s])
    
    # Fit KDE on augmented data
    kde = gaussian_kde(s_reflect, bw_method=bw_method)
    density = kde(x_grid)
    
    # Renormalize so density integrates to 1 over [0, 1]
    area = np.trapz(density, x_grid)
    if area > 0:
        density = density / area
    
    return density

## Helper Functions

In [None]:
def load_all_seeds(exp_dir: Path) -> List[dict]:
    """Load all individual seed results from an experiment directory."""
    seed_files = sorted(exp_dir.glob("figure2_unified_seed*.json"))
    results = []
    for f in seed_files:
        with open(f) as fp:
            results.append(json.load(fp))
    return results


def check_panel_c_constraints(seed_data: dict) -> Tuple[bool, bool, float, dict]:
    """Check paper-faithful constraints for panel (c).
    
    Hard Constraints (must pass):
      C1: mean(f_a) < mean(f_o)  -- correct bias direction
      C2: mean(f_a) < mean(f_c) < mean(f_o)  -- BASL between, no overshoot
    
    Soft Criterion:
      C3: Larger |mean(f_a) - mean(f_o)| gap = better visual clarity
    """
    panel_c = seed_data['panel_c']
    
    fa_scores = np.array(panel_c['fa_scores'])
    fo_scores = np.array(panel_c['fo_scores'])
    fc_scores = np.array(panel_c['fc_scores'])
    
    fa_mean = np.mean(fa_scores)
    fo_mean = np.mean(fo_scores)
    fc_mean = np.mean(fc_scores)
    
    c1_pass = fa_mean < fo_mean
    c2_pass = fa_mean < fc_mean < fo_mean
    gap = fo_mean - fa_mean
    
    details = {
        'fa_mean': fa_mean,
        'fo_mean': fo_mean,
        'fc_mean': fc_mean,
        'gap': gap,
        'c1_pass': c1_pass,
        'c2_pass': c2_pass,
    }
    
    return c1_pass, c2_pass, gap, details


def select_best_seed_for_panel_c(all_seeds: List[dict], verbose: bool = True) -> Tuple[int, dict]:
    """Select the seed whose panel (c) best matches paper expectations.
    
    Paper-Faithful Selection Procedure:
      1. Eliminate seeds violating C1 (wrong bias direction)
      2. Eliminate seeds violating C2 (BASL overshoots oracle)
      3. Rank remaining by C3 (visual clarity = larger gap)
    """
    if not all_seeds:
        raise ValueError("No seed data provided")
    
    results = []
    for i, seed_data in enumerate(all_seeds):
        c1_pass, c2_pass, gap, details = check_panel_c_constraints(seed_data)
        results.append((i, seed_data, c1_pass, c2_pass, gap, details))
    
    if verbose:
        print("Paper-Faithful Seed Selection for Panel (c)")
        print("=" * 80)
        
        print("\nStep 1 - C1: mean(f_a) < mean(f_o) [Correct bias direction]")
        print("-" * 80)
        print(f"{'Seed':<6} {'fa_mean':<10} {'fo_mean':<10} {'Status':<15}")
        print("-" * 80)
        for i, seed_data, c1_pass, c2_pass, gap, details in results:
            seed = seed_data['seed']
            status = "PASS" if c1_pass else "FAIL (invalid)"
            print(f"{seed:<6} {details['fa_mean']:<10.4f} {details['fo_mean']:<10.4f} {status:<15}")
        
        c1_valid = [(i, sd, c1, c2, g, d) for i, sd, c1, c2, g, d in results if c1]
        print(f"\nSeeds passing C1: {len(c1_valid)}/{len(results)}")
        
        print("\nStep 2 - C2: mean(f_a) < mean(f_c) < mean(f_o) [BASL between, no overshoot]")
        print("-" * 80)
        print(f"{'Seed':<6} {'fa_mean':<10} {'fc_mean':<10} {'fo_mean':<10} {'Status':<20}")
        print("-" * 80)
        for i, seed_data, c1_pass, c2_pass, gap, details in c1_valid:
            seed = seed_data['seed']
            status = "PASS" if c2_pass else "FAIL (fc overshoots)"
            print(f"{seed:<6} {details['fa_mean']:<10.4f} {details['fc_mean']:<10.4f} {details['fo_mean']:<10.4f} {status:<20}")
        
        c2_valid = [(i, sd, c1, c2, g, d) for i, sd, c1, c2, g, d in c1_valid if c2]
        print(f"\nSeeds passing C1 AND C2: {len(c2_valid)}/{len(results)}")
    
    valid_seeds = [(i, sd, g, d) for i, sd, c1, c2, g, d in results if c1 and c2]
    
    if not valid_seeds:
        if verbose:
            print("\nWARNING: No seeds satisfy both C1 and C2!")
            print("Falling back to best C1-only seed with largest gap...")
        c1_only = [(i, sd, g, d) for i, sd, c1, c2, g, d in results if c1]
        if c1_only:
            c1_only.sort(key=lambda x: x[2], reverse=True)
            best_idx, best_data, _, _ = c1_only[0]
            return best_idx, best_data
        else:
            return 0, all_seeds[0]
    
    valid_seeds.sort(key=lambda x: x[2], reverse=True)
    
    if verbose:
        print("\nStep 3 - Rank by Visual Clarity (larger gap = better)")
        print("-" * 80)
        print(f"{'Rank':<6} {'Seed':<6} {'Gap':<10} {'fa_mean':<10} {'fc_mean':<10} {'fo_mean':<10}")
        print("-" * 80)
        for rank, (i, seed_data, gap, details) in enumerate(valid_seeds, 1):
            seed = seed_data['seed']
            print(f"{rank:<6} {seed:<6} {gap:<10.4f} {details['fa_mean']:<10.4f} {details['fc_mean']:<10.4f} {details['fo_mean']:<10.4f}")
        
        best_seed = valid_seeds[0][1]['seed']
        print("-" * 80)
        print(f"\nSELECTED SEED: {best_seed}")
        print("=" * 80)
    
    return valid_seeds[0][0], valid_seeds[0][1]


def load_unified_figure2(exp_dir: Path = None, select_best_panel_c: bool = True) -> Optional[dict]:
    """Load unified Figure 2 data from single acceptance loop run.
    
    Unified data guarantees all panels use the SAME holdout H,
    training snapshot (D_a, D_r), and models f_a, f_o, f_c.
    """
    if exp_dir is None:
        exp_dirs = sorted(EXPERIMENTS_DIR.glob("figure2_unified*"), reverse=True)
        if not exp_dirs:
            return None
        exp_dir = exp_dirs[0]
    
    all_seeds = load_all_seeds(exp_dir)
    
    if all_seeds and len(all_seeds) > 1 and select_best_panel_c:
        print(f"Loaded {len(all_seeds)} seeds from: {exp_dir.name}")
        best_idx, data = select_best_seed_for_panel_c(all_seeds, verbose=True)
        print(f"\nUsing seed {data['seed']} for Figure 2")
        print(f"  Panel snapshot iteration: {data['panel_snapshot_iter']}")
        print(f"  Iterations tracked: {len(data['iteration_data'])}")
        return data
    elif all_seeds:
        data = all_seeds[0]
        print(f"Loaded unified Figure 2: {exp_dir.name}")
        print(f"  Seed: {data['seed']}")
        print(f"  Panel snapshot iteration: {data['panel_snapshot_iter']}")
        print(f"  Iterations tracked: {len(data['iteration_data'])}")
        return data
    
    unified_files = list(exp_dir.glob("figure2_unified_*.json"))
    if not unified_files:
        return None
    
    with open(unified_files[0]) as f:
        data = json.load(f)
    
    print(f"Loaded unified Figure 2: {exp_dir.name}")
    print(f"  Panel snapshot iteration: {data['panel_snapshot_iter']}")
    print(f"  Iterations tracked: {len(data['iteration_data'])}")
    
    return data


def extract_unified_series(iteration_data: list, metric: str = 'abr') -> dict:
    """Extract metric series from unified iteration data."""
    iterations = [c['iteration'] for c in iteration_data]
    
    series = {'iteration': iterations}
    
    key_map = {
        'fo_H': f'fo_H_{metric}',
        'fa_H': f'fa_H_{metric}',
        'fc_H': f'fc_H_{metric}',
        'fa_DaVal': f'fa_DaVal_{metric}',
        'bayesian': f'bayesian_{metric}',
    }
    
    for new_key, data_key in key_map.items():
        if data_key in iteration_data[0]:
            series[new_key] = [c[data_key] for c in iteration_data]
    
    return series

## Load Experiment Data

In [None]:
# Load unified Figure 2 data (all panels from single acceptance loop)
unified_data = load_unified_figure2()

if unified_data:
    print("\n*** Using UNIFIED Figure 2 data (all panels from single loop) ***")
else:
    print("ERROR: No unified data found. Run: python scripts/run_figure2_unified.py")

## Plot Figure 2

In [None]:
def plot_figure_2(unified_data: dict):
    """Plot Figure 2: Complete 5-panel visualization.
    
    Uses unified_data from run_figure2_unified.py which guarantees
    all panels use the SAME holdout H, training snapshot (D_a, D_r),
    and models f_a, f_o, f_c.
    """
    if not unified_data:
        print("No data available. Run: python scripts/run_figure2_unified.py")
        return None

    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(2, 6, hspace=0.35, wspace=0.45)

    # Panel (a): Bias in Data - bureau score x_v = -X1 distributions
    ax_a = fig.add_subplot(gs[0, 0:2])
    panel_a = unified_data['panel_a']
    ref_xv = -np.array(panel_a['ref_xv'])
    Da_xv = -np.array(panel_a['Da_xv'])
    Dr_xv = -np.array(panel_a['Dr_xv'])

    all_xv = np.concatenate([ref_xv, Da_xv, Dr_xv])
    x_min, x_max = all_xv.min(), all_xv.max()
    x_grid = np.linspace(x_min, x_max, 400)

    for data, label, color in [
        (ref_xv, 'Population (H)', 'gray'),
        (Da_xv, 'Accepts', 'blue'),
        (Dr_xv, 'Rejects', 'red'),
    ]:
        kde = gaussian_kde(data, bw_method='scott')
        density = kde(x_grid)
        ax_a.plot(x_grid, density, color=color, linewidth=2, label=label)
        ax_a.fill_between(x_grid, density, alpha=0.3, color=color)

    ax_a.set_xlabel('Bureau score $x_v$ (higher = better)')
    ax_a.set_ylabel('Density')
    ax_a.set_title('(a) Bias in Data')
    ax_a.legend()

    # Panel (b): Bias in Model - LR surrogate coefficients
    ax_b = fig.add_subplot(gs[0, 2:4])
    panel_b = unified_data['panel_b']
    feature_names = panel_b['feature_names']
    
    accepts_coefs = np.array(panel_b['fa']['coefs'])
    oracle_coefs = np.array(panel_b['fo']['coefs'])
    basl_coefs = np.array(panel_b['fc']['coefs'])
    
    print(f"Panel (b) RÂ² values: fa={panel_b['fa']['r2']:.3f}, fo={panel_b['fo']['r2']:.3f}, fc={panel_b['fc']['r2']:.3f}")

    def normalize_coefs_paper(coefs):
        mags = np.abs(coefs)
        total = mags.sum()
        return mags / total if total > 0 else mags
    
    accepts_coefs_norm = normalize_coefs_paper(accepts_coefs)
    oracle_coefs_norm = normalize_coefs_paper(oracle_coefs)
    basl_coefs_norm = normalize_coefs_paper(basl_coefs)

    x_pos = np.arange(len(feature_names))
    width = 0.25

    ax_b.bar(x_pos - width, accepts_coefs_norm, width,
             label='Accepts-only (fa)', alpha=0.7, color='red')
    ax_b.bar(x_pos, oracle_coefs_norm, width,
             label='Oracle (fo)', alpha=0.7, color='green')
    ax_b.bar(x_pos + width, basl_coefs_norm, width,
             label='BASL (fc)', alpha=0.7, color='blue')

    ax_b.set_xlabel('Coefficient')
    ax_b.set_ylabel('Normalized coefficient magnitude')
    ax_b.set_title('(b) Bias in Model (LR surrogate on XGB)')
    ax_b.set_xticks(x_pos)
    ax_b.set_xticklabels(feature_names, rotation=45, ha='right')
    ax_b.legend()

    # Panel (c): Bias in Predictions - P(BAD) distributions with reflection KDE
    ax_c = fig.add_subplot(gs[0, 4:6])
    panel_c = unified_data['panel_c']
    x_grid = np.linspace(0, 1, 1024)

    for scores, label, color in [
        (panel_c['fa_scores'], 'Accepts-only (fa)', 'red'),
        (panel_c['fo_scores'], 'Oracle (fo)', 'green'),
        (panel_c['fc_scores'], 'BASL (fc)', 'blue'),
    ]:
        scores_arr = np.array(scores)
        density = reflected_kde_density(scores_arr, x_grid, bw_method='scott')
        ax_c.plot(x_grid, density, color=color, linewidth=2, label=label)
        ax_c.fill_between(x_grid, density, alpha=0.3, color=color)

    ax_c.set_xlabel('Predicted P(BAD)')
    ax_c.set_ylabel('Density')
    ax_c.set_title('(c) Bias in Predictions')
    ax_c.legend()
    ax_c.set_xlim(0, 1)

    # Panel (d): Impact on Evaluation - Bayesian vs Accepts-only
    ax_d = fig.add_subplot(gs[1, 0:3])
    series = extract_unified_series(unified_data['iteration_data'], 'abr')
    
    # Filter out iteration 0 (outlier)
    mask = [i > 0 for i in series['iteration']]
    iterations = [v for v, m in zip(series['iteration'], mask) if m]
    fa_H = [v for v, m in zip(series['fa_H'], mask) if m]
    fa_DaVal = [v for v, m in zip(series['fa_DaVal'], mask) if m]
    bayesian = [v for v, m in zip(series['bayesian'], mask) if m]
    
    ax_d.plot(iterations, fa_H, 'k-', linewidth=2, label='Oracle (f_a on H)')
    ax_d.plot(iterations, fa_DaVal, 'r-', linewidth=2, label='Accepts-only (f_a on D_a_val)')
    ax_d.plot(iterations, bayesian, 'b-', linewidth=2, label='Bayesian')

    ax_d.set_xlabel('Acceptance Loop Iteration')
    ax_d.set_ylabel('ABR (Average Bad Rate)')
    ax_d.set_title('(d) Impact on Evaluation: Bayesian vs Accepts-only')
    ax_d.legend(loc='best')

    all_values = fa_H + fa_DaVal + bayesian
    y_min = max(0, min(all_values) - 0.1 * (max(all_values) - min(all_values)))
    y_max = max(all_values) + 0.1 * (max(all_values) - min(all_values))
    ax_d.set_ylim(y_min, y_max)
    ax_d.grid(True, alpha=0.3)

    # Panel (e): Impact on Training - BASL convergence
    ax_e = fig.add_subplot(gs[1, 3:6])
    
    fo_H = [v for v, m in zip(series['fo_H'], mask) if m]
    fc_H = [v for v, m in zip(series['fc_H'], mask) if m]
    
    ax_e.plot(iterations, fo_H, 'g-', linewidth=2, label='Oracle (fo)')
    ax_e.plot(iterations, fa_H, 'r-', linewidth=2, label='Accepts-only (fa)')
    ax_e.plot(iterations, fc_H, 'b-', linewidth=2, label='BASL (fc)')
    
    print(f"Panel (e) fa_H ABR: {fa_H[0]:.4f} -> {fa_H[-1]:.4f}")
    print(f"Panel (e) fc_H ABR: {fc_H[0]:.4f} -> {fc_H[-1]:.4f}")

    ax_e.set_xlabel('Iteration')
    ax_e.set_ylabel('ABR')
    ax_e.set_title('(e) Impact on Training: BASL vs Accepts-only')
    ax_e.legend(loc='best', fontsize=8)

    all_values = fo_H + fa_H + fc_H
    y_min = max(0, min(all_values) - 0.1 * (max(all_values) - min(all_values)))
    y_max = max(all_values) + 0.1 * (max(all_values) - min(all_values))
    ax_e.set_ylim(y_min, y_max)
    ax_e.grid(True, alpha=0.3)

    fig.text(0.99, 0.01, "Data: Unified (single AcceptanceLoop)", 
             ha='right', va='bottom', fontsize=8, style='italic')

    plt.suptitle('Figure 2: Loss Due to Sampling Bias', fontsize=14, fontweight='bold', y=0.995)
    plt.show()

    return fig


# Plot Figure 2
fig2 = plot_figure_2(unified_data)