In [None]:
# ============================================================
# COMPREHENSIVE LIGHTGBM PIPELINE WITH ABLATION & REPORTING
# Complete analysis including feature importance, ablation, and visualizations
# ============================================================


# Feature ranges: 
# - medical_patterns: 0-42 (43 features)
# - word_char_features: 43-3062 (3020 features)
# - sentence_embeddings: 3063-3448 (386 features)
# - transformer_scores: 3449-3450 (2 features)

import os
import numpy as np
import pandas as pd
import lightgbm as lgb
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from datasets import load_dataset
from scipy import sparse
from sklearn.metrics import (
    roc_auc_score, f1_score, precision_score, recall_score,
    balanced_accuracy_score, confusion_matrix, classification_report,
    roc_curve, precision_recall_curve, auc
)
from sklearn.model_selection import StratifiedKFold

# ============================================================
# CONFIGURATION
# ============================================================

OUTPUT_DIR = "out-final"
FEATURE_DIR = "features"
PREFIX = "split_tr8000_word50_char3000_ng3-7_all-MiniLM-L6-v2_FULL"

# Dataset revision (use local cache)
DATASET_REVISION = "ee5e7f3c00400ab56b2aa407c2d9088c9d0b01db"

# Create output directory
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

# ============================================================
# LOAD DATA
# ============================================================

print("="*80)
print("LOADING DATASET")
print("="*80)

# Load specific revision from local cache
print(f"Using dataset revision: {DATASET_REVISION}")

ds_train = load_dataset(
    "sssohrab/ct-dosing-errors-benchmark", 
    split="train",
    revision=DATASET_REVISION,
    download_mode="reuse_dataset_if_exists"
)
ds_val = load_dataset(
    "sssohrab/ct-dosing-errors-benchmark", 
    split="validation",
    revision=DATASET_REVISION,
    download_mode="reuse_dataset_if_exists"
)
ds_test = load_dataset(
    "sssohrab/ct-dosing-errors-benchmark", 
    split="test",
    revision=DATASET_REVISION,
    download_mode="reuse_dataset_if_exists"
)

df_train = ds_train.to_pandas()
df_val   = ds_val.to_pandas()
df_test  = ds_test.to_pandas()

y_train = df_train["target"].values
y_val   = df_val["target"].values
y_test  = df_test["target"].values

print(f"Train: {len(y_train)} samples, {y_train.sum()} positive ({y_train.mean()*100:.2f}%)")
print(f"Val:   {len(y_val)} samples, {y_val.sum()} positive ({y_val.mean()*100:.2f}%)")
print(f"Test:  {len(y_test)} samples, {y_test.sum()} positive ({y_test.mean()*100:.2f}%)")

# ============================================================
# LOAD FEATURES
# ============================================================

print("\n" + "="*80)
print("LOADING FEATURES")
print("="*80)

X_train = sparse.load_npz(os.path.join(FEATURE_DIR, f"{PREFIX}_X_train.npz"))
X_val   = sparse.load_npz(os.path.join(FEATURE_DIR, f"{PREFIX}_X_val.npz"))
X_test  = sparse.load_npz(os.path.join(FEATURE_DIR, f"{PREFIX}_X_test.npz"))

print(f"Train features: {X_train.shape}")
print(f"Val features:   {X_val.shape}")
print(f"Test features:  {X_test.shape}")

# Combine train + val for cross-validation
X_all = sparse.vstack([X_train, X_val])
y_all = np.concatenate([y_train, y_val])

print(f"\nCombined for CV: {X_all.shape}")

# ============================================================
# FEATURE CATEGORY IDENTIFICATION
# ============================================================

print("\n" + "="*80)
print("IDENTIFYING FEATURE CATEGORIES")
print("="*80)

n_features = X_train.shape[1]

# Try to infer categories or use known structure
if n_features == 3449:
    feature_categories = {
        'medical_patterns': (0, 43),
        'word_char_features': (43, 3043),
        'sentence_embeddings': (3043, 3427),
        'transformer_scores': (3427, 3449)
    }
    print("Using known feature structure (3449 features)")
elif n_features == 3451:
    # Legacy support for old padded version
    feature_categories = {
        'medical_patterns': (0, 43),
        'word_char_features': (43, 3063),
        'sentence_embeddings': (3063, 3449),
        'transformer_scores': (3449, 3451)
    }
    print("Using legacy feature structure (3451 features with padding)")
else:
    # Generic categorization
    print(f"[WARNING] Unknown feature structure ({n_features} features)")
    print("   Skipping category-based analysis")
    feature_categories = None

