In [1]:
### Notebook 5: Model Interpretation & Explainability
### Project: Churn Prevention System
### This notebook explains model predictions using SHAP values

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap
import warnings
warnings.filterwarnings('ignore')

print("=" * 70)
print("MODEL INTERPRETATION & EXPLAINABILITY WITH SHAP")
print("=" * 70)

### =============================================================================
### PART 1: LOAD MODEL AND DATA
### =============================================================================

print("\n" + "=" * 70)
print("PART 1: LOADING TRAINED MODEL & DATA")
print("=" * 70)

### Load trained model
model = joblib.load(r'../ML Models/churn_model.pkl')
scaler = joblib.load(r'../ML Models/scaler.pkl')
model_info = joblib.load(r'../ML Models/model_info.pkl')

print(f"✅ Loaded model: {model_info['model_type']}")
print(f"   AUC Score: {model_info['auc_score']:.4f}")

### Load engineered data
df = pd.read_csv('../Datasets/customer_churn_engineered.csv')
print(f"✅ Loaded dataset: {len(df):,} customers")

### Prepare features
feature_cols = model_info['feature_columns']
X = df[feature_cols]
y = df['churned']

print(f"✅ Features: {len(feature_cols)}")

### =============================================================================
### PART 2: SHAP EXPLAINER SETUP
### =============================================================================

print("\n" + "=" * 70)
print("PART 2: INITIALIZING SHAP EXPLAINER")
print("=" * 70)

print("\nCreating SHAP explainer...")
print("(This may take 1-2 minutes for large datasets)")

### Use a sample for faster computation
sample_size = min(1000, len(X))
X_sample = X.sample(n=sample_size, random_state=42)
y_sample = y[X_sample.index]

### Create SHAP explainer
explainer = shap.LinearExplainer(model, X_sample)
shap_values = explainer.shap_values(X_sample)


### Handle binary classification (some models return list)
if isinstance(shap_values, list):
    shap_values = shap_values[1]  # Churn class

print(f"✅ SHAP values calculated for {sample_size:,} samples")
print(f"   Shape: {shap_values.shape}")

### =============================================================================
### PART 3: GLOBAL FEATURE IMPORTANCE
### =============================================================================

print("\n" + "=" * 70)
print("PART 3: GLOBAL FEATURE IMPORTANCE ANALYSIS")
print("=" * 70)

### Calculate mean absolute SHAP values
mean_abs_shap = np.abs(shap_values).mean(axis=0)
feature_importance_shap = pd.DataFrame({
    'feature': feature_cols,
    'importance': mean_abs_shap
}).sort_values('importance', ascending=False)

print("\n🎯 Top 15 Most Important Features (by SHAP):")
print("-" * 70)
for i, row in feature_importance_shap.head(15).iterrows():
    print(f"{i+1:2d}. {row['feature']:35s} {row['importance']:.4f}")

### Compare with model's feature importance (if available)
if hasattr(model, 'feature_importances_'):
    model_importance = pd.DataFrame({
        'feature': feature_cols,
        'model_importance': model.feature_importances_
    }).sort_values('model_importance', ascending=False)
    
    print("\n📊 Comparison: SHAP vs Model Feature Importance")
    print("-" * 70)
    print(f"{'Feature':<35} {'SHAP':>10} {'Model':>10}")
    print("-" * 70)
    
    for feature in feature_importance_shap.head(10)['feature']:
        shap_imp = feature_importance_shap[
            feature_importance_shap['feature']==feature
        ]['importance'].values[0]
        model_imp = model_importance[
            model_importance['feature']==feature
        ]['model_importance'].values[0]
        print(f"{feature:<35} {shap_imp:>10.4f} {model_imp:>10.4f}")

### =============================================================================
### PART 4: SHAP SUMMARY VISUALIZATIONS
### =============================================================================

print("\n" + "=" * 70)
print("PART 4: CREATING SHAP VISUALIZATIONS")
print("=" * 70)

