# Geometric Theory of Learning Under Class Imbalance - Full Experiment Suite

This notebook runs all experiments and generates all figures for the paper **"A Geometric Theory of Learning Under Class Imbalance"**.

## Overview

We validate the following key claims:
1. Under label shift, only threshold/logit-offset updates are needed (no retraining)
2. Ranking metrics (AUC) are invariant under prevalence changes  
3. Reweighting/rebalancing reduces effective sample size and increases instability
4. Under concept drift, offset alone fails and retraining is necessary

## 1. Setup and Imports

In [None]:
# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add package to path
import sys
sys.path.insert(0, '..')

# Import our modules
from src.geomimb.utils.logging import setup_logging
from src.geomimb.utils.io import ensure_output_dirs, save_metadata
from src.geomimb.utils.checks import run_all_checks
from src.geomimb.plotting.plots import create_all_experiment_plots

# Import all experiments
from src.geomimb.experiments.exp1_label_shift_offset import run_experiment as run_exp1
from src.geomimb.experiments.exp2_auc_invariance import run_experiment as run_exp2
from src.geomimb.experiments.exp3_weighting_neff_instability import run_experiment as run_exp3
from src.geomimb.experiments.exp4_operating_point_metrics import run_experiment as run_exp4
from src.geomimb.experiments.exp5_concept_drift_control import run_experiment as run_exp5

# Setup
setup_logging(level='INFO')
ensure_output_dirs()

%matplotlib inline
plt.style.use('default')

print("Setup complete!")

## 2. Configuration Overview

In [None]:
# Display key configuration parameters
from src.geomimb import config

print("Experiment Configuration:")
print(f"- Training prevalence: {config.PI_TRAIN}")
print(f"- Test prevalences: {config.PI_TEST_GRID}")
print(f"- Cost settings: {config.COSTS_GRID}")
print(f"- Number of seeds: {len(config.SEEDS)}")
print(f"- Synthetic dimension: {config.SYNTH_DIM}")
print(f"- Alpha values (Exp 3): {config.ALPHA_VALUES}")

## 3. Data Generation Preview

Let's visualize the synthetic data generation process to understand the Gaussian mixture model.

In [None]:
# Preview synthetic data
from src.geomimb.data.synthetic import sample_synthetic_dataset, get_synthetic_params

# Get parameters
params = get_synthetic_params()
print("Synthetic Data Parameters:")
print(f"- Dimension: {params['d']}")
print(f"- mu0 (class 0): {params['mu0'][:5]}... (first 5 dims)")
print(f"- mu1 (class 1): {params['mu1'][:5]}... (first 5 dims)")
print(f"- Theoretical LLR coefficient: {params['theoretical_llr_coef'][:5]}...")
print(f"- Theoretical LLR constant: {params['theoretical_llr_const']:.3f}")

# Sample and visualize
X_sample, y_sample = sample_synthetic_dataset(pi=0.3, n=1000, seed=42)

# Plot first two discriminative dimensions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Dimension 0 (discriminative)
ax1.hist(X_sample[y_sample==0, 0], bins=30, alpha=0.5, label='Class 0', density=True)
ax1.hist(X_sample[y_sample==1, 0], bins=30, alpha=0.5, label='Class 1', density=True)
ax1.set_xlabel('Feature 0 (discriminative)')
ax1.set_ylabel('Density')
ax1.legend()
ax1.set_title('First Discriminative Dimension')

# Dimension 5 (non-discriminative)
ax2.hist(X_sample[y_sample==0, 5], bins=30, alpha=0.5, label='Class 0', density=True)
ax2.hist(X_sample[y_sample==1, 5], bins=30, alpha=0.5, label='Class 1', density=True)
ax2.set_xlabel('Feature 5 (non-discriminative)')
ax2.set_ylabel('Density')
ax2.legend()
ax2.set_title('Non-Discriminative Dimension')

plt.tight_layout()
plt.show()

## 4. Train Base Model and Verify Implementation

Before running full experiments, let's verify our implementation with a single model.

In [None]:
# Train a single model and check coefficients
from src.geomimb.models.sklearn_models import LogisticRegressionWrapper
from src.geomimb.seeds import set_global_seed

set_global_seed(0)

# Generate training data
X_train, y_train = sample_synthetic_dataset(pi=0.2, n=10000, seed=0)