if feature_categories:
    for cat_name, (start, end) in feature_categories.items():
        n_cat = end - start
        print(f"  {cat_name:25s}: indices {start:4d}-{end-1:4d} ({n_cat:4d} features)")

# ============================================================
# CLASS IMBALANCE
# ============================================================

print("\n" + "="*80)
print("CLASS IMBALANCE ANALYSIS")
print("="*80)

neg = (y_all == 0).sum()
pos = (y_all == 1).sum()
scale_pos_weight = neg / pos

print(f"Negative samples: {neg} ({neg/len(y_all)*100:.2f}%)")
print(f"Positive samples: {pos} ({pos/len(y_all)*100:.2f}%)")
print(f"Imbalance ratio:  {neg/pos:.2f}:1")
print(f"scale_pos_weight: {scale_pos_weight:.2f}")

# ============================================================
# MODEL PARAMETERS
# ============================================================

print("\n" + "="*80)
print("MODEL CONFIGURATION")
print("="*80)

# Best params from Optuna (trial 18)
params = {
    "learning_rate": 0.0054,
    "num_leaves": 118,
    "max_depth": 9,
    "min_child_samples": 211,
    "feature_fraction": 0.795,
    "bagging_fraction": 0.813,
    "bagging_freq": 1,
    "lambda_l1": 4.29,
    "lambda_l2": 4.33,
    "n_estimators": 4000,
    "scale_pos_weight": scale_pos_weight,
    "objective": "binary",
    "metric": "auc",
    "n_jobs": -1,
    "verbosity": -1,
}

print("Hyperparameters:")
for key, value in params.items():
    if key not in ['objective', 'metric', 'n_jobs', 'verbosity']:
        print(f"  {key:25s}: {value}")

# ============================================================
# TRAIN 5-FOLD ENSEMBLE
# ============================================================

print("\n" + "="*80)
print("TRAINING 5-FOLD ENSEMBLE")
print("="*80)

models = []
fold_results = []
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for fold, (tr_idx, va_idx) in enumerate(kf.split(X_all, y_all)):
    print(f"\n{'='*60}")
    print(f"Fold {fold+1}/5")
    print(f"{'='*60}")

    X_tr, X_va = X_all[tr_idx], X_all[va_idx]
    y_tr, y_va = y_all[tr_idx], y_all[va_idx]
    
    print(f"Train: {X_tr.shape[0]} samples, {y_tr.sum()} positive")
    print(f"Val:   {X_va.shape[0]} samples, {y_va.sum()} positive")

    model = lgb.LGBMClassifier(**params)

    model.fit(
        X_tr, y_tr,
        eval_set=[(X_va, y_va)],
        eval_metric="auc",
        callbacks=[
            lgb.early_stopping(200, verbose=False),
            lgb.log_evaluation(100)
        ]
    )

    models.append(model)
    
    # Evaluate fold
    va_pred_proba = model.predict_proba(X_va)[:,1]
    va_auc = roc_auc_score(y_va, va_pred_proba)
    
    fold_results.append({
        'fold': fold + 1,
        'val_auc': va_auc,
        'n_iterations': model.best_iteration_
    })
    
    print(f"Fold {fold+1} Val AUC: {va_auc:.6f}")
    print(f"Best iteration: {model.best_iteration_}")

# Save fold results
fold_df = pd.DataFrame(fold_results)
fold_df.to_csv(output_path / 'fold_results.csv', index=False)
print(f"\n[OK] Fold results saved to: {output_path / 'fold_results.csv'}")

print("\n" + "="*80)
print("ENSEMBLE TRAINING SUMMARY")
print("="*80)
print(f"Mean Val AUC: {fold_df['val_auc'].mean():.6f} +/- {fold_df['val_auc'].std():.6f}")
print(f"Mean iterations: {fold_df['n_iterations'].mean():.0f} +/- {fold_df['n_iterations'].std():.0f}")

# ============================================================
# ENSEMBLE PREDICTION FUNCTION
# ============================================================

def predict_ensemble(X):
    """Predict using ensemble averaging"""
    preds = np.zeros(X.shape[0])
    for m in models:
        preds += m.predict_proba(X)[:,1]
    return preds / len(models)

# ============================================================
# OOF PREDICTIONS (NO LEAKAGE)
# ============================================================

print("\n" + "="*80)
print("COMPUTING OUT-OF-FOLD PREDICTIONS")
print("="*80)

oof_preds = np.zeros(X_all.shape[0])

for fold, (_, va_idx) in enumerate(kf.split(X_all, y_all)):
    model = models[fold]
    oof_preds[va_idx] = model.predict_proba(X_all[va_idx])[:,1]

oof_auc = roc_auc_score(y_all, oof_preds)
print(f"OOF ROC-AUC: {oof_auc:.6f}")