### 1. Summary Plot (Feature Importance + Impact Direction)
print("\nCreating SHAP summary plot...")
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_sample, feature_names=feature_cols, show=False)
plt.title('SHAP Summary Plot: Feature Impact on Churn Prediction', 
          fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('../Datasets/shap_summary_plot.png', dpi=300, bbox_inches='tight')
print("✅ Saved: ../data/shap_summary_plot.png")
plt.close()

### 2. Bar Plot (Feature Importance Ranking)
print("\nCreating SHAP importance bar plot...")
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_sample, feature_names=feature_cols,
                 plot_type="bar", show=False)
plt.title('SHAP Feature Importance Ranking',
          fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('../Datasets/shap_importance_bar.png', dpi=300, bbox_inches='tight')
print("✅ Saved: ../data/shap_importance_bar.png")
plt.close()

### =============================================================================
### PART 5: INDIVIDUAL PREDICTION EXPLANATIONS
### =============================================================================

print("\n" + "=" * 70)
print("PART 5: INDIVIDUAL PREDICTION EXPLANATIONS")
print("=" * 70)

### Select interesting cases for explanation
print("\nSelecting representative customers for analysis...")

### Case 1: High risk customer
high_risk_idx = df[df['churned']==1].sample(1, random_state=42).index[0]
high_risk_customer = df.loc[high_risk_idx]

### Case 2: Low risk customer
low_risk_idx = df[df['churned']==0].sample(1, random_state=42).index[0]
low_risk_customer = df.loc[low_risk_idx]

### Case 3: Borderline case
subset = df[(df['churned']==1) & (df['health_score'] > df['health_score'].median())]

if subset.empty:
    print("⚠️ No rows found above median health_score — selecting any churned customer instead.")
    subset = df[df['churned']==1]

borderline_idx = subset.sample(1, random_state=42).index[0]
borderline_customer = df.loc[borderline_idx]

print(f"\n📋 Selected Cases:")
print(f"   1. High Risk: {high_risk_customer['customer_id']}")
print(f"   2. Low Risk: {low_risk_customer['customer_id']}")
print(f"   3. Borderline: {borderline_customer['customer_id']}")

### Explain each case
cases = [
    (high_risk_idx, high_risk_customer, "High Risk Customer"),
    (low_risk_idx, low_risk_customer, "Low Risk Customer"),
    (borderline_idx, borderline_customer, "Borderline Case")
]

for case_idx, (idx, customer, title) in enumerate(cases, 1):
    print(f"\n{'-'*70}")
    print(f"Case {case_idx}: {title} ({customer['customer_id']})")
    print(f"{'-'*70}")
    
    ### Get prediction
    X_case = X.loc[idx:idx]
    prediction = model.predict(X_case)[0]
    probability = model.predict_proba(X_case)[0][1]
    
    print(f"\nPrediction: {'CHURN' if prediction == 1 else 'RETAIN'}")
    print(f"Churn Probability: {probability*100:.1f}%")
    print(f"Actual: {'CHURNED' if customer['churned'] == 1 else 'RETAINED'}")
    
    ### Customer profile
    print(f"\nCustomer Profile:")
    print(f"  • Subscription: {customer['subscription_tier']}")
    print(f"  • Monthly Reoccuring Revenue: ${customer['monthly_reoccuring_revenue']:.0f}")
    print(f"  • Tenure: {customer['tenure_days']} days")
    print(f"  • Health Score: {customer['health_score']:.1f}/100")
    print(f"  • Logins (30d): {customer['logins_30d']}")
    print(f"  • Support Tickets: {customer['support_tickets_30d']}")
    
    ### Calculate SHAP for this customer
    shap_case = explainer.shap_values(X_case)
    if isinstance(shap_case, list):
        shap_case = shap_case[1]
    
    ### Create waterfall plot
    plt.figure(figsize=(10, 8))
    shap.waterfall_plot(
        shap.Explanation(
            values=shap_case[0],
            base_values=explainer.expected_value if not isinstance(
                explainer.expected_value, np.ndarray
            ) else explainer.expected_value[1],
            data=X_case.iloc[0].values,
            feature_names=feature_cols
        ),
        show=False
    )
    plt.title(f'{title}: SHAP Waterfall Explanation',
             fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'../Datasets/shap_waterfall_case{case_idx}.png', dpi=300, bbox_inches='tight')
    print(f"✅ Saved waterfall plot: ../data/shap_waterfall_case{case_idx}.png")
    plt.close()
    
    ### Top contributing factors
    shap_contributions = pd.DataFrame({
        'feature': feature_cols,
        'shap_value': shap_case[0]
    })
    shap_contributions['abs_shap'] = shap_contributions['shap_value'].abs()
    top_factors = shap_contributions.nlargest(5, 'abs_shap')
    
    print(f"\n🔍 Top 5 Factors Influencing Prediction:")
    for i, row in top_factors.iterrows():
        direction = "increases" if row['shap_value'] > 0 else "decreases"
        impact = "🔴" if row['shap_value'] > 0 else "🟢"
        print(f"   {impact} {row['feature']}: {direction} churn risk by {abs(row['shap_value']):.3f}")

### =============================================================================
### PART 6: DEPENDENCE PLOTS
### =============================================================================

print("\n" + "=" * 70)
print("PART 6: SHAP DEPENDENCE PLOTS")
print("=" * 70)

print("\nCreating dependence plots for top features...")

### Select top 6 features
top_features = feature_importance_shap.head(6)['feature'].tolist()

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.ravel()

for idx, feature in enumerate(top_features):
    feature_idx = feature_cols.index(feature)
    
    shap.dependence_plot(
        feature_idx,
        shap_values,
        X_sample,
        feature_names=feature_cols,
        ax=axes[idx],
        show=False
    )
    axes[idx].set_title(f'SHAP Dependence: {feature}', 
                       fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('../Datasets/shap_dependence_plots.png', dpi=300, bbox_inches='tight')
print("✅ Saved: ../Datasets/shap_dependence_plots.png")
plt.close()

### =============================================================================
### PART 7: FORCE PLOTS (HTML)
### =============================================================================

print("\n" + "=" * 70)
print("PART 7: CREATING INTERACTIVE FORCE PLOTS")
print("=" * 70)

print("\nGenerating interactive HTML force plots...")

### Single prediction force plot
idx_sample = X_sample.index[0]
shap_single = explainer.shap_values(X.loc[idx_sample:idx_sample])
if isinstance(shap_single, list):
    shap_single = shap_single[1]

expected_value = (explainer.expected_value if not isinstance(explainer.expected_value, np.ndarray) 
                 else explainer.expected_value[1])

force_plot = shap.force_plot(
    expected_value,
    shap_single[0],
    X.loc[idx_sample],
    feature_names=feature_cols
)

### Save as HTML
shap.save_html('../Datasets/shap_force_plot.html', force_plot)
print("✅ Saved: ../Datasets/shap_force_plot.html")

### Multiple predictions force plot
print("\nGenerating multi-customer force plot...")
force_plot_multi = shap.force_plot(
    expected_value,
    shap_values[:100],
    X_sample.iloc[:100],
    feature_names=feature_cols
)

shap.save_html('../Datasets/shap_force_plot_multi.html', force_plot_multi)
print("✅ Saved: ../Datasets/shap_force_plot_multi.html")

### =============================================================================
### PART 8: INTERACTION EFFECTS
### =============================================================================

print("\n" + "=" * 70)
print("PART 8: FEATURE INTERACTION ANALYSIS")
print("=" * 70)

print("\nAnalyzing feature interactions...")
print("(This may take a few minutes)")

### Calculate interaction values for top 2 features
top_2_features = feature_importance_shap.head(2)['feature'].tolist()
print(f"\nAnalyzing interaction between:")
print(f"  • {top_2_features[0]}")
print(f"  • {top_2_features[1]}")

### Use smaller sample for interactions (computational intensive)
X_interaction_sample = X_sample.sample(n=min(300, len(X_sample)), random_state=42)

if hasattr(explainer, "shap_interaction_values"):
    shap_interaction_values = explainer.shap_interaction_values(X_interaction_sample)
    if isinstance(shap_interaction_values, list):
        shap_interaction_values = shap_interaction_values[1]
else:
    print("⚠️ SHAP interaction values not supported for this model type (LinearExplainer). Skipping interaction analysis.")
    shap_interaction_values = None

### Identify top 2 interacting features (fallback if not available)
if shap_interaction_values is not None:
    # Sum absolute interactions
    interaction_strength = np.abs(shap_interaction_values).sum(axis=0)
    np.fill_diagonal(interaction_strength, 0)
    top_2_indices = np.unravel_index(np.argmax(interaction_strength), interaction_strength.shape)
    top_2_features = [feature_cols[top_2_indices[0]], feature_cols[top_2_indices[1]]]

    print(f"📊 Top interacting features: {top_2_features}")
else:
    top_2_features = feature_cols[:2]  # fallback to first two features
    print(f"⚠️ Using fallback features for plot: {top_2_features}")

### Plot interaction (only if available)

if shap_interaction_values is not None:
    ### Plot interaction
    feature1_idx = feature_cols.index(top_2_features[0])
    feature2_idx = feature_cols.index(top_2_features[1])

    plt.figure(figsize=(10, 8))
    shap.dependence_plot(
        (feature1_idx, feature2_idx),
        shap_interaction_values,
        X_interaction_sample,
        feature_names=feature_cols,
        show=False
    )
    plt.title(f'Interaction: {top_2_features[0]} & {top_2_features[1]}',
             fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('../Datasets/shap_interaction_plot.png', dpi=300, bbox_inches='tight')
    print("✅ Saved: ../Datasets/shap_interaction_plot.png")
    plt.close()
else:
    print("⏩ Skipped interaction plot (not supported for linear models).")

### =============================================================================
### PART 9: COHORT ANALYSIS
### =============================================================================

print("\n" + "=" * 70)
print("PART 9: SHAP ANALYSIS BY CUSTOMER COHORTS")
print("=" * 70)

print("\nAnalyzing SHAP values by subscription tier...")

### Split by subscription tier
subscription_Tiers = df.loc[X_sample.index, 'subscription_tier'].unique()

for subscription_tier in subscription_Tiers:
    subscription_tier_mask = df.loc[X_sample.index, 'subscription_tier'] == subscription_tier
    subscription_tier_indices = X_sample.index[subscription_tier_mask]
    
    if len(subscription_tier_indices) < 10:
        continue
    
    print(f"\n{subscription_tier.upper()} TIER ({len(subscription_tier_indices)} customers):")
    
    ### Calculate mean SHAP for this tier
    subscription_tier_shap = shap_values[subscription_tier_mask]
    mean_subscription_tier_shap = np.abs(subscription_tier_shap).mean(axis=0)
    
    subscription_tier_importance = pd.DataFrame({
        'feature': feature_cols,
        'importance': mean_subscription_tier_shap
    }).sort_values('importance', ascending=False)
    
    print("  Top 5 Features:")
    for i, row in subscription_tier_importance.head(5).iterrows():
        print(f"    {i+1}. {row['feature']}: {row['importance']:.4f}")

### =============================================================================
### PART 10: ACTIONABLE INSIGHTS
### =============================================================================

print("\n" + "=" * 70)
print("PART 10: ACTIONABLE INSIGHTS FROM SHAP ANALYSIS")
print("=" * 70)

print("\n💡 KEY INSIGHTS FOR PRODUCT & CS TEAMS:")

### Get top positive and negative SHAP contributors
avg_shap_by_feature = pd.DataFrame({
    'feature': feature_cols,
    'avg_shap': shap_values.mean(axis=0)
}).sort_values('avg_shap')

top_churn_drivers = avg_shap_by_feature.tail(5)
top_retention_factors = avg_shap_by_feature.head(5)

print("\n🔴 TOP CHURN DRIVERS (What Increases Risk):")
for i, row in top_churn_drivers.iterrows():
    print(f"   {i+1}. {row['feature']}: +{row['avg_shap']:.4f}")
    
    ### Provide actionable recommendations
    if 'days_since_last_login' in row['feature']:
        print("      → Action: Implement re-engagement campaigns for inactive users")
    elif 'logins' in row['feature']:
        print("      → Action: Send usage tips and feature highlights to boost engagement")
    elif 'support' in row['feature']:
        print("      → Action: Proactive CS outreach for users with support issues")
    elif 'payment' in row['feature']:
        print("      → Action: Automated billing issue resolution")
    elif 'health_score' in row['feature']:
        print("      → Action: Monitor health score trends, intervene early")

print("\n🟢 TOP RETENTION FACTORS (What Decreases Risk):")
for i, row in top_retention_factors.iterrows():
    print(f"   {i+1}. {row['feature']}: {row['avg_shap']:.4f}")
    
    ### Provide actionable recommendations
    if 'engagement' in row['feature']:
        print("      → Strategy: Encourage high engagement through gamification")
    elif 'features_used' in row['feature']:
        print("      → Strategy: Feature adoption programs and tutorials")
    elif 'premium' in row['feature'] or 'tier' in row['feature']:
        print("      → Strategy: Incentivize upgrades with trial periods")
    elif 'net_profit_score' in row['feature']:
        print("      → Strategy: Focus on customer delight and advocacy programs")

### =============================================================================
### PART 11: MODEL INTERPRETATION SUMMARY
### =============================================================================

print("\n" + "=" * 70)
print("PART 11: INTERPRETATION SUMMARY REPORT")
print("=" * 70)

### Create summary report
summary_report = f"""
╔══════════════════════════════════════════════════════════════════════╗
║           MODEL INTERPRETATION & EXPLAINABILITY REPORT               ║
╚══════════════════════════════════════════════════════════════════════╝

MODEL INFORMATION:
  • Model Type: {model_info['model_type']}
  • AUC Score: {model_info['auc_score']:.4f}
  • Features: {len(feature_cols)}
  • Samples Analyzed: {len(X_sample):,}

TOP 5 MOST IMPORTANT FEATURES (GLOBAL):
"""

for i, row in feature_importance_shap.head(5).iterrows():
    summary_report += f"  {i+1}. {row['feature']:<35} {row['importance']:.4f}\n"

summary_report += f"""
PREDICTION DRIVERS:
  • Primary churn drivers: {', '.join(top_churn_drivers.tail(3)['feature'].tolist())}
  • Primary retention factors: {', '.join(top_retention_factors.head(3)['feature'].tolist())}

INTERPRETABILITY BENEFITS:
  ✓ Model decisions are explainable to stakeholders
  ✓ Customer-specific interventions can be tailored
  ✓ Feature engineering can be guided by SHAP insights
  ✓ Compliance and audit requirements can be met
  ✓ Trust in AI system increased through transparency

BUSINESS APPLICATIONS:
  1. Customer Success: Personalized intervention strategies
  2. Product: Feature prioritization based on retention impact
  3. Marketing: Target messaging based on risk factors
  4. Executive: Clear explanation of churn drivers

FILES CREATED:
  • shap_summary_plot.png (feature impact overview)
  • shap_importance_bar.png (feature ranking)
  • shap_waterfall_case[1-3].png (individual explanations)
  • shap_dependence_plots.png (feature relationships)
  • shap_force_plot.html (interactive single prediction)
  • shap_force_plot_multi.html (interactive multi-prediction)
  • shap_interaction_plot.png (feature interactions)
"""

print(summary_report)

# Save report
with open('../Datasets/interpretation_report.txt', 'w') as f:
    f.write(summary_report)
print("\n✅ Saved: ../data/interpretation_report.txt")

### =============================================================================
### PART 12: EXPORT SHAP VALUES
### =============================================================================

print("\n" + "=" * 70)
print("PART 12: EXPORTING SHAP VALUES FOR DASHBOARD")
print("=" * 70)

### Create dataframe with SHAP values
shap_df = pd.DataFrame(shap_values, columns=feature_cols)
shap_df['customer_id'] = df.loc[X_sample.index, 'customer_id'].values
shap_df['actual_churn'] = y_sample.values
shap_df['predicted_churn'] = model.predict(X_sample)
shap_df['churn_probability'] = model.predict_proba(X_sample)[:, 1]

### Add top contributing features for each customer
def get_top_features(row, n=3):
    feature_impacts = []
    for col in feature_cols:
        feature_impacts.append({
            'feature': col,
            'impact': abs(row[col])
        })
    feature_impacts.sort(key=lambda x: x['impact'], reverse=True)
    return [f['feature'] for f in feature_impacts[:n]]

shap_df['top_risk_factors'] = shap_df[feature_cols].apply(
    lambda row: ', '.join(get_top_features(row, n=3)), axis=1
)

### Save
shap_export_file = '../Datasets/shap_values_export.csv'
shap_df.to_csv(shap_export_file, index=False)
print(f"\n✅ Saved SHAP values: {shap_export_file}")
print(f"   Columns: {list(shap_df.columns[:5])}... + {len(shap_df.columns)-5} more")

### =============================================================================
### SUMMARY
### =============================================================================

print("\n" + "=" * 70)
print("✅ MODEL INTERPRETATION COMPLETE!")
print("=" * 70)

print(f"""
Interpretation Analysis Summary:
  • Samples analyzed: {len(X_sample):,}
  • Features evaluated: {len(feature_cols)}
  • Individual cases explained: 3
  • Cohorts analyzed: {len(subscription_Tiers)}
  
Key Findings:
  • Most important feature: {feature_importance_shap.iloc[0]['feature']}
  • Strongest churn driver: {top_churn_drivers.iloc[-1]['feature']}
  • Strongest retention factor: {top_retention_factors.iloc[0]['feature']}
  
Files Created:
  1. shap_summary_plot.png (global importance)
  2. shap_importance_bar.png (ranking)
  3. shap_waterfall_case[1-3].png (individual explanations)
  4. shap_dependence_plots.png (6 feature dependencies)
  5. shap_force_plot.html (interactive visualization)
  6. shap_force_plot_multi.html (multi-customer view)
  7. shap_interaction_plot.png (feature interactions)
  8. interpretation_report.txt (summary report)
  9. shap_values_export.csv (SHAP values dataset)

Business Value:
  ✓ Explainable AI for stakeholder trust
  ✓ Personalized intervention recommendations
  ✓ Regulatory compliance capability
  ✓ Product insight generation
  ✓ Feature development prioritization
""")

print("\n🎉 All analysis notebooks complete!")
print("🚀 Ready to deploy dashboard with full explainability!")
print("=" * 70)

  from .autonotebook import tqdm as notebook_tqdm


MODEL INTERPRETATION & EXPLAINABILITY WITH SHAP

PART 1: LOADING TRAINED MODEL & DATA
✅ Loaded model: Logistic Regression
   AUC Score: 1.0000
✅ Loaded dataset: 5,000 customers
✅ Features: 44

PART 2: INITIALIZING SHAP EXPLAINER

Creating SHAP explainer...
(This may take 1-2 minutes for large datasets)
✅ SHAP values calculated for 1,000 samples
   Shape: (1000, 44)

PART 3: GLOBAL FEATURE IMPORTANCE ANALYSIS

🎯 Top 15 Most Important Features (by SHAP):
----------------------------------------------------------------------
 1. tenure_days                         49.2829
33. login_feature_interaction           41.2018
31. engagement_tenure                   31.6946
 4. session_duration_avg                11.3413
18. health_score                        10.7999
30. estimated_ltv                       9.2785
 3. logins_30d                          8.5867
13. engagement_score                    5.5050
28. value_realization                   3.3770
 7. days_since_last_login               3.30