# Task 3: Model Explainability with SHAP

This notebook implements Task 3: Interpreting model predictions using SHAP to understand what drives fraud detection.

## Objectives:
1. Extract built-in feature importance from ensemble model
2. Generate SHAP Summary Plot (global feature importance)
3. Generate SHAP Force Plots for individual predictions (TP, FP, FN)
4. Compare SHAP importance with built-in feature importance
5. Identify top 5 drivers of fraud predictions
6. Provide business recommendations based on SHAP insights


In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os
import warnings
import shap
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

from modeling import load_model
from shap_explainability import (
    get_feature_importance, plot_feature_importance,
    create_shap_explainer, plot_shap_summary, plot_shap_force_plot,
    get_top_fraud_drivers
)

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

# Initialize SHAP
shap.initjs()

print("Libraries imported successfully!")


## Part 1: Fraud_Data.csv Model Explainability


In [None]:
# Load best model and test data
print("=" * 60)
print("Loading Best Model for Fraud_Data")
print("=" * 60)

# Load test data
X_test_fraud = pd.read_csv('../data/processed/fraud_X_test.csv')
y_test_fraud = pd.read_csv('../data/processed/fraud_y_test.csv').squeeze()

# Try to load the best model (adjust filename based on what was saved)
import glob
model_files = glob.glob('../models/best_model_fraud_*.pkl')
if model_files:
    best_model_fraud = load_model(model_files[0])
    print(f"Loaded model: {model_files[0]}")
else:
    print("Model file not found. Please run modeling.ipynb first.")
    # For demonstration, we'll use a placeholder - in practice, load the actual model
    best_model_fraud = None

print(f"Test data shape: {X_test_fraud.shape}")
print(f"Test class distribution:")
print(y_test_fraud.value_counts())


### 1. Feature Importance Baseline


In [None]:
# Extract built-in feature importance
if best_model_fraud is not None:
    feature_importance_fraud = get_feature_importance(
        best_model_fraud, 
        X_test_fraud.columns.tolist()
    )
    
    print("Top 10 Most Important Features (Built-in):")
    print(feature_importance_fraud.head(10))
    
    # Visualize top 10 features
    plot_feature_importance(
        feature_importance_fraud, 
        top_n=10,
        title="Top 10 Feature Importance - Fraud_Data Model"
    )
else:
    print("Please load a trained model first.")


### 2. SHAP Analysis


In [None]:
# Create SHAP explainer
if best_model_fraud is not None:
    # Sample data for SHAP (use subset for efficiency)
    sample_size = min(1000, len(X_test_fraud))
    X_sample_fraud = X_test_fraud.sample(n=sample_size, random_state=42)
    
    # Determine model type for SHAP explainer
    model_type = 'tree' if hasattr(best_model_fraud, 'feature_importances_') else 'linear'
    
    print(f"Creating SHAP explainer (model_type: {model_type})...")
    explainer_fraud = create_shap_explainer(best_model_fraud, X_sample_fraud, model_type=model_type)
    
    # Calculate SHAP values
    print("Calculating SHAP values (this may take a while)...")
    shap_values_fraud = explainer_fraud.shap_values(X_sample_fraud)
    
    # Handle binary classification
    if isinstance(shap_values_fraud, list):
        shap_values_fraud = shap_values_fraud[1]  # Use positive class
    
    print(f"SHAP values shape: {shap_values_fraud.shape}")
else:
    print("Please load a trained model first.")


In [None]:
# SHAP Summary Plot (Global Feature Importance)
if best_model_fraud is not None and 'shap_values_fraud' in locals():
    print("=" * 60)
    print("SHAP Summary Plot - Global Feature Importance")
    print("=" * 60)
    
    plot_shap_summary(shap_values_fraud, X_sample_fraud, max_display=20)
else:
    print("Please run previous cells first.")