# ============================================================
# THRESHOLD OPTIMIZATION
# ============================================================

print("\n" + "="*80)
print("THRESHOLD OPTIMIZATION")
print("="*80)

# Find optimal thresholds for different objectives
def find_optimal_threshold(y_true, y_pred, metric='f1', min_precision=0.3):
    """Find optimal threshold for given metric"""
    thresholds = np.linspace(0.01, 0.99, 200)
    best_score = -1
    best_threshold = 0.5
    
    for t in thresholds:
        pred = (y_pred >= t).astype(int)
        
        # Check minimum precision
        prec = precision_score(y_true, pred, zero_division=0)
        if prec < min_precision:
            continue
        
        if metric == 'f1':
            score = f1_score(y_true, pred, zero_division=0)
        elif metric == 'recall':
            score = recall_score(y_true, pred, zero_division=0)
        elif metric == 'balanced_accuracy':
            score = balanced_accuracy_score(y_true, pred)
        
        if score > best_score:
            best_score = score
            best_threshold = t
    
    return best_threshold, best_score

# Optimize for different objectives
objectives = ['f1', 'recall', 'balanced_accuracy']
optimal_thresholds = {}

print("\nOptimal thresholds (OOF):")
print(f"{'Objective':<20} {'Threshold':>12} {'Score':>12}")
print("-"*50)

for obj in objectives:
    thresh, score = find_optimal_threshold(y_all, oof_preds, metric=obj, min_precision=0.3)
    optimal_thresholds[obj] = thresh
    print(f"{obj:<20} {thresh:>12.4f} {score:>12.4f}")

# Use F1-optimized threshold as default
best_thr = optimal_thresholds['f1']
print(f"\n[OK] Using F1-optimized threshold: {best_thr:.4f}")

# ============================================================
# FEATURE IMPORTANCE ANALYSIS
# ============================================================

print("\n" + "="*80)
print("FEATURE IMPORTANCE ANALYSIS")
print("="*80)

# Average feature importance across all folds
all_importances = []
for model in models:
    all_importances.append(model.feature_importances_)

avg_importance = np.mean(all_importances, axis=0)
std_importance = np.std(all_importances, axis=0)

# Create feature importance dataframe
importance_df = pd.DataFrame({
    'feature_idx': range(len(avg_importance)),
    'importance_mean': avg_importance,
    'importance_std': std_importance
}).sort_values('importance_mean', ascending=False)

# Add category labels if available
if feature_categories:
    def get_category(idx):
        for cat_name, (start, end) in feature_categories.items():
            if start <= idx < end:
                return cat_name
        return 'unknown'
    
    importance_df['category'] = importance_df['feature_idx'].apply(get_category)

# Top 50 features
top_50 = importance_df.head(50)
print("\nTop 50 Most Important Features:")
print(top_50.to_string(index=False))

# Save full importance
importance_df.to_csv(output_path / 'feature_importance_complete.csv', index=False)
print(f"\n[OK] Saved: {output_path / 'feature_importance_complete.csv'}")

# Category-wise importance
if feature_categories:
    print("\n" + "="*80)
    print("IMPORTANCE BY CATEGORY")
    print("="*80)
    
    category_importance = importance_df.groupby('category').agg({
        'importance_mean': ['sum', 'mean', 'count']
    }).round(6)
    category_importance.columns = ['total_gain', 'avg_gain', 'n_features']
    category_importance = category_importance.sort_values('total_gain', ascending=False)
    category_importance['total_gain_pct'] = category_importance['total_gain'] / category_importance['total_gain'].sum() * 100
    
    print(category_importance)
    
    # Save category importance
    category_importance.to_csv(output_path / 'category_importance.csv')
    print(f"\n[OK] Saved: {output_path / 'category_importance.csv'}")

# ============================================================
# ABLATION STUDY
# ============================================================

