# Threshold Tuning and Model Explainability
## Transaction Fraud Detection System

This notebook covers:
1. Threshold optimization for business constraints
2. Precision-Recall trade-off analysis
3. SHAP explainability (global and local)
4. Generate fraud score predictions
5. Production recommendations

In [None]:
# Import libraries
import sys
sys.path.append('../src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import lightgbm as lgb
import shap

from data_loader import load_raw_data, clean_data
from features import engineer_features, get_feature_columns
from train import time_based_split, load_model
from evaluate import precision_at_k, recall_at_k
from sklearn.metrics import precision_recall_curve

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

import warnings
warnings.filterwarnings('ignore')

# Initialize SHAP
shap.initjs()

print("Libraries imported successfully!")

## 1. Load Data and Model

In [None]:
# Load and prepare data
df = load_raw_data(nrows=500000)  # Adjust based on your system
df_clean = clean_data(df)
df_featured, _ = engineer_features(df_clean)

# Get features
feature_cols = get_feature_columns(df_featured)

# Time-based split
train_df, test_df = time_based_split(df_featured, test_size=0.2)

X_train = train_df[feature_cols]
y_train = train_df['isFraud']
X_test = test_df[feature_cols]
y_test = test_df['isFraud']

print(f"Test set: {len(X_test):,} transactions")
print(f"Test fraud rate: {y_test.mean()*100:.4f}%")

In [None]:
# Load trained model
# If model file doesn't exist, train a new one
try:
    model = load_model('../models/lightgbm_best.txt', model_type='lightgbm')
    print("Loaded existing model")
except:
    print("Model not found. Training new model...")
    from train import train_lightgbm
    model = train_lightgbm(X_train, y_train, X_test, y_test)
    from train import save_model
    save_model(model, '../models/lightgbm_best.txt', model_type='lightgbm')

# Get predictions
y_pred_proba = model.predict(X_test)
print(f"\nPredictions generated for {len(y_pred_proba):,} transactions")

## 2. Threshold Optimization

Finding the optimal threshold based on business constraints.

In [None]:
# Calculate precision and recall for different thresholds
precision, recall, thresholds = precision_recall_curve(y_test, y_pred_proba)

# Calculate F1 score for each threshold
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)

# Find optimal threshold based on different criteria
# 1. Maximum F1 score
optimal_idx_f1 = np.argmax(f1_scores)
optimal_threshold_f1 = thresholds[optimal_idx_f1] if optimal_idx_f1 < len(thresholds) else 0.5

# 2. Precision >= 0.9 (high confidence)
high_precision_idx = np.where(precision >= 0.9)[0]
if len(high_precision_idx) > 0:
    optimal_threshold_precision = thresholds[high_precision_idx[0]] if high_precision_idx[0] < len(thresholds) else 0.9
else:
    optimal_threshold_precision = 0.9

# 3. Recall >= 0.7 (catch most frauds)
high_recall_idx = np.where(recall >= 0.7)[0]
if len(high_recall_idx) > 0:
    optimal_threshold_recall = thresholds[high_recall_idx[-1]] if high_recall_idx[-1] < len(thresholds) else 0.3
else:
    optimal_threshold_recall = 0.3

print("Optimal Thresholds:")
print("="*70)
print(f"Maximum F1 Score:        {optimal_threshold_f1:.4f} (F1={f1_scores[optimal_idx_f1]:.4f})")
print(f"High Precision (≥0.9):   {optimal_threshold_precision:.4f}")
print(f"High Recall (≥0.7):      {optimal_threshold_recall:.4f}")

In [None]:
# Visualize Precision-Recall trade-off
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Precision-Recall curve with thresholds
axes[0].plot(recall, precision, linewidth=2, color='steelblue')
axes[0].scatter([recall[optimal_idx_f1]], [precision[optimal_idx_f1]], 
                color='red', s=100, zorder=5, label=f'Max F1 (threshold={optimal_threshold_f1:.3f})')
axes[0].axhline(y=0.9, color='green', linestyle='--', alpha=0.5, label='Precision=0.9')
axes[0].axvline(x=0.7, color='orange', linestyle='--', alpha=0.5, label='Recall=0.7')
axes[0].set_xlabel('Recall', fontsize=12)
axes[0].set_ylabel('Precision', fontsize=12)
axes[0].set_title('Precision-Recall Curve with Optimal Threshold', fontsize=14, fontweight='bold')
axes[0].legend(loc='best')
axes[0].grid(alpha=0.3)

