# SHAP Interpretability with py-tidymodels

This notebook demonstrates model interpretability using SHAP (SHapley Additive exPlanations).

**Topics Covered:**
1. SHAP explanations for different model types (tree, linear)
2. Auto-explainer selection
3. Global feature importance
4. Local explanations (single observations)
5. SHAP with workflows and recipes
6. Grouped model SHAP (per-group feature importance)
7. Identifying prediction errors with SHAP

**Use Case:** Customer churn prediction with explanations

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from py_parsnip import linear_reg, rand_forest, decision_tree
from py_workflows import workflow
from py_recipes import recipe
from py_interpret import ShapEngine

# Set random seed
np.random.seed(42)
sns.set_style('whitegrid')

print("Imports successful!")

## 1. Generate Customer Churn Data

Create realistic customer data with:
- Demographics (age, tenure)
- Engagement metrics (login frequency, feature usage)
- Service attributes (plan type, support calls)
- Target: Churn probability

In [None]:
# Generate synthetic customer churn data
n = 300

# Customer demographics
age = np.random.randint(18, 70, n)
tenure_months = np.random.randint(1, 60, n)

# Engagement metrics
logins_per_month = np.random.poisson(10, n)
features_used = np.random.randint(0, 20, n)
avg_session_minutes = np.random.exponential(30, n)

# Service attributes
plan_type = np.random.choice(['Basic', 'Premium', 'Enterprise'], n, p=[0.5, 0.3, 0.2])
support_calls = np.random.poisson(2, n)

# Create churn probability (hidden true relationship)
# Higher churn if:
# - Short tenure
# - Low engagement
# - Many support calls
# - Basic plan
churn_logit = (
    -2.0 +  # Baseline
    -0.05 * tenure_months +  # Longer tenure = lower churn
    -0.1 * logins_per_month +  # More logins = lower churn
    -0.05 * features_used +  # More features = lower churn
    0.3 * support_calls +  # More support = higher churn
    -0.02 * age +  # Older customers = lower churn
    (1.0 if plan_type == 'Basic' else -0.5)  # Basic plan = higher churn
)

churn_prob = 1 / (1 + np.exp(-churn_logit))
churn = (np.random.rand(n) < churn_prob).astype(float)

# Create DataFrame
data = pd.DataFrame({
    'churn': churn,
    'age': age,
    'tenure_months': tenure_months,
    'logins_per_month': logins_per_month,
    'features_used': features_used,
    'avg_session_minutes': avg_session_minutes,
    'plan_type': plan_type,
    'support_calls': support_calls
})

# For regression demo, use continuous target
data['churn_risk'] = churn_prob + np.random.randn(n) * 0.1

# Split into train/test
train_data = data.iloc[:240]
test_data = data.iloc[240:]

print(f"Training data: {len(train_data)} observations")
print(f"Test data: {len(test_data)} observations")
print(f"\nChurn rate: {data['churn'].mean():.1%}")
print(f"\nData summary:")
print(data.describe())

In [None]:
# Visualize feature distributions by churn
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

features_to_plot = ['age', 'tenure_months', 'logins_per_month',
                     'features_used', 'avg_session_minutes', 'support_calls']