if feature_categories:
    print("\n" + "="*80)
    print("ABLATION STUDY")
    print("="*80)
    
    ablation_results = []
    
    # Baseline (all features)
    print("\nBaseline (all features):")
    baseline_auc = oof_auc
    baseline_pred = (oof_preds >= best_thr).astype(int)
    baseline_f1 = f1_score(y_all, baseline_pred)
    
    print(f"  ROC-AUC: {baseline_auc:.6f}")
    print(f"  F1:      {baseline_f1:.6f}")
    
    ablation_results.append({
        'configuration': 'baseline (all features)',
        'n_features': X_all.shape[1],
        'removed_category': None,
        'auc_mean': baseline_auc,
        'auc_std': 0.0,
        'f1': baseline_f1,
        'auc_delta': 0.0,
        'auc_delta_pct': 0.0
    })
    
    # Remove each category
    for cat_name, (start, end) in feature_categories.items():
        print(f"\nRemoving: {cat_name} (features {start}-{end-1})")
        
        # Create feature mask
        mask = np.ones(X_all.shape[1], dtype=bool)
        mask[start:end] = False
        
        # Get reduced features
        X_reduced = X_all[:, mask]
        n_features_kept = X_reduced.shape[1]
        n_features_removed = end - start
        
        print(f"  Keeping {n_features_kept} features, removing {n_features_removed}")
        
        # 5-fold CV for ablation (proper methodology)
        fold_aucs = []
        for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(X_all, y_all)):
            X_tr_red = X_reduced[tr_idx]
            X_va_red = X_reduced[va_idx]
            y_tr = y_all[tr_idx]
            y_va = y_all[va_idx]
            
            # Train model with same params as baseline
            ablation_model = lgb.LGBMClassifier(**params)
            ablation_model.fit(
                X_tr_red, y_tr,
                eval_set=[(X_va_red, y_va)],
                callbacks=[lgb.early_stopping(200, verbose=False)]
            )
            
            # Evaluate
            va_pred = ablation_model.predict_proba(X_va_red)[:,1]
            fold_auc = roc_auc_score(y_va, va_pred)
            fold_aucs.append(fold_auc)
            
            print(f"    Fold {fold_idx+1}: {fold_auc:.6f}")
        
        mean_auc = np.mean(fold_aucs)
        std_auc = np.std(fold_aucs)
        auc_delta = baseline_auc - mean_auc
        auc_delta_pct = (auc_delta / baseline_auc) * 100
        
        print(f"  Mean AUC: {mean_auc:.6f} +/- {std_auc:.6f}")
        print(f"  Impact: Delta AUC = {auc_delta:+.6f} ({auc_delta_pct:+.2f}%)")
        
        ablation_results.append({
            'configuration': f'without {cat_name}',
            'n_features': n_features_kept,
            'removed_category': cat_name,
            'auc_mean': mean_auc,
            'auc_std': std_auc,
            'f1': np.nan,  # Not computed for speed
            'auc_delta': auc_delta,
            'auc_delta_pct': auc_delta_pct
        })
    
    # Save ablation results
    ablation_df = pd.DataFrame(ablation_results).sort_values('auc_delta', ascending=False)
    ablation_df.to_csv(output_path / 'ablation_results.csv', index=False)
    print(f"\n[OK] Saved: {output_path / 'ablation_results.csv'}")
    
    print("\n" + "="*80)
    print("ABLATION SUMMARY")
    print("="*80)
    print(ablation_df[['configuration', 'n_features', 'auc_mean', 'auc_std', 'auc_delta_pct']].to_string(index=False))

# ============================================================
# TOP-K FEATURE ANALYSIS - CORRECTED WITH 5-FOLD CV
# ============================================================

print("\n" + "="*80)
print("TOP-K FEATURE EFFICIENCY ANALYSIS (5-FOLD CV)")
print("="*80)
print("Using proper cross-validation to ensure valid comparisons")
print("="*80)

topk_results = []
k_values = [10, 25, 50, 100, 200, 500, 1000, 2000, 3000]

for k in k_values:
    if k >= X_all.shape[1]:
        continue
    
    print(f"\n{'='*60}")
    print(f"Testing top {k} features with 5-fold CV")
    print(f"{'='*60}")
    
    # Get top k feature indices (by mean importance)
    top_k_indices = importance_df.head(k)['feature_idx'].values
    X_topk = X_all[:, top_k_indices]
    
    # 5-fold CV (SAME splits as baseline!)
    fold_aucs = []
    fold_f1s = []
    fold_iterations = []
    
    for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(X_all, y_all)):
        X_tr = X_topk[tr_idx]
        X_va = X_topk[va_idx]
        y_tr = y_all[tr_idx]
        y_va = y_all[va_idx]
        
        # Train with SAME parameters as baseline (critical!)
        model = lgb.LGBMClassifier(**params)
        model.fit(
            X_tr, y_tr,
            eval_set=[(X_va, y_va)],
            callbacks=[lgb.early_stopping(200, verbose=False)]
        )
        
        # Evaluate
        va_pred = model.predict_proba(X_va)[:,1]
        fold_auc = roc_auc_score(y_va, va_pred)
        
        va_pred_binary = (va_pred >= best_thr).astype(int)
        fold_f1 = f1_score(y_va, va_pred_binary, zero_division=0)
        
        fold_aucs.append(fold_auc)
        fold_f1s.append(fold_f1)
        fold_iterations.append(model.best_iteration_)
        
        print(f"  Fold {fold_idx+1}: AUC={fold_auc:.6f}, F1={fold_f1:.6f}, Iter={model.best_iteration_}")
    
    mean_auc = np.mean(fold_aucs)
    std_auc = np.std(fold_aucs)
    mean_f1 = np.mean(fold_f1s)
    mean_iter = np.mean(fold_iterations)
    
    pct_baseline = (mean_auc / baseline_auc) * 100
    auc_delta = baseline_auc - mean_auc
    
    print(f"\n  Summary:")
    print(f"    Mean AUC:       {mean_auc:.6f} +/- {std_auc:.6f}")
    print(f"    % of baseline:  {pct_baseline:.2f}%")
    print(f"    Delta AUC:      {auc_delta:+.6f}")
    print(f"    Mean F1:        {mean_f1:.6f}")
    print(f"    Mean iterations: {mean_iter:.0f}")
    
    topk_results.append({
        'k': k,
        'auc_mean': mean_auc,
        'auc_std': std_auc,
        'f1_mean': mean_f1,
        'pct_baseline': pct_baseline,
        'auc_delta': auc_delta,
        'mean_iterations': mean_iter
    })