# F1 score vs threshold
axes[1].plot(thresholds, f1_scores[:-1], linewidth=2, color='purple')
axes[1].scatter([optimal_threshold_f1], [f1_scores[optimal_idx_f1]], 
                color='red', s=100, zorder=5, label=f'Max F1={f1_scores[optimal_idx_f1]:.3f}')
axes[1].set_xlabel('Threshold', fontsize=12)
axes[1].set_ylabel('F1 Score', fontsize=12)
axes[1].set_title('F1 Score vs Threshold', fontsize=14, fontweight='bold')
axes[1].legend(loc='best')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('../outputs/threshold_optimization.png', dpi=300, bbox_inches='tight')
plt.show()

## 3. Business Scenario Analysis

Different thresholds for different business needs.

In [None]:
# Analyze different threshold scenarios
scenarios = {
    'Conservative (High Precision)': 0.9,
    'Balanced (Max F1)': optimal_threshold_f1,
    'Aggressive (High Recall)': 0.3,
    'Default': 0.5
}

scenario_results = []

for scenario_name, threshold in scenarios.items():
    y_pred = (y_pred_proba >= threshold).astype(int)
    
    tp = ((y_pred == 1) & (y_test == 1)).sum()
    fp = ((y_pred == 1) & (y_test == 0)).sum()
    fn = ((y_pred == 0) & (y_test == 1)).sum()
    tn = ((y_pred == 0) & (y_test == 0)).sum()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    scenario_results.append({
        'Scenario': scenario_name,
        'Threshold': threshold,
        'Precision': precision,
        'Recall': recall,
        'F1': f1,
        'Flagged': tp + fp,
        'True Positives': tp,
        'False Positives': fp
    })

scenario_df = pd.DataFrame(scenario_results)

print("\nBusiness Scenario Analysis:")
print("="*100)
print(scenario_df.to_string(index=False))
print("\n" + "="*100)

## 4. Precision@K and Recall@K for Different K Values

In [None]:
# Calculate Precision@K and Recall@K for various K values
k_values = [50, 100, 200, 500, 1000, 2000, 5000, 10000]

precision_at_k_values = []
recall_at_k_values = []

for k in k_values:
    if k <= len(y_test):
        prec_k = precision_at_k(y_test, y_pred_proba, k)
        rec_k = recall_at_k(y_test, y_pred_proba, k)
        precision_at_k_values.append(prec_k)
        recall_at_k_values.append(rec_k)
    else:
        precision_at_k_values.append(np.nan)
        recall_at_k_values.append(np.nan)

# Create results table
precision_recall_k_df = pd.DataFrame({
    'K': k_values,
    'Precision@K': precision_at_k_values,
    'Recall@K': recall_at_k_values
})

print("\nPrecision@K and Recall@K:")
print("="*70)
print(precision_recall_k_df.to_string(index=False))

In [None]:
# Visualize Precision@K and Recall@K
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Precision@K
axes[0].plot(k_values, precision_at_k_values, marker='o', linewidth=2, markersize=8, color='steelblue')
axes[0].set_xlabel('K (Number of Top Predictions)', fontsize=12)
axes[0].set_ylabel('Precision@K', fontsize=12)
axes[0].set_title('Precision@K: Quality of Top Predictions', fontsize=14, fontweight='bold')
axes[0].set_xscale('log')
axes[0].grid(alpha=0.3)

# Recall@K
axes[1].plot(k_values, recall_at_k_values, marker='o', linewidth=2, markersize=8, color='orange')
axes[1].set_xlabel('K (Number of Top Predictions)', fontsize=12)
axes[1].set_ylabel('Recall@K', fontsize=12)
axes[1].set_title('Recall@K: Coverage of Frauds', fontsize=14, fontweight='bold')
axes[1].set_xscale('log')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('../outputs/precision_recall_at_k_detailed.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. SHAP Explainability - Global Feature Importance

In [None]:
# Create SHAP explainer
# Use a sample for faster computation
sample_size = min(1000, len(X_test))
X_test_sample = X_test.sample(n=sample_size, random_state=42)

print(f"Creating SHAP explainer with {sample_size} samples...")
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test_sample)