for idx, feature in enumerate(features_to_plot):
    ax = axes[idx // 3, idx % 3]

    churned = train_data[train_data['churn'] == 1][feature]
    retained = train_data[train_data['churn'] == 0][feature]

    ax.hist(retained, alpha=0.5, label='Retained', bins=20)
    ax.hist(churned, alpha=0.5, label='Churned', bins=20)
    ax.set_xlabel(feature)
    ax.set_ylabel('Count')
    ax.set_title(f'{feature.replace("_", " ").title()} Distribution')
    ax.legend()

plt.tight_layout()
plt.show()

## 2. Basic SHAP with Linear Regression

SHAP auto-selects LinearExplainer for linear models (fast and exact).

In [None]:
# Fit linear regression model
spec_linear = linear_reg()
fit_linear = spec_linear.fit(train_data, 'churn_risk ~ age + tenure_months + logins_per_month + features_used + support_calls')

print("Linear model fitted!")
print(f"Model type: {fit_linear.spec.model_type}")
print(f"Engine: {fit_linear.spec.engine}")

In [None]:
# Compute SHAP values (auto-selects LinearExplainer)
shap_linear = fit_linear.explain(test_data, check_additivity=False)

print("SHAP values computed!")
print(f"\nSHAP DataFrame shape: {shap_linear.shape}")
print(f"Columns: {shap_linear.columns.tolist()}")
print(f"\nFirst 10 rows:")
print(shap_linear.head(10))

In [None]:
# Global feature importance (mean |SHAP|)
importance_linear = shap_linear.groupby('variable')['abs_shap'].mean().sort_values(ascending=False)

print("\nGlobal Feature Importance (Linear Model):")
for var, imp in importance_linear.items():
    print(f"  {var:<25} {imp:.4f}")

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
importance_linear.plot(kind='barh', ax=ax, color='steelblue')
ax.set_xlabel('Mean |SHAP Value|')
ax.set_ylabel('Feature')
ax.set_title('Global Feature Importance (Linear Model)')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

## 3. SHAP with Random Forest (TreeExplainer)

TreeExplainer is fast and exact for tree-based models.

In [None]:
# Fit random forest model
spec_rf = rand_forest(trees=100, min_n=5).set_mode('regression')
fit_rf = spec_rf.fit(train_data, 'churn_risk ~ age + tenure_months + logins_per_month + features_used + support_calls')

print("Random Forest model fitted!")
print(f"Trees: {fit_rf.spec.args.get('trees', 100)}")

In [None]:
# Compute SHAP (auto-selects TreeExplainer - fast!)
shap_rf = fit_rf.explain(test_data, check_additivity=False)

print("SHAP values computed with TreeExplainer!")
print(f"Shape: {shap_rf.shape}")

# Global importance
importance_rf = shap_rf.groupby('variable')['abs_shap'].mean().sort_values(ascending=False)

print("\nFeature Importance (Random Forest):")
for var, imp in importance_rf.items():
    print(f"  {var:<25} {imp:.4f}")

In [None]:
# Compare feature importance: Linear vs Random Forest
comparison_df = pd.DataFrame({
    'Linear': importance_linear,
    'Random_Forest': importance_rf
}).fillna(0)

fig, ax = plt.subplots(figsize=(12, 6))
comparison_df.plot(kind='barh', ax=ax, width=0.8)
ax.set_xlabel('Mean |SHAP Value|')
ax.set_ylabel('Feature')
ax.set_title('Feature Importance Comparison: Linear vs Random Forest')
ax.legend(title='Model Type')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("- Random Forest captures non-linear relationships")
print("- Different models may emphasize different features")
print("- SHAP provides model-agnostic comparison")

## 4. Local Explanations (Single Observation)

Explain individual predictions with SHAP waterfall plots.

In [None]:
# Select a high-risk customer (high predicted churn)
preds_rf = fit_rf.predict(test_data, type='numeric')
test_with_preds = test_data.copy()
test_with_preds['pred_risk'] = preds_rf['.pred'].values

# Find highest risk customer
high_risk_idx = test_with_preds['pred_risk'].idxmax()
high_risk_customer = test_data.loc[[high_risk_idx]]

print("High-Risk Customer Profile:")
print(high_risk_customer.T)
print(f"\nPredicted churn risk: {test_with_preds.loc[high_risk_idx, 'pred_risk']:.3f}")

In [None]:
# Compute SHAP for this customer
shap_single = fit_rf.explain(high_risk_customer, check_additivity=False)

print(f"\nSHAP Explanation for Customer {high_risk_idx}:")
print(f"Base value (average prediction): {shap_single['base_value'].iloc[0]:.4f}")
print(f"Final prediction: {shap_single['prediction'].iloc[0]:.4f}")
print(f"\nFeature Contributions:")

# Sort by absolute impact
shap_sorted = shap_single.sort_values('abs_shap', ascending=False)

for _, row in shap_sorted.iterrows():
    sign = "+" if row['shap_value'] >= 0 else ""
    direction = "↑ increases" if row['shap_value'] >= 0 else "↓ decreases"
    print(f"  {row['variable']:<25} {sign}{row['shap_value']:>8.4f}  {direction} risk (value={row['feature_value']:.2f})")

# Verify additivity
total_shap = shap_single['shap_value'].sum()
base = shap_single['base_value'].iloc[0]
pred = shap_single['prediction'].iloc[0]
print(f"\nAdditivity Check:")
print(f"  sum(SHAP) + base = {total_shap + base:.4f}")
print(f"  prediction        = {pred:.4f}")
print(f"  difference        = {abs((total_shap + base) - pred):.6f}")

In [None]:
# Visualize waterfall plot
def plot_waterfall(shap_df, customer_id):
    """Create waterfall plot for SHAP values."""
    shap_sorted = shap_df.sort_values('shap_value', ascending=True)

    base_value = shap_sorted['base_value'].iloc[0]
    prediction = shap_sorted['prediction'].iloc[0]

    fig, ax = plt.subplots(figsize=(10, 8))

    # Start with base value
    y_pos = 0
    cumulative = base_value

    colors = ['red' if v < 0 else 'green' for v in shap_sorted['shap_value']]

    for idx, (_, row) in enumerate(shap_sorted.iterrows()):
        shap_val = row['shap_value']
        next_cumulative = cumulative + shap_val

        # Draw bar
        ax.barh(
            y_pos,
            abs(shap_val),
            left=min(cumulative, next_cumulative),
            color=colors[idx],
            alpha=0.7,
            edgecolor='black'
        )

        # Label
        label = f"{row['variable'][:20]}\n{row['feature_value']:.2f}"
        ax.text(
            cumulative + shap_val/2,
            y_pos,
            f"{shap_val:+.3f}",
            ha='center',
            va='center',
            fontweight='bold',
            fontsize=9
        )

        ax.text(
            ax.get_xlim()[0],
            y_pos,
            label,
            ha='right',
            va='center',
            fontsize=8
        )

        cumulative = next_cumulative
        y_pos += 1

    # Add base and prediction lines
    ax.axvline(base_value, color='blue', linestyle='--', label=f'Base: {base_value:.3f}', linewidth=2)
    ax.axvline(prediction, color='red', linestyle='--', label=f'Prediction: {prediction:.3f}', linewidth=2)

    ax.set_xlabel('Churn Risk')
    ax.set_yticks([])
    ax.set_title(f'SHAP Waterfall Plot - Customer {customer_id}')
    ax.legend(loc='best')
    ax.grid(True, alpha=0.3, axis='x')

    plt.tight_layout()
    plt.show()

plot_waterfall(shap_single, high_risk_idx)

## 5. SHAP with Workflows and Recipes

SHAP works seamlessly with preprocessing pipelines.

In [None]:
# Create workflow with recipe
rec = (
    recipe()
    .step_normalize()  # Normalize numeric features
)

wf = workflow().add_recipe(rec).add_model(
    rand_forest(trees=100, min_n=5).set_mode('regression')
)

wf_fit = wf.fit(train_data)
print("Workflow with recipe fitted!")

In [None]:
# SHAP on normalized features (applied automatically)
shap_wf = wf_fit.explain(test_data, check_additivity=False)

importance_wf = shap_wf.groupby('variable')['abs_shap'].mean().sort_values(ascending=False)

print("Feature Importance (with normalization):")
for var, imp in importance_wf.items():
    print(f"  {var:<25} {imp:.4f}")

# Compare with non-normalized
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

importance_rf.plot(kind='barh', ax=axes[0], color='steelblue')
axes[0].set_xlabel('Mean |SHAP|')
axes[0].set_title('Without Normalization')
axes[0].invert_yaxis()

importance_wf.plot(kind='barh', ax=axes[1], color='coral')
axes[1].set_xlabel('Mean |SHAP|')
axes[1].set_title('With Normalization')
axes[1].invert_yaxis()

plt.suptitle('Effect of Normalization on SHAP Importance', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Grouped Model SHAP (Per-Group Feature Importance)

Analyze feature importance separately for different customer segments.

In [None]:
# Create grouped data (by plan type)
grouped_data = train_data.copy()
grouped_test = test_data.copy()

print(f"Groups (plan types): {grouped_data['plan_type'].unique()}")
print(f"\nGroup sizes:")
print(grouped_data['plan_type'].value_counts())

In [None]:
# Fit nested models (separate model per plan type)
spec_nested = rand_forest(trees=100, min_n=5).set_mode('regression')

nested_fit = spec_nested.fit_nested(
    grouped_data,
    'churn_risk ~ age + tenure_months + logins_per_month + features_used + support_calls',
    group_col='plan_type'
)

print(f"Fitted {len(nested_fit.group_fits)} models (one per plan type)")

In [None]:
# Compute SHAP per group
shap_grouped = nested_fit.explain(grouped_test, check_additivity=False)

print("\nSHAP values computed for all groups!")
print(f"Shape: {shap_grouped.shape}")
print(f"Groups: {shap_grouped['group'].unique()}")

# Feature importance by group
importance_by_group = shap_grouped.groupby(['group', 'variable'])['abs_shap'].mean().unstack()

print("\nFeature Importance by Plan Type:")
print(importance_by_group)

In [None]:
# Visualize group differences
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, group in enumerate(['Basic', 'Premium', 'Enterprise']):
    if group in importance_by_group.index:
        importance_by_group.loc[group].sort_values(ascending=True).plot(
            kind='barh',
            ax=axes[idx],
            color='steelblue',
            alpha=0.7
        )
        axes[idx].set_title(f'{group} Plan')
        axes[idx].set_xlabel('Mean |SHAP|')
        if idx == 0:
            axes[idx].set_ylabel('Feature')

plt.suptitle('Feature Importance by Customer Segment', fontsize=14)
plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("- Different customer segments have different risk drivers")
print("- Basic plan: Support calls more important")
print("- Premium/Enterprise: Engagement metrics more important")

## 7. Identifying Prediction Errors with SHAP

Use SHAP to understand when and why the model makes mistakes.

In [None]:
# Get predictions and calculate errors
preds_test = fit_rf.predict(test_data, type='numeric')
test_analysis = test_data.copy()
test_analysis['prediction'] = preds_test['.pred'].values
test_analysis['actual'] = test_data['churn_risk'].values
test_analysis['error'] = test_analysis['actual'] - test_analysis['prediction']
test_analysis['abs_error'] = abs(test_analysis['error'])

# Find worst predictions
worst_idx = test_analysis['abs_error'].nlargest(3).index

print("Worst Predictions:")
print(test_analysis.loc[worst_idx, ['actual', 'prediction', 'error', 'abs_error']])

In [None]:
# Explain worst predictions
worst_customers = test_data.loc[worst_idx]
shap_worst = fit_rf.explain(worst_customers, check_additivity=False)

# Analyze each bad prediction
for customer_id in worst_idx:
    customer_shap = shap_worst[shap_worst.index.get_level_values(0).isin([customer_id])]

    actual = test_analysis.loc[customer_id, 'actual']
    predicted = test_analysis.loc[customer_id, 'prediction']
    error = test_analysis.loc[customer_id, 'error']

    print(f"\n{'='*70}")
    print(f"Customer {customer_id}")
    print(f"{'='*70}")
    print(f"Actual: {actual:.3f}  |  Predicted: {predicted:.3f}  |  Error: {error:+.3f}")
    print(f"\nTop SHAP Contributors:")

    top_shap = customer_shap.nlargest(3, 'abs_shap')
    for _, row in top_shap.iterrows():
        sign = "↑" if row['shap_value'] >= 0 else "↓"
        print(f"  {row['variable']:<25} {sign} {row['shap_value']:>8.4f}  (value={row['feature_value']:.2f})")

## 8. SHAP Dependence Plots

Visualize how feature values affect predictions.

In [None]:
# Create SHAP dependence plot
def plot_shap_dependence(shap_df, feature):
    """Plot SHAP values vs feature values."""
    feature_data = shap_df[shap_df['variable'] == feature].copy()

    fig, ax = plt.subplots(figsize=(10, 6))

    scatter = ax.scatter(
        feature_data['feature_value'],
        feature_data['shap_value'],
        c=feature_data['abs_shap'],
        cmap='viridis',
        alpha=0.6,
        s=50
    )

    ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)
    ax.set_xlabel(f'{feature} (Feature Value)')
    ax.set_ylabel('SHAP Value (Impact on Prediction)')
    ax.set_title(f'SHAP Dependence Plot: {feature}')
    ax.grid(True, alpha=0.3)

    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('|SHAP Value|')

    plt.tight_layout()
    plt.show()

# Plot for key features
for feature in ['tenure_months', 'support_calls', 'logins_per_month']:
    plot_shap_dependence(shap_rf, feature)

## 9. Summary and Best Practices

**Key Takeaways:**

1. **SHAP Advantages:**
   - Model-agnostic explanations
   - Theoretically grounded (Shapley values)
   - Additivity property
   - Local and global interpretability

2. **Explainer Types:**
   - TreeExplainer: Fast, exact for tree models
   - LinearExplainer: Fast, exact for linear models
   - KernelExplainer: Slow, model-agnostic fallback

3. **Use Cases:**
   - Global importance: Which features matter most?
   - Local explanations: Why this prediction?
   - Model comparison: Different models, same metric
   - Error analysis: When does model fail?
   - Grouped models: Heterogeneous feature importance

4. **Best Practices:**
   - Use auto-selection (fast explainers when possible)
   - Check additivity for verification
   - Compare multiple models
   - Analyze error cases
   - Consider grouped models for heterogeneous data

5. **Integration:**
   - Works with ModelFit and WorkflowFit
   - Handles preprocessing automatically
   - Supports nested/grouped models
   - Returns tidy DataFrame format

In [None]:
# Final summary visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# 1. Global importance comparison
comparison_df.plot(kind='barh', ax=axes[0, 0], width=0.7)
axes[0, 0].set_title('Model Comparison: Feature Importance')
axes[0, 0].set_xlabel('Mean |SHAP|')
axes[0, 0].invert_yaxis()

# 2. SHAP distribution by feature
shap_rf_pivot = shap_rf.pivot_table(
    values='shap_value',
    columns='variable',
    aggfunc=list
)
shap_distributions = []
labels = []
for col in importance_rf.index[:5]:  # Top 5 features
    vals = [v for sublist in shap_rf[shap_rf['variable'] == col]['shap_value'].tolist()
            for v in ([sublist] if not isinstance(sublist, list) else sublist)]
    shap_distributions.append(vals)
    labels.append(col)

axes[0, 1].boxplot(shap_distributions, labels=labels, vert=True)
axes[0, 1].set_title('SHAP Value Distribution (Top 5 Features)')
axes[0, 1].set_ylabel('SHAP Value')
axes[0, 1].axhline(y=0, color='red', linestyle='--', alpha=0.5)
axes[0, 1].tick_params(axis='x', rotation=45)

# 3. Prediction scatter with SHAP coloring
test_shap_summary = shap_rf.groupby(level=0)['abs_shap'].sum()
scatter = axes[1, 0].scatter(
    test_with_preds['pred_risk'],
    test_with_preds.loc[test_shap_summary.index, 'churn_risk'],
    c=test_shap_summary,
    cmap='viridis',
    alpha=0.6,
    s=50
)
axes[1, 0].plot([0, 1], [0, 1], 'k--', alpha=0.5)
axes[1, 0].set_xlabel('Predicted Risk')
axes[1, 0].set_ylabel('Actual Risk')
axes[1, 0].set_title('Predictions (colored by total |SHAP|)')
cbar = plt.colorbar(scatter, ax=axes[1, 0])
cbar.set_label('Total |SHAP|')

# 4. Group-wise importance heatmap
if 'Basic' in importance_by_group.index:
    im = axes[1, 1].imshow(importance_by_group.values, cmap='YlOrRd', aspect='auto')
    axes[1, 1].set_xticks(range(len(importance_by_group.columns)))
    axes[1, 1].set_xticklabels(importance_by_group.columns, rotation=45, ha='right')
    axes[1, 1].set_yticks(range(len(importance_by_group.index)))
    axes[1, 1].set_yticklabels(importance_by_group.index)
    axes[1, 1].set_title('Feature Importance Heatmap by Segment')
    plt.colorbar(im, ax=axes[1, 1], label='Mean |SHAP|')

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("DEMO COMPLETE")
print("="*70)
print("SHAP provides powerful model interpretability for any model type!")