# Model Explainability with SHAP

This notebook demonstrates how to interpret the fraud detection model using SHAP (SHapley Additive exPlanations).

In [None]:
import sys
sys.path.append('../src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
import joblib

from features import load_and_prepare_data, get_feature_names
from explain import (
    create_shap_explainer,
    plot_shap_summary,
    plot_shap_waterfall,
    explain_prediction
)

# Initialize SHAP's JavaScript visualization
shap.initjs()

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

## 1. Load Model and Data

In [None]:
# Load trained model and preprocessor
print("Loading model and preprocessor...")
model = joblib.load('../models/fraud_model.pkl')
preprocessor = joblib.load('../models/preprocessor.pkl')
print("✓ Model and preprocessor loaded")

# Load test data
print("\nLoading test data...")
_, X_test, _, y_test = load_and_prepare_data('../data/raw/transactions.csv')
X_test_processed = preprocessor.transform(X_test)
print(f"✓ Test set: {X_test_processed.shape}")

# Get feature names
feature_names = get_feature_names(preprocessor)
print(f"✓ Features: {len(feature_names)}")

## 2. Create SHAP Explainer

In [None]:
# Create SHAP explainer
print("Creating SHAP explainer...")
print("Note: This may take a few minutes for tree-based models")

explainer = shap.TreeExplainer(model)
print("✓ SHAP explainer created")

# Calculate SHAP values for a sample of test data
# Using 1000 samples for faster computation
sample_size = min(1000, len(X_test))
X_sample = X_test_processed[:sample_size]

print(f"\nCalculating SHAP values for {sample_size} samples...")
shap_values = explainer.shap_values(X_sample)
print("✓ SHAP values calculated")

## 3. Global Feature Importance - Summary Plot

In [None]:
# SHAP Summary Plot (Global Importance)
print("Generating SHAP summary plot...")
plt.figure(figsize=(10, 8))
shap.summary_plot(
    shap_values, 
    X_sample,
    feature_names=feature_names,
    max_display=15,
    show=False
)
plt.title('SHAP Feature Importance Summary', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('../reports/shap_summary.png', dpi=300, bbox_inches='tight')
plt.show()
print("✓ SHAP summary plot saved")

### Understanding the Summary Plot

The SHAP summary plot shows:
- **Y-axis**: Features ordered by importance (top = most important)
- **X-axis**: SHAP value (impact on model output)
  - Positive values push prediction toward fraud (right)
  - Negative values push prediction toward normal (left)
- **Color**: Feature value (red = high, blue = low)

**Key Insights:**
- High transaction amounts (red dots on right) increase fraud probability
- Country mismatches strongly indicate fraud
- Online channels have higher fraud risk than POS

In [None]:
# Bar plot of mean absolute SHAP values
plt.figure(figsize=(10, 8))
shap.summary_plot(
    shap_values,
    X_sample,
    feature_names=feature_names,
    plot_type='bar',
    max_display=15,
    show=False
)
plt.title('Mean Absolute SHAP Values (Global Feature Importance)', 
          fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Dependence Plots - Feature Interactions

In [None]:
# Dependence plot for transaction amount
plt.figure(figsize=(10, 6))
shap.dependence_plot(
    'amount_usd',
    shap_values,
    X_sample,
    feature_names=feature_names,
    show=False
)
plt.title('SHAP Dependence: Transaction Amount', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Interpretation: Shows how transaction amount affects fraud prediction")
print("Color indicates interaction with another feature")

In [None]:
# Dependence plot for country mismatch
if 'country_mismatch' in feature_names:
    plt.figure(figsize=(10, 6))
    shap.dependence_plot(
        'country_mismatch',
        shap_values,
        X_sample,
        feature_names=feature_names,
        show=False
    )
    plt.title('SHAP Dependence: Country Mismatch', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

## 5. Local Explanations - Individual Predictions

### Example 1: High-Risk Fraudulent Transaction

In [None]:
# Find a fraudulent transaction with high probability
y_pred_proba = model.predict_proba(X_test_processed)[:, 1]
fraud_indices = np.where(y_test == 1)[0]
high_fraud_idx = fraud_indices[np.argmax(y_pred_proba[fraud_indices])]

print(f"Analyzing transaction at index {high_fraud_idx}")
print(f"True label: {y_test.iloc[high_fraud_idx]} (Fraud)")
print(f"Predicted probability: {y_pred_proba[high_fraud_idx]:.4f}")
print(f"\nTransaction details:")
print(X_test.iloc[high_fraud_idx])

In [None]:
# SHAP force plot for this transaction
shap_values_single = explainer.shap_values(X_test_processed[high_fraud_idx:high_fraud_idx+1])

shap.force_plot(
    explainer.expected_value,
    shap_values_single,
    X_test_processed[high_fraud_idx:high_fraud_idx+1],
    feature_names=feature_names,
    matplotlib=True,
    show=False
)
plt.title(f'SHAP Force Plot - Fraud Transaction (P={y_pred_proba[high_fraud_idx]:.3f})',
          fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Red features push prediction toward fraud (increase probability)")
print("- Blue features push prediction toward normal (decrease probability)")
print("- Arrow width indicates feature impact strength")

In [None]:
# Waterfall plot - alternative visualization
plt.figure(figsize=(10, 8))
shap.waterfall_plot(
    shap.Explanation(
        values=shap_values_single[0],
        base_values=explainer.expected_value,
        data=X_test_processed[high_fraud_idx],
        feature_names=feature_names
    ),
    max_display=15,
    show=False
)
plt.title('Waterfall Plot - Feature Contributions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### Example 2: Normal Transaction Correctly Classified

In [None]:
# Find a normal transaction with low probability
normal_indices = np.where(y_test == 0)[0]
low_fraud_idx = normal_indices[np.argmin(y_pred_proba[normal_indices])]

print(f"Analyzing transaction at index {low_fraud_idx}")
print(f"True label: {y_test.iloc[low_fraud_idx]} (Normal)")
print(f"Predicted probability: {y_pred_proba[low_fraud_idx]:.4f}")
print(f"\nTransaction details:")
print(X_test.iloc[low_fraud_idx])

In [None]:
# SHAP force plot for normal transaction
shap_values_normal = explainer.shap_values(X_test_processed[low_fraud_idx:low_fraud_idx+1])

shap.force_plot(
    explainer.expected_value,
    shap_values_normal,
    X_test_processed[low_fraud_idx:low_fraud_idx+1],
    feature_names=feature_names,
    matplotlib=True,
    show=False
)
plt.title(f'SHAP Force Plot - Normal Transaction (P={y_pred_proba[low_fraud_idx]:.3f})',
          fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

### Example 3: False Positive Analysis

In [None]:
# Find false positives (normal transactions predicted as fraud)
false_positive_mask = (y_test == 0) & (y_pred_proba > 0.5)
if false_positive_mask.any():
    false_positive_indices = np.where(false_positive_mask)[0]
    fp_idx = false_positive_indices[0]
    
    print(f"Analyzing FALSE POSITIVE at index {fp_idx}")
    print(f"True label: {y_test.iloc[fp_idx]} (Normal)")
    print(f"Predicted probability: {y_pred_proba[fp_idx]:.4f}")
    print(f"\nTransaction details:")
    print(X_test.iloc[fp_idx])
    
    # SHAP explanation
    shap_values_fp = explainer.shap_values(X_test_processed[fp_idx:fp_idx+1])
    
    plt.figure(figsize=(10, 8))
    shap.waterfall_plot(
        shap.Explanation(
            values=shap_values_fp[0],
            base_values=explainer.expected_value,
            data=X_test_processed[fp_idx],
            feature_names=feature_names
        ),
        max_display=15,
        show=False
    )
    plt.title('False Positive - Why was this flagged as fraud?', 
              fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print("\nInsight: These features made the model suspicious despite being normal")
else:
    print("No false positives found at 0.5 threshold")

## 6. Feature Interaction Analysis

In [None]:
# Calculate SHAP interaction values (computationally intensive)
# Using smaller sample for demonstration
small_sample_size = 100
X_small = X_test_processed[:small_sample_size]

print(f"Calculating SHAP interaction values for {small_sample_size} samples...")
print("Note: This may take several minutes")

# Uncomment to run (warning: slow)
# shap_interaction_values = explainer.shap_interaction_values(X_small)
# shap.summary_plot(shap_interaction_values, X_small, feature_names=feature_names)

print("Interaction analysis skipped for performance (uncomment to enable)")

## 7. Decision Plot - Multiple Predictions

In [None]:
# Decision plot for multiple transactions
sample_indices = np.random.choice(len(X_test), 20, replace=False)
X_decision_sample = X_test_processed[sample_indices]
shap_values_decision = explainer.shap_values(X_decision_sample)

plt.figure(figsize=(12, 8))
shap.decision_plot(
    explainer.expected_value,
    shap_values_decision,
    X_decision_sample,
    feature_names=feature_names,
    feature_display_range=slice(-1, -16, -1),  # Top 15 features
    show=False
)
plt.title('SHAP Decision Plot - 20 Random Transactions', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- Each line represents one transaction")
print("- Lines start at base value and show cumulative feature effects")
print("- Final position indicates predicted fraud probability")

## 8. Export Explanations for Dashboard

In [None]:
# Save example SHAP values for dashboard use
import pickle

shap_data = {
    'explainer': explainer,
    'sample_shap_values': shap_values[:100],  # First 100 samples
    'sample_data': X_sample[:100],
    'feature_names': feature_names
}

with open('../models/shap_explainer.pkl', 'wb') as f:
    pickle.dump(shap_data, f)

print("✓ SHAP explainer saved to ../models/shap_explainer.pkl")
print("This can be loaded in the Streamlit dashboard for interactive explanations")

## Summary: Key Findings

**Most Important Fraud Indicators:**
1. **Transaction Amount**: Higher amounts strongly correlate with fraud
2. **Country Mismatch**: Transactions from unusual countries are risky
3. **Channel**: Online transactions have higher fraud rates
4. **Card Not Present**: CNP transactions are riskier
5. **Customer Behavior**: Deviation from normal spending patterns

**Model Transparency:**
- SHAP provides clear explanations for every prediction
- Analysts can see exactly which features drove each decision
- Helps identify model biases and areas for improvement
- Enables regulatory compliance and audit trails

**Business Value:**
- Fraud investigators can prioritize cases based on explanation confidence
- Model decisions are interpretable and defensible
- Helps identify new fraud patterns as they emerge