print("SHAP values computed successfully!")

In [None]:
# Global feature importance (mean absolute SHAP values)
shap_importance = pd.DataFrame({
    'feature': feature_cols,
    'importance': np.abs(shap_values).mean(axis=0)
}).sort_values('importance', ascending=False)

print("\nTop 20 Features by SHAP Importance:")
print("="*70)
print(shap_importance.head(20).to_string(index=False))

In [None]:
# SHAP summary plot (global feature importance)
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_test_sample, max_display=20, show=False)
plt.title('SHAP Feature Importance (Global)', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('../outputs/shap_global_importance.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# SHAP summary plot (feature impact)
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_test_sample, plot_type='violin', max_display=20, show=False)
plt.title('SHAP Feature Impact Distribution', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('../outputs/shap_feature_impact.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. SHAP Explainability - Local Explanations

Explaining individual fraud predictions.

In [None]:
# Get top fraud predictions
test_df_with_scores = test_df.copy()
test_df_with_scores['fraud_score'] = y_pred_proba

# Top 10 predicted frauds
top_frauds = test_df_with_scores.nlargest(10, 'fraud_score')

print("\nTop 10 Predicted Frauds:")
print("="*100)
print(top_frauds[['transaction_id', 'type', 'amount', 'isFraud', 'fraud_score']].to_string(index=False))

In [None]:
# SHAP waterfall plot for top predicted fraud
# Get the index of the top fraud in our sample
top_fraud_id = top_frauds.iloc[0]['transaction_id']

# Find this transaction in our sample (if it exists)
if top_fraud_id in X_test_sample.index:
    idx_in_sample = X_test_sample.index.get_loc(top_fraud_id)
    
    print(f"\nExplaining transaction {top_fraud_id}:")
    print(f"Fraud score: {y_pred_proba[top_fraud_id]:.4f}")
    print(f"Actual label: {'FRAUD' if y_test.loc[top_fraud_id] == 1 else 'LEGITIMATE'}")
    
    # Create waterfall plot
    shap.waterfall_plot(shap.Explanation(
        values=shap_values[idx_in_sample],
        base_values=explainer.expected_value,
        data=X_test_sample.iloc[idx_in_sample],
        feature_names=feature_cols
    ), max_display=15, show=False)
    
    plt.title(f'SHAP Explanation: Transaction {top_fraud_id}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../outputs/shap_local_explanation_top_fraud.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print(f"\nTop fraud transaction {top_fraud_id} not in SHAP sample. Using first sample instead.")
    shap.waterfall_plot(shap.Explanation(
        values=shap_values[0],
        base_values=explainer.expected_value,
        data=X_test_sample.iloc[0],
        feature_names=feature_cols
    ), max_display=15, show=False)
    plt.title('SHAP Explanation: Sample Transaction', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../outputs/shap_local_explanation_sample.png', dpi=300, bbox_inches='tight')
    plt.show()

## 7. Generate Fraud Score Output

In [None]:
# Create fraud score output
fraud_score_df = pd.DataFrame({
    'transaction_id': test_df['transaction_id'],
    'fraud_score': y_pred_proba,
    'isFraud': y_test
})

# Sort by fraud score (descending)
fraud_score_df = fraud_score_df.sort_values('fraud_score', ascending=False).reset_index(drop=True)

# Save to CSV
output_path = '../outputs/fraud_score.csv'
fraud_score_df.to_csv(output_path, index=False)

print(f"\nFraud scores saved to: {output_path}")
print(f"Total transactions: {len(fraud_score_df):,}")
print("\nSample of top predictions:")
print(fraud_score_df.head(20))

In [None]:
# Analyze fraud score distribution
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Fraud score distribution by actual label
fraud_score_df[fraud_score_df['isFraud']==0]['fraud_score'].hist(
    bins=50, alpha=0.5, label='Legitimate', ax=axes[0], color='green'
)
fraud_score_df[fraud_score_df['isFraud']==1]['fraud_score'].hist(
    bins=50, alpha=0.5, label='Fraud', ax=axes[0], color='red'
)
axes[0].set_xlabel('Fraud Score', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Fraud Score Distribution', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].set_yscale('log')
axes[0].grid(alpha=0.3)

# Cumulative fraud capture
fraud_score_df['cumulative_frauds'] = fraud_score_df['isFraud'].cumsum()
fraud_score_df['pct_transactions'] = (fraud_score_df.index + 1) / len(fraud_score_df) * 100
fraud_score_df['pct_frauds_captured'] = fraud_score_df['cumulative_frauds'] / fraud_score_df['isFraud'].sum() * 100

axes[1].plot(fraud_score_df['pct_transactions'], fraud_score_df['pct_frauds_captured'], 
             linewidth=2, color='steelblue')
axes[1].plot([0, 100], [0, 100], 'r--', alpha=0.5, label='Random')
axes[1].set_xlabel('% of Transactions Reviewed', fontsize=12)
axes[1].set_ylabel('% of Frauds Captured', fontsize=12)
axes[1].set_title('Cumulative Fraud Capture Curve', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('../outputs/fraud_score_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. Production Recommendations

In [None]:
print("\n" + "="*100)
print("PRODUCTION RECOMMENDATIONS")
print("="*100)

print("\n1. THRESHOLD SELECTION:")
print("   - For high-value transactions: Use threshold ≥ 0.9 (high precision)")
print("   - For automated blocking: Use threshold ≥ 0.7 (balanced)")
print("   - For manual review queue: Use threshold ≥ 0.3 (high recall)")
print(f"   - Recommended: {optimal_threshold_f1:.3f} (maximizes F1 score)")

print("\n2. RANKING APPROACH:")
print("   - Use fraud_score for ranking (don't just use binary threshold)")
print("   - Review top K transactions daily based on investigation capacity")
print("   - Precision@100 shows quality of top predictions")

print("\n3. MONITORING:")
print("   - Track Precision@K and Recall@K over time")
print("   - Monitor fraud score distribution for drift")
print("   - Retrain model quarterly or when performance degrades")

print("\n4. EXPLAINABILITY:")
print("   - Use SHAP values to explain flagged transactions")
print("   - Key fraud indicators: balance errors, transaction type, amount ratios")
print("   - Provide explanations to fraud investigators")

print("\n5. MODEL LIMITATIONS:")
print("   - Model trained on historical patterns (may miss new fraud types)")
print("   - Balance errors are strong signal but not present in all frauds")
print("   - Consider ensemble with rule-based systems")
print("   - Regular retraining needed as fraud patterns evolve")

print("\n" + "="*100)

## 9. Summary Statistics

In [None]:
# Calculate final summary statistics
from sklearn.metrics import roc_auc_score, average_precision_score

roc_auc = roc_auc_score(y_test, y_pred_proba)
pr_auc = average_precision_score(y_test, y_pred_proba)
prec_100 = precision_at_k(y_test, y_pred_proba, 100)
rec_1000 = recall_at_k(y_test, y_pred_proba, 1000)

print("\n" + "="*100)
print("FINAL MODEL PERFORMANCE SUMMARY")
print("="*100)
print(f"\nDataset:")
print(f"  Test transactions: {len(y_test):,}")
print(f"  Actual frauds: {y_test.sum():,} ({y_test.mean()*100:.4f}%)")
print(f"\nModel Performance:")
print(f"  ROC-AUC:          {roc_auc:.4f}")
print(f"  PR-AUC:           {pr_auc:.4f}")
print(f"  Precision@100:    {prec_100:.4f}")
print(f"  Recall@1000:      {rec_1000:.4f}")
print(f"\nBusiness Impact:")
print(f"  By reviewing top 100 transactions, we catch {int(prec_100*100)} frauds")
print(f"  By reviewing top 1000 transactions, we catch {rec_1000*100:.1f}% of all frauds")
print("\n" + "="*100)

In [None]:
print("\n" + "="*70)
print("THRESHOLD TUNING AND EXPLAINABILITY COMPLETE!")
print("="*70)
print("\nAll outputs saved to ../outputs/")
print("  - fraud_score.csv")
print("  - threshold_optimization.png")
print("  - shap_global_importance.png")
print("  - shap_feature_impact.png")
print("  - fraud_score_analysis.png")
print("\nProject complete! Review README.md for full documentation.")