In [None]:
# Get predictions for finding TP, FP, FN cases
if best_model_fraud is not None:
    y_pred_fraud = best_model_fraud.predict(X_test_fraud)
    y_pred_proba_fraud = best_model_fraud.predict_proba(X_test_fraud)[:, 1]
    
    # Find indices of different prediction types
    # True Positive: Correctly identified fraud
    tp_indices = np.where((y_test_fraud == 1) & (y_pred_fraud == 1))[0]
    # False Positive: Legitimate flagged as fraud
    fp_indices = np.where((y_test_fraud == 0) & (y_pred_fraud == 1))[0]
    # False Negative: Missed fraud
    fn_indices = np.where((y_test_fraud == 1) & (y_pred_fraud == 0))[0]
    
    print(f"True Positives: {len(tp_indices)}")
    print(f"False Positives: {len(fp_indices)}")
    print(f"False Negatives: {len(fn_indices)}")
    
    # Select examples for force plots
    if len(tp_indices) > 0:
        tp_idx = tp_indices[0]
        print(f"\nTrue Positive example at index {tp_idx}")
    if len(fp_indices) > 0:
        fp_idx = fp_indices[0]
        print(f"False Positive example at index {fp_idx}")
    if len(fn_indices) > 0:
        fn_idx = fn_indices[0]
        print(f"False Negative example at index {fn_idx}")
else:
    print("Please load a trained model first.")


### 3. SHAP Force Plots for Individual Predictions


In [None]:
# SHAP Force Plot for True Positive (Correctly identified fraud)
if best_model_fraud is not None and len(tp_indices) > 0:
    print("=" * 60)
    print("SHAP Force Plot: True Positive (Correctly Identified Fraud)")
    print("=" * 60)
    
    tp_instance = X_test_fraud.iloc[[tp_indices[0]]]
    expected_value = explainer_fraud.expected_value
    if isinstance(expected_value, np.ndarray):
        expected_value = expected_value[1]  # For binary classification
    
    shap_values_tp = explainer_fraud.shap_values(tp_instance)
    if isinstance(shap_values_tp, list):
        shap_values_tp = shap_values_tp[1]
    
    shap.force_plot(
        expected_value,
        shap_values_tp[0],
        tp_instance.iloc[0],
        matplotlib=True,
        show=False
    )
    plt.title("True Positive: Correctly Identified Fraud")
    plt.tight_layout()
    plt.show()
    
    print(f"Actual: Fraud (1), Predicted: Fraud (1)")
    print(f"Prediction probability: {y_pred_proba_fraud[tp_indices[0]]:.4f}")
else:
    print("No True Positive examples found or model not loaded.")


In [None]:
# SHAP Force Plot for False Positive (Legitimate flagged as fraud)
if best_model_fraud is not None and len(fp_indices) > 0:
    print("=" * 60)
    print("SHAP Force Plot: False Positive (Legitimate Flagged as Fraud)")
    print("=" * 60)
    
    fp_instance = X_test_fraud.iloc[[fp_indices[0]]]
    expected_value = explainer_fraud.expected_value
    if isinstance(expected_value, np.ndarray):
        expected_value = expected_value[1]
    
    shap_values_fp = explainer_fraud.shap_values(fp_instance)
    if isinstance(shap_values_fp, list):
        shap_values_fp = shap_values_fp[1]
    
    shap.force_plot(
        expected_value,
        shap_values_fp[0],
        fp_instance.iloc[0],
        matplotlib=True,
        show=False
    )
    plt.title("False Positive: Legitimate Transaction Flagged as Fraud")
    plt.tight_layout()
    plt.show()
    
    print(f"Actual: Legitimate (0), Predicted: Fraud (1)")
    print(f"Prediction probability: {y_pred_proba_fraud[fp_indices[0]]:.4f}")
else:
    print("No False Positive examples found or model not loaded.")