# Add baseline (all features) to top-k results
topk_results.append({
    'k': X_all.shape[1],
    'auc_mean': baseline_auc,
    'auc_std': fold_df['val_auc'].std(),
    'f1_mean': baseline_f1,
    'pct_baseline': 100.0,
    'auc_delta': 0.0,
    'mean_iterations': fold_df['n_iterations'].mean()
})

# Save top-k results
topk_df = pd.DataFrame(topk_results).sort_values('k')
topk_df.to_csv(output_path / 'topk_results.csv', index=False)
print(f"\n[OK] Saved: {output_path / 'topk_results.csv'}")

print("\n" + "="*80)
print("TOP-K SUMMARY")
print("="*80)
print(topk_df[['k', 'auc_mean', 'auc_std', 'pct_baseline', 'auc_delta']].to_string(index=False))

# Verify monotonicity
print("\n" + "="*80)
print("MONOTONICITY CHECK")
print("="*80)
sorted_topk = topk_df.sort_values('k')
is_monotonic = all(sorted_topk['auc_mean'].iloc[i] <= sorted_topk['auc_mean'].iloc[i+1] 
                   for i in range(len(sorted_topk)-1))
if is_monotonic:
    print("✓ Results are monotonic (performance increases with more features)")
else:
    print("⚠ Results are NOT perfectly monotonic (expected due to CV variance)")
    print("  This is normal with proper CV - small non-monotonicity is measurement noise")

# ============================================================
# TEST SET EVALUATION
# ============================================================

print("\n" + "="*80)
print("TEST SET EVALUATION")
print("="*80)

test_probs = predict_ensemble(X_test)

# Evaluate at different thresholds
print("\nPerformance at different thresholds:")
print(f"{'Threshold':<12} {'Recall':>10} {'Precision':>12} {'F1':>10} {'Detected':>12}")
print("-"*60)

test_results_by_threshold = []

for t in [0.5, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15, best_thr]:
    test_pred = (test_probs >= t).astype(int)
    
    recall = recall_score(y_test, test_pred, zero_division=0)
    precision = precision_score(y_test, test_pred, zero_division=0)
    f1 = f1_score(y_test, test_pred, zero_division=0)
    
    n_detected = (test_pred & y_test).sum()
    n_total = y_test.sum()
    
    print(f"{t:<12.4f} {recall:>10.4f} {precision:>12.4f} {f1:>10.4f} {n_detected:>6}/{n_total:<5}")
    
    test_results_by_threshold.append({
        'threshold': t,
        'recall': recall,
        'precision': precision,
        'f1': f1,
        'n_detected': n_detected,
        'n_total': n_total
    })

# Save threshold analysis
threshold_df = pd.DataFrame(test_results_by_threshold)
threshold_df.to_csv(output_path / 'test_threshold_analysis.csv', index=False)

# Final test evaluation with optimal threshold
test_pred_optimal = (test_probs >= best_thr).astype(int)

test_auc = roc_auc_score(y_test, test_probs)
test_recall = recall_score(y_test, test_pred_optimal)
test_precision = precision_score(y_test, test_pred_optimal)
test_f1 = f1_score(y_test, test_pred_optimal)
test_balanced_acc = balanced_accuracy_score(y_test, test_pred_optimal)

print("\n" + "="*80)
print(f"FINAL TEST RESULTS (threshold={best_thr:.4f})")
print("="*80)
print(f"ROC-AUC:           {test_auc:.6f}")
print(f"F1:                {test_f1:.6f}")
print(f"Precision:         {test_precision:.6f}")
print(f"Recall:            {test_recall:.6f}")
print(f"Balanced Accuracy: {test_balanced_acc:.6f}")