# Train model
model = LogisticRegressionWrapper(random_state=0)
model.fit(X_train, y_train)

# Get coefficients
coef = model.get_coefficients()
intercept = model.get_intercept()

print("Trained Logistic Regression:")
print(f"- Coefficients (first 5): {coef[:5]}")
print(f"- Coefficients (last 5): {coef[-5:]}")
print(f"- Intercept: {intercept:.3f}")

# Compare with theoretical
theoretical_coef = params['theoretical_llr_coef']
cosine_sim = np.dot(coef, theoretical_coef) / (np.linalg.norm(coef) * np.linalg.norm(theoretical_coef))
print(f"\nCosine similarity to theoretical: {cosine_sim:.4f}")

## 5. Test Threshold Selection Function

Verify our precision-constrained threshold selection.

In [None]:
# Test threshold selection with synthetic arrays
from src.geomimb.metrics.calibration import tune_threshold_for_operating_point

# Create synthetic scores and labels
np.random.seed(42)
n_test = 1000
y_test = np.random.binomial(1, 0.3, n_test)
scores = np.random.randn(n_test) + y_test * 1.5  # Shifted scores for class 1

# Find threshold for precision >= 0.95
threshold, info = tune_threshold_for_operating_point(
    scores, y_test, 
    target_metric='precision', 
    target_value=0.95,
    constraint_type='min',
    optimize_metric='recall'
)

print(f"Found threshold: {threshold:.3f}")
print(f"Feasible: {info['feasible']}")
print(f"Best recall at precision >= 0.95: {info['best_recall']:.3f}")

# Verify by computing metrics at this threshold
from src.geomimb.metrics.classification import apply_threshold, compute_metrics
y_pred = apply_threshold(scores, threshold)
metrics = compute_metrics(y_test, y_pred)
print(f"\nActual metrics at threshold:")
print(f"- Precision: {metrics['precision']:.3f}")
print(f"- Recall: {metrics['recall']:.3f}")

## 6. Run Experiment 1: Label Shift Offset Correction

This experiment validates that under label shift, offset correction works and retraining is not needed.

In [None]:
%%time
print("Running Experiment 1: Label Shift Offset Correction")
print("="*60)

exp1_results = run_exp1(dataset_name='both', save_results=True)

print(f"\nCompleted with {len(exp1_results)} results")

# Show summary
summary = exp1_results.groupby(['dataset', 'method', 'pi_test']).agg({
    'auc': ['mean', 'std'],
    'risk': ['mean', 'std']
}).round(3)

print("\nAUC by method and prevalence (first dataset):")
print(summary.loc['synthetic', :, 'auc'])

## 7. Run Experiment 2: AUC Invariance

This experiment demonstrates that AUC is invariant to prevalence while PR-AUC depends on it.

In [None]:
%%time
print("Running Experiment 2: AUC Invariance")
print("="*60)

exp2_results = run_exp2(dataset_name='both', save_results=True)

print(f"\nCompleted with {len(exp2_results)} results")

# Check AUC invariance
for (dataset, model), group in exp2_results.groupby(['dataset', 'model']):
    auc_range = group['auc'].max() - group['auc'].min()
    pr_auc_range = group['pr_auc'].max() - group['pr_auc'].min()
    print(f"\n{dataset} - {model}:")
    print(f"  AUC range: {auc_range:.4f}")
    print(f"  PR-AUC range: {pr_auc_range:.4f}")

## 8. Run Experiment 3: Weighting Reduces Neff and Increases Instability

This experiment shows that class weighting reduces effective sample size and increases model instability.

In [None]:
%%time
print("Running Experiment 3: Weighting Neff and Instability")
print("="*60)

exp3_results = run_exp3(dataset_name='both', save_results=True)

print(f"\nCompleted with {len(exp3_results)} results")

# Show Neff by alpha
neff_summary = exp3_results.groupby(['dataset', 'alpha'])['neff'].mean().round(0)
print("\nEffective sample size by alpha:")
print(neff_summary)

## 9. Run Experiment 4: Operating Point Metrics

This experiment shows that operating point metrics drift without correction but are restored with offset.

In [None]:
%%time
print("Running Experiment 4: Operating Point Metrics")
print("="*60)

exp4_results = run_exp4(save_results=True)

print(f"\nCompleted with {len(exp4_results)} results")

# Show precision/recall drift
for method in ['nocorr', 'offset']:
    method_data = exp4_results[exp4_results['method'] == method]
    prec_by_pi = method_data.groupby('pi_test')['precision'].mean()
    print(f"\n{method} - Precision by prevalence:")
    print(prec_by_pi.round(3))

## 10. Run Experiment 5: Concept Drift Control

This experiment demonstrates that under concept drift, offset doesn't help but retraining does.

In [None]:
%%time
print("Running Experiment 5: Concept Drift Control")
print("="*60)

exp5_results = run_exp5(save_results=True)

print(f"\nCompleted with {len(exp5_results)} results")

# Show performance by method
perf_summary = exp5_results.groupby('method')[['auc', 'risk']].mean().round(3)
print("\nPerformance under concept drift:")
print(perf_summary)

## 11. Run All Acceptance Checks

Verify that all experiments meet the acceptance criteria.

In [None]:
# Collect all results
all_results = {
    'exp1': exp1_results,
    'exp2': exp2_results,
    'exp3': exp3_results,
    'exp4': exp4_results,
    'exp5': exp5_results
}

# Run acceptance checks
checks = run_all_checks(all_results)

print("Acceptance Check Results:")
print("="*60)

for check_name, check_result in checks.items():
    if check_name == 'overall_passed':
        continue
    
    passed = check_result.get('passed', True)
    status = "✓ PASSED" if passed else "✗ FAILED"
    print(f"{check_name}: {status}")
    
    if not passed and 'violations' in check_result:
        print(f"  Violations: {len(check_result['violations'])}")
        # Show first violation
        if check_result['violations']:
            print(f"  Example: {check_result['violations'][0]}")

print("\n" + "="*60)
overall = "✓ ALL CHECKS PASSED" if checks['overall_passed'] else "✗ SOME CHECKS FAILED"
print(f"Overall: {overall}")

## 12. Generate All Plots

Create all figures for the paper.

In [None]:
# Create all plots
print("Generating all plots...")
create_all_experiment_plots(all_results)
print("All plots saved to outputs/figures/")

# Display a few key plots inline
from IPython.display import Image, display
import os

key_plots = [
    'exp1_auc_vs_pi.png',
    'exp1_risk_vs_pi.png', 
    'exp3_neff_vs_alpha.png',
    'exp5_concept_drift.png'
]

for plot_file in key_plots:
    plot_path = f'../outputs/figures/{plot_file}'
    if os.path.exists(plot_path):
        print(f"\n{plot_file}:")
        display(Image(plot_path, width=600))

## 13. Create Paper Summary Table

Generate the final summary table for the paper.

In [None]:
from src.geomimb.utils.io import create_paper_summary

# Combine all results
combined_results = pd.concat(list(all_results.values()), ignore_index=True)

# Create summary
summary_df = create_paper_summary(combined_results)

print("Paper Summary Table:")
print(summary_df.to_string())

# Also save as LaTeX
latex_table = summary_df.to_latex(index=False, float_format="%.3f")
with open('../outputs/tables/paper_summary.tex', 'w') as f:
    f.write(latex_table)

print("\nLaTeX table saved to outputs/tables/paper_summary.tex")

## 14. Final Summary

Summarize the key findings that support the paper's claims.

In [None]:
print("EXPERIMENT SUITE SUMMARY")
print("="*60)

# Key findings
print("\nKey Findings:")

# 1. AUC invariance
auc_check = checks.get('exp1_auc_invariance', {})
if auc_check.get('passed', False):
    print("✓ AUC is invariant to prevalence changes (max range < 0.01)")

# 2. Offset improvement
offset_check = checks.get('exp1_offset_improvement', {})
if offset_check.get('passed', False):
    print("✓ Offset correction reduces risk at extreme imbalance")

# 3. Neff monotonicity
neff_check = checks.get('exp3_neff_monotonicity', {})
if neff_check.get('passed', False):
    print("✓ Weighting monotonically reduces effective sample size")

# 4. Concept drift
drift_check = checks.get('exp5_concept_drift', {})
if drift_check.get('passed', False):
    print("✓ Under concept drift, offset fails but retraining helps")

print("\n" + "="*60)
print("All experiments completed successfully!")
print(f"Results saved in: outputs/")
print(f"- Tables: outputs/tables/")
print(f"- Figures: outputs/figures/")
print(f"- Metadata: outputs/metadata/")