In [None]:
# SHAP Force Plot for False Negative (Missed fraud)
if best_model_fraud is not None and len(fn_indices) > 0:
    print("=" * 60)
    print("SHAP Force Plot: False Negative (Missed Fraud)")
    print("=" * 60)
    
    fn_instance = X_test_fraud.iloc[[fn_indices[0]]]
    expected_value = explainer_fraud.expected_value
    if isinstance(expected_value, np.ndarray):
        expected_value = expected_value[1]
    
    shap_values_fn = explainer_fraud.shap_values(fn_instance)
    if isinstance(shap_values_fn, list):
        shap_values_fn = shap_values_fn[1]
    
    shap.force_plot(
        expected_value,
        shap_values_fn[0],
        fn_instance.iloc[0],
        matplotlib=True,
        show=False
    )
    plt.title("False Negative: Fraud Transaction Missed")
    plt.tight_layout()
    plt.show()
    
    print(f"Actual: Fraud (1), Predicted: Legitimate (0)")
    print(f"Prediction probability: {y_pred_proba_fraud[fn_indices[0]]:.4f}")
else:
    print("No False Negative examples found or model not loaded.")


### 4. Interpretation and Analysis


In [None]:
# Compare SHAP importance with built-in feature importance
if best_model_fraud is not None and 'shap_values_fraud' in locals():
    print("=" * 60)
    print("SHAP vs Built-in Feature Importance Comparison")
    print("=" * 60)
    
    # Get SHAP-based importance
    top_shap_drivers = get_top_fraud_drivers(
        shap_values_fraud, 
        X_sample_fraud.columns.tolist(), 
        top_n=10
    )
    
    print("\nTop 10 Features by SHAP Importance:")
    for i, (feature, importance) in enumerate(top_shap_drivers, 1):
        print(f"{i}. {feature}: {importance:.4f}")
    
    print("\nTop 10 Features by Built-in Importance:")
    for i, row in feature_importance_fraud.head(10).iterrows():
        print(f"{i+1}. {row['feature']}: {row['importance']:.4f}")
    
    # Identify top 5 drivers
    print("\n" + "=" * 60)
    print("TOP 5 DRIVERS OF FRAUD PREDICTIONS")
    print("=" * 60)
    top_5_drivers = get_top_fraud_drivers(
        shap_values_fraud, 
        X_sample_fraud.columns.tolist(), 
        top_n=5
    )
    
    for i, (feature, importance) in enumerate(top_5_drivers, 1):
        print(f"{i}. {feature}: {importance:.4f}")
else:
    print("Please run previous cells first.")


### 5. Business Recommendations


In [None]:
# Generate business recommendations based on SHAP insights
print("=" * 60)
print("BUSINESS RECOMMENDATIONS - Fraud_Data")
print("=" * 60)

if best_model_fraud is not None and 'top_5_drivers' in locals():
    recommendations = []
    
    # Analyze top drivers and create recommendations
    for feature, importance in top_5_drivers:
        if 'time_since_signup' in feature.lower():
            recommendations.append(
                f"1. Transactions within X hours of signup should receive additional verification. "
                f"SHAP analysis shows '{feature}' is a top driver of fraud predictions."
            )
        elif 'transaction' in feature.lower() and 'velocity' in feature.lower():
            recommendations.append(
                f"2. High transaction velocity (multiple transactions in short time) should trigger "
                f"fraud alerts. Feature '{feature}' is highly predictive of fraud."
            )
        elif 'country' in feature.lower():
            recommendations.append(
                f"3. Transactions from high-risk countries (identified by '{feature}') should "
                f"undergo enhanced screening procedures."
            )
        elif 'purchase_value' in feature.lower():
            recommendations.append(
                f"4. Unusual purchase value patterns (captured by '{feature}') should be flagged "
                f"for manual review."
            )
        else:
            recommendations.append(
                f"5. Feature '{feature}' is a strong fraud indicator. Consider implementing "
                f"business rules based on this feature's values."
            )
    
    # Print recommendations
    for i, rec in enumerate(recommendations[:5], 1):
        print(f"\nRecommendation {i}:")
        print(rec)
    
    print("\n" + "=" * 60)
    print("Additional Insights:")
    print("=" * 60)
    print("- Monitor false positives to understand legitimate transaction patterns")
    print("- Investigate false negatives to identify fraud patterns the model misses")
    print("- Regularly update model with new fraud patterns")
    print("- Combine model predictions with business rules for optimal performance")
else:
    print("Please run previous analysis cells first.")


## Part 2: creditcard.csv Model Explainability

(Similar analysis can be performed for creditcard dataset)