# Confusion matrix
cm = confusion_matrix(y_test, test_pred_optimal)
tn, fp, fn, tp = cm.ravel()

print(f"\nConfusion Matrix:")
print(f"  True Negatives:  {tn:>6}")
print(f"  False Positives: {fp:>6}")
print(f"  False Negatives: {fn:>6}")
print(f"  True Positives:  {tp:>6}")
print(f"\nSpecificity: {tn/(tn+fp):.4f}")
print(f"Sensitivity: {tp/(tp+fn):.4f}")

# Classification report
print(f"\nDetailed Classification Report:")
print(classification_report(y_test, test_pred_optimal, digits=4))

# Save test results
test_summary = {
    'threshold': best_thr,
    'roc_auc': test_auc,
    'f1': test_f1,
    'precision': test_precision,
    'recall': test_recall,
    'balanced_accuracy': test_balanced_acc,
    'tn': int(tn),
    'fp': int(fp),
    'fn': int(fn),
    'tp': int(tp),
    'oof_auc': oof_auc,
    'mean_fold_auc': fold_df['val_auc'].mean(),
    'std_fold_auc': fold_df['val_auc'].std()
}

import json
with open(output_path / 'test_results.json', 'w', encoding='utf-8') as f:
    json.dump(test_summary, f, indent=2)

print(f"\n[OK] Saved: {output_path / 'test_results.json'}")

# ============================================================
# VISUALIZATION: ROC CURVE & PR CURVE
# ============================================================

print("\n" + "="*80)
print("GENERATING VISUALIZATIONS")
print("="*80)

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# 1. ROC Curve
ax = axes[0, 0]
fpr, tpr, _ = roc_curve(y_test, test_probs)
ax.plot(fpr, tpr, linewidth=2, label=f'Test (AUC={test_auc:.4f})')
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='Random')
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate (Recall)', fontsize=12)
ax.set_title('ROC Curve', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# 2. Precision-Recall Curve
ax = axes[0, 1]
precision_curve, recall_curve, _ = precision_recall_curve(y_test, test_probs)
pr_auc = auc(recall_curve, precision_curve)
ax.plot(recall_curve, precision_curve, linewidth=2, label=f'PR AUC={pr_auc:.4f}')
ax.axhline(y_test.mean(), color='r', linestyle='--', alpha=0.3, label=f'Baseline={y_test.mean():.3f}')
ax.set_xlabel('Recall', fontsize=12)
ax.set_ylabel('Precision', fontsize=12)
ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# 3. Feature Importance (Top 30)
ax = axes[1, 0]
top_30 = importance_df.head(30)
colors = plt.cm.viridis(np.linspace(0, 1, 30))
ax.barh(range(30), top_30['importance_mean'].values, color=colors)
ax.set_yticks(range(30))
ax.set_yticklabels([f"Feature {idx}" for idx in top_30['feature_idx'].values], fontsize=8)
ax.set_xlabel('Importance (Gain)', fontsize=12)
ax.set_title('Top 30 Features by Importance', fontsize=14, fontweight='bold')
ax.invert_yaxis()
ax.grid(True, alpha=0.3, axis='x')

# 4. Confusion Matrix Heatmap
ax = axes[1, 1]
cm_display = confusion_matrix(y_test, test_pred_optimal)
sns.heatmap(cm_display, annot=True, fmt='d', cmap='Blues', ax=ax,
            xticklabels=['Negative', 'Positive'],
            yticklabels=['Negative', 'Positive'])
ax.set_xlabel('Predicted', fontsize=12)
ax.set_ylabel('Actual', fontsize=12)
ax.set_title(f'Confusion Matrix (threshold={best_thr:.3f})', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(output_path / 'test_evaluation_summary.png', dpi=300, bbox_inches='tight')
print(f"[OK] Saved: {output_path / 'test_evaluation_summary.png'}")
plt.close()

# ============================================================
# VISUALIZATION: CATEGORY IMPORTANCE
# ============================================================

if feature_categories and 'category' in importance_df.columns:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Category importance pie chart
    ax = axes[0]
    category_totals = importance_df.groupby('category')['importance_mean'].sum().sort_values(ascending=False)
    colors_pie = plt.cm.Set3(np.linspace(0, 1, len(category_totals)))
    wedges, texts, autotexts = ax.pie(category_totals.values, 
                                        labels=category_totals.index,
                                        autopct='%1.1f%%',
                                        colors=colors_pie,
                                        startangle=90)
    ax.set_title('Feature Importance by Category', fontsize=14, fontweight='bold')
    
    # Category importance bar chart
    ax = axes[1]
    ax.bar(range(len(category_totals)), category_totals.values, color=colors_pie)
    ax.set_xticks(range(len(category_totals)))
    ax.set_xticklabels(category_totals.index, rotation=45, ha='right')
    ax.set_ylabel('Total Importance (Gain)', fontsize=12)
    ax.set_title('Category Importance Comparison', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(output_path / 'category_importance.png', dpi=300, bbox_inches='tight')
    print(f"[OK] Saved: {output_path / 'category_importance.png'}")
    plt.close()

# ============================================================
# VISUALIZATION: ABLATION IMPACT
# ============================================================

if feature_categories and len(ablation_results) > 1:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    ablation_plot_df = ablation_df[ablation_df['removed_category'].notna()].copy()
    
    colors_abl = ['red' if x > 0 else 'green' for x in ablation_plot_df['auc_delta']]
    
    # Add error bars
    ax.barh(range(len(ablation_plot_df)), ablation_plot_df['auc_delta'].values, 
            xerr=ablation_plot_df['auc_std'].values, color=colors_abl, alpha=0.7, capsize=5)
    ax.set_yticks(range(len(ablation_plot_df)))
    ax.set_yticklabels(ablation_plot_df['removed_category'].values)
    ax.set_xlabel('AUC Impact (Baseline - Without Category)', fontsize=12)
    ax.set_title('Ablation Study: Impact of Removing Each Category', fontsize=14, fontweight='bold')
    ax.axvline(0, color='black', linestyle='-', linewidth=0.5)
    ax.grid(True, alpha=0.3, axis='x')
    
    plt.tight_layout()
    plt.savefig(output_path / 'ablation_impact.png', dpi=300, bbox_inches='tight')
    print(f"[OK] Saved: {output_path / 'ablation_impact.png'}")
    plt.close()

# ============================================================
# VISUALIZATION: TOP-K PERFORMANCE WITH ERROR BARS
# ============================================================

if len(topk_results) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # AUC vs K with error bars
    ax = axes[0]
    ax.errorbar(topk_df['k'], topk_df['auc_mean'], yerr=topk_df['auc_std'], 
                marker='o', linewidth=2, capsize=5, label='Top-K (5-fold CV)')
    ax.axhline(baseline_auc, color='r', linestyle='--', label='Baseline (all features)')
    ax.set_xlabel('Number of Features (K)', fontsize=12)
    ax.set_ylabel('ROC-AUC', fontsize=12)
    ax.set_title('Performance vs Number of Features', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # % of Baseline vs K
    ax = axes[1]
    ax.plot(topk_df['k'], topk_df['pct_baseline'], marker='o', linewidth=2, color='green')
    ax.axhline(100, color='r', linestyle='--', label='100% baseline')
    ax.axhline(99, color='orange', linestyle=':', label='99% baseline')
    ax.axhline(95, color='gray', linestyle=':', label='95% baseline', alpha=0.5)
    ax.set_xlabel('Number of Features (K)', fontsize=12)
    ax.set_ylabel('% of Baseline Performance', fontsize=12)
    ax.set_title('Feature Efficiency Analysis (5-Fold CV)', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_path / 'topk_performance.png', dpi=300, bbox_inches='tight')
    print(f"[OK] Saved: {output_path / 'topk_performance.png'}")
    plt.close()

# ============================================================
# VISUALIZATION: THRESHOLD ANALYSIS
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Recall/Precision vs Threshold
ax = axes[0]
thresholds_fine = np.linspace(0.05, 0.95, 100)
recalls = [recall_score(y_test, (test_probs >= t).astype(int), zero_division=0) for t in thresholds_fine]
precisions = [precision_score(y_test, (test_probs >= t).astype(int), zero_division=0) for t in thresholds_fine]

ax.plot(thresholds_fine, recalls, label='Recall', linewidth=2)
ax.plot(thresholds_fine, precisions, label='Precision', linewidth=2)
ax.axvline(best_thr, color='r', linestyle='--', label=f'Optimal ({best_thr:.3f})')
ax.set_xlabel('Threshold', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Recall & Precision vs Threshold', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Number detected vs Threshold
ax = axes[1]
n_detected = [(test_probs >= t).sum() for t in thresholds_fine]
ax.plot(thresholds_fine, n_detected, linewidth=2, color='purple')
ax.axvline(best_thr, color='r', linestyle='--', label=f'Optimal ({best_thr:.3f})')
ax.set_xlabel('Threshold', fontsize=12)
ax.set_ylabel('Number of Samples Flagged', fontsize=12)
ax.set_title('Detection Volume vs Threshold', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_path / 'threshold_analysis.png', dpi=300, bbox_inches='tight')
print(f"[OK] Saved: {output_path / 'threshold_analysis.png'}")
plt.close()

# ============================================================
# FINAL SUMMARY REPORT
# ============================================================

print("\n" + "="*80)
print("GENERATING FINAL REPORT")
print("="*80)

report_lines = [
    "="*80,
    "COMPREHENSIVE ANALYSIS REPORT - CORRECTED VERSION",
    "="*80,
    "",
    "DATASET SUMMARY",
    "-"*80,
    f"Train samples:      {len(y_train):>8,}  ({y_train.mean()*100:.2f}% positive)",
    f"Validation samples: {len(y_val):>8,}  ({y_val.mean()*100:.2f}% positive)",
    f"Test samples:       {len(y_test):>8,}  ({y_test.mean()*100:.2f}% positive)",
    f"Total features:     {X_all.shape[1]:>8,}",
    "",
    "CROSS-VALIDATION RESULTS",
    "-"*80,
    f"OOF AUC:            {oof_auc:.6f}",
    f"Mean Fold AUC:      {fold_df['val_auc'].mean():.6f} +/- {fold_df['val_auc'].std():.6f}",
    "",
    "TEST SET PERFORMANCE",
    "-"*80,
    f"Optimal Threshold:  {best_thr:.4f}",
    f"ROC-AUC:            {test_auc:.6f}",
    f"F1-Score:           {test_f1:.6f}",
    f"Precision:          {test_precision:.6f}",
    f"Recall:             {test_recall:.6f}",
    f"Balanced Accuracy:  {test_balanced_acc:.6f}",
    "",
    "CONFUSION MATRIX",
    "-"*80,
    f"True Negatives:     {tn:>8,}",
    f"False Positives:    {fp:>8,}",
    f"False Negatives:    {fn:>8,}",
    f"True Positives:     {tp:>8,}",
    "",
    f"Detected:           {tp}/{tp+fn} errors ({tp/(tp+fn)*100:.1f}%)",
    f"Missed:             {fn}/{tp+fn} errors ({fn/(tp+fn)*100:.1f}%)",
    ""
]

if feature_categories:
    report_lines.extend([
        "FEATURE CATEGORY IMPORTANCE",
        "-"*80,
    ])
    for cat_name in category_importance.index:
        total_pct = category_importance.loc[cat_name, 'total_gain_pct']
        report_lines.append(f"{cat_name:25s}: {total_pct:>6.2f}%")
    report_lines.append("")

report_lines.extend([
    "TOP-K EFFICIENCY (5-FOLD CV)",
    "-"*80,
])
for _, row in topk_df.iterrows():
    report_lines.append(f"K={row['k']:>5}: {row['auc_mean']:.6f} +/- {row['auc_std']:.6f} ({row['pct_baseline']:>5.2f}%)")
report_lines.append("")

report_lines.extend([
    "FILES GENERATED",
    "-"*80,
    "[OK] fold_results.csv",
    "[OK] feature_importance_complete.csv",
    "[OK] test_results.json",
    "[OK] test_threshold_analysis.csv",
    "[OK] test_evaluation_summary.png",
])

if feature_categories:
    report_lines.extend([
        "[OK] category_importance.csv",
        "[OK] category_importance.png",
        "[OK] ablation_results.csv",
        "[OK] ablation_impact.png",
    ])

if len(topk_results) > 0:
    report_lines.extend([
        "[OK] topk_results.csv",
        "[OK] topk_performance.png",
    ])

report_lines.extend([
    "[OK] threshold_analysis.png",
    "",
    "="*80,
    "ANALYSIS COMPLETE",
    "="*80,
])

report_text = "\n".join(report_lines)
print(report_text)

# Save report
with open(output_path / 'ANALYSIS_REPORT.txt', 'w', encoding='utf-8') as f:
    f.write(report_text)

print(f"\n[OK] Full report saved to: {output_path / 'ANALYSIS_REPORT.txt'}")
print(f"\n[OK] All outputs saved to: {output_path.absolute()}")

print("\n" + "="*80)
print("SUCCESS - ALL ANALYSES COMPLETE")
print("="*80)
print("\nKEY IMPROVEMENTS IN THIS VERSION:")
print("  ✓ Top-K analysis uses proper 5-fold CV (not single split)")
print("  ✓ Same parameters for all K values (n_estimators=4000)")
print("  ✓ Same folds as baseline (fair comparison)")
print("  ✓ Error bars included in results and plots")
print("  ✓ Results should be monotonic (or nearly so)")
print("="*80)