# Model Interpretability with SHAP

**DOST-ITDI AI Training Workshop**  
**Day 2 - Session 5: Understanding Model Predictions**

---

## Learning Objectives
1. Understand why model interpretability matters
2. Learn SHAP (SHapley Additive exPlanations) fundamentals
3. Interpret predictions from regression models
4. Interpret predictions from classification models
5. Apply interpretability to chemistry problems
6. Make better decisions using model insights

## Why Model Interpretability?

### The Black Box Problem
Machine learning models can make accurate predictions, but often we don't understand **why**.

**Questions we want to answer:**
- Which features are most important?
- Why did the model predict this value?
- Can we trust this prediction?
- What happens if we change a feature?

### Why This Matters in Science
- **Regulatory compliance** - FDA, EPA need explanations
- **Scientific discovery** - Find new insights
- **Trust** - Convince stakeholders
- **Debugging** - Find model errors
- **Ethical AI** - Detect bias

### Example: Drug Development
Model says: "This compound will be 85% bioavailable"

**We need to know:**
- Which molecular features drove this prediction?
- Is the model focusing on the right chemistry?
- Should we trust this for a $10M clinical trial decision?

## Section 1: Setup and Load Data

In [None]:
# Install SHAP
!pip install shap scikit-learn rdkit xgboost -q

print("[OK] Libraries installed!")

In [None]:
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import warnings
warnings.filterwarnings('ignore')

# Machine learning
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, classification_report
import xgboost as xgb

# Chemistry
from rdkit import Chem
from rdkit.Chem import Descriptors

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
np.random.seed(42)

print("[OK] Libraries imported!")

### 1.1 Load ESOL Dataset (Solubility Prediction)

In [None]:
# Load ESOL dataset - same as notebooks 01, 02, 04!
url = "https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/delaney-processed.csv"
df = pd.read_csv(url)

print(f"Loaded {len(df)} molecules")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
display(df.head())

# Target: measured log solubility
print(f"\nTarget variable: measured log solubility in mols per litre")
print(f"Range: [{df['measured log solubility in mols per litre'].min():.2f}, {df['measured log solubility in mols per litre'].max():.2f}]")

### 1.2 Feature Engineering

In [None]:
# Calculate molecular descriptors using RDKit
def calculate_descriptors(smiles):
    """Calculate molecular descriptors from SMILES"""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    return {
        'MolWt': Descriptors.MolWt(mol),
        'LogP': Descriptors.MolLogP(mol),
        'NumHDonors': Descriptors.NumHDonors(mol),
        'NumHAcceptors': Descriptors.NumHAcceptors(mol),
        'NumRotatableBonds': Descriptors.NumRotatableBonds(mol),
        'NumAromaticRings': Descriptors.NumAromaticRings(mol),
        'TPSA': Descriptors.TPSA(mol),
        'NumAtoms': mol.GetNumAtoms()
    }

# Calculate for all molecules
print("Calculating molecular descriptors...")
descriptors = []
for smiles in df['smiles']:
    desc = calculate_descriptors(smiles)
    descriptors.append(desc)

# Create features DataFrame
features_df = pd.DataFrame(descriptors)
features_df['Compound'] = df['Compound ID'].values
features_df['Solubility'] = df['measured log solubility in mols per litre'].values

# Remove any failed conversions
features_df = features_df.dropna()

print(f"\n[OK] Calculated descriptors for {len(features_df)} molecules")
print(f"\nFeatures: {features_df.columns.tolist()[:-2]}")
display(features_df.head())

## Section 2: Train Regression Model

In [None]:
# Prepare data
feature_cols = ['MolWt', 'LogP', 'NumHDonors', 'NumHAcceptors',
                'NumRotatableBonds', 'NumAromaticRings', 'TPSA', 'NumAtoms']

X = features_df[feature_cols]
y = features_df['Solubility']

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print(f"Training set: {X_train.shape[0]} molecules")
print(f"Test set: {X_test.shape[0]} molecules")

# Train Random Forest
print("\nTraining Random Forest model...")
rf_model = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)
rf_model.fit(X_train, y_train)

# Predictions
y_pred = rf_model.predict(X_test)

# Evaluate
r2 = r2_score(y_test, y_pred)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))

print(f"\n[OK] Model trained!")
print(f"R² Score: {r2:.4f}")
print(f"RMSE: {rmse:.4f}")

# Visualize predictions
plt.figure(figsize=(8, 6))
plt.scatter(y_test, y_pred, alpha=0.6, edgecolors='k')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
plt.xlabel('Actual Solubility', fontsize=12)
plt.ylabel('Predicted Solubility', fontsize=12)
plt.title(f'Model Performance (R² = {r2:.4f})', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Section 3: SHAP - Understanding the Basics

### What is SHAP?

**SHAP** = SHapley Additive exPlanations

Based on game theory (Shapley values) - fairly distributes "credit" for a prediction among features.

### Key Concepts:

1. **SHAP Value** = How much a feature contributes to a prediction
   - Positive SHAP → increases prediction
   - Negative SHAP → decreases prediction
   - Zero SHAP → no effect

2. **Base Value** = Average prediction across all data

3. **Prediction** = Base Value + sum of all SHAP values

### Analogy: Team Project Grade

Your team got 85/100 on a project. How much did each member contribute?

- **Base value**: Average grade = 75
- **Alice**: +8 (excellent research)
- **Bob**: +5 (good presentation)
- **Carlos**: -3 (late submission)
- **Diana**: 0 (average contribution)

**Final**: 75 + 8 + 5 - 3 + 0 = **85**

SHAP does this for features!

### 3.1 Initialize SHAP Explainer

In [None]:
# Create SHAP explainer for Random Forest
print("Creating SHAP explainer...")
explainer = shap.TreeExplainer(rf_model)

# Calculate SHAP values for test set
print("Calculating SHAP values (this may take a minute)...")
shap_values = explainer.shap_values(X_test)

print(f"\n[OK] SHAP values calculated!")
print(f"Shape: {shap_values.shape}")
print(f"  - {shap_values.shape[0]} test samples")
print(f"  - {shap_values.shape[1]} features")

# Base value (average prediction)
base_value = explainer.expected_value
print(f"\nBase value (average prediction): {base_value:.4f}")

### 3.2 Global Feature Importance

In [None]:
# Summary plot - shows overall feature importance
print("Feature Importance (Global View)")
print("="*60)

shap.summary_plot(shap_values, X_test, plot_type="bar")
plt.tight_layout()
plt.show()

print("\n[INFO] This shows which features are most important OVERALL")
print("Higher bar = more impact on predictions across all molecules")

### 3.3 Summary Plot - Feature Effects

In [None]:
# Summary plot - shows how features affect predictions
print("Feature Effects Summary")
print("="*60)
print("How to read this plot:")
print("  - Each dot = one molecule")
print("  - X-axis: SHAP value (impact on prediction)")
print("  - Color: Feature value (red=high, blue=low)")
print("\nExample: If LogP has red dots on the right:")
print("  -> High LogP increases solubility prediction")
print("="*60)

shap.summary_plot(shap_values, X_test)
plt.tight_layout()
plt.show()

## Section 4: Individual Predictions

### 4.1 Explain a Single Prediction

In [None]:
# Pick a molecule to explain
sample_idx = 0
sample = X_test.iloc[sample_idx]
actual_value = y_test.iloc[sample_idx]
predicted_value = rf_model.predict([sample])[0]

print("Explaining Individual Prediction")
print("="*60)
print(f"Molecule #{sample_idx}")
print(f"\nActual Solubility: {actual_value:.4f}")
print(f"Predicted Solubility: {predicted_value:.4f}")
print(f"Base Value (average): {base_value:.4f}")
print(f"\nMolecular Features:")
for feat, val in sample.items():
    print(f"  {feat:20}: {val:.2f}")
print("="*60)

In [None]:
# Waterfall plot - shows how we get from base value to prediction
print("\nWaterfall Plot - How Features Build the Prediction")
print("="*60)
print("This shows step-by-step how each feature contributes:")
print("  - Start at base value (average)")
print("  - Each bar adds or subtracts from prediction")
print("  - End at final predicted value")
print("="*60)

shap.plots.waterfall(shap.Explanation(
    values=shap_values[sample_idx],
    base_values=base_value,
    data=sample,
    feature_names=feature_cols
))
plt.tight_layout()
plt.show()

In [None]:
# Force plot - another way to visualize individual prediction
print("\nForce Plot - Visual Breakdown")
print("="*60)
print("Red features = push prediction HIGHER")
print("Blue features = push prediction LOWER")
print("Width = magnitude of effect")
print("="*60)

shap.initjs()
shap.force_plot(
    base_value,
    shap_values[sample_idx],
    sample,
    matplotlib=True
)
plt.tight_layout()
plt.show()

### 4.2 Compare Multiple Predictions

In [None]:
# Compare 3 molecules
indices = [0, 10, 20]

fig, axes = plt.subplots(len(indices), 1, figsize=(12, 4*len(indices)))

for i, idx in enumerate(indices):
    sample = X_test.iloc[idx]
    actual = y_test.iloc[idx]
    pred = rf_model.predict([sample])[0]

    # Create waterfall data
    shap_exp = shap.Explanation(
        values=shap_values[idx],
        base_values=base_value,
        data=sample,
        feature_names=feature_cols
    )

    plt.sca(axes[i])
    shap.plots.waterfall(shap_exp, show=False)
    axes[i].set_title(f'Molecule #{idx} | Actual: {actual:.3f}, Predicted: {pred:.3f}',
                      fontweight='bold')

plt.tight_layout()
plt.show()

print("\n[INFO] Notice how different features drive predictions for different molecules!")

## Section 5: Feature Dependence

### 5.1 How does LogP affect predictions?

In [None]:
# Dependence plot - shows relationship between feature value and SHAP value
print("Feature Dependence: LogP")
print("="*60)
print("This shows how LogP affects predictions:")
print("  - X-axis: LogP value")
print("  - Y-axis: SHAP value (impact on prediction)")
print("  - Color: Another feature (interaction effect)")
print("="*60)

shap.dependence_plot(
    "LogP",
    shap_values,
    X_test,
    interaction_index="MolWt"
)
plt.tight_layout()
plt.show()

print("\n[KEY INSIGHT]")
print("If the plot shows an upward trend:")
print("  -> Higher LogP leads to higher predicted solubility")
print("If there's a color gradient:")
print("  -> The effect of LogP depends on MolWt (interaction)")

In [None]:
# Multiple dependence plots
important_features = ['LogP', 'MolWt', 'TPSA', 'NumAromaticRings']

fig, axes = plt.subplots(2, 2, figsize=(15, 12))
axes = axes.ravel()

for i, feat in enumerate(important_features):
    plt.sca(axes[i])
    shap.dependence_plot(
        feat,
        shap_values,
        X_test,
        show=False
    )
    axes[i].set_title(f'Effect of {feat}', fontweight='bold', fontsize=12)

plt.tight_layout()
plt.show()

## Section 6: Decision Plot - Multiple Molecules

In [None]:
# Decision plot - shows prediction paths for multiple molecules
print("Decision Plot - Prediction Paths")
print("="*60)
print("This shows how predictions are built step by step:")
print("  - Each line = one molecule")
print("  - Start at base value (left)")
print("  - Each feature shifts the line up or down")
print("  - End at predicted value (right)")
print("="*60)

# Select a subset of molecules
subset_idx = np.random.choice(len(X_test), 20, replace=False)

shap.decision_plot(
    base_value,
    shap_values[subset_idx],
    X_test.iloc[subset_idx],
    feature_names=feature_cols
)
plt.tight_layout()
plt.show()

## Section 7: Practical Applications

### 7.1 Find Unusual Predictions

In [None]:
# Find molecules with large prediction errors
errors = np.abs(y_test.values - y_pred)
worst_idx = np.argmax(errors)

print("Analyzing Worst Prediction")
print("="*60)
print(f"Molecule index: {worst_idx}")
print(f"Actual: {y_test.iloc[worst_idx]:.4f}")
print(f"Predicted: {y_pred[worst_idx]:.4f}")
print(f"Error: {errors[worst_idx]:.4f}")
print("\nWhy did the model get this wrong?")
print("="*60)

# Explain this prediction
sample = X_test.iloc[worst_idx]
shap.plots.waterfall(shap.Explanation(
    values=shap_values[worst_idx],
    base_values=base_value,
    data=sample,
    feature_names=feature_cols
))
plt.title(f'Worst Prediction Analysis (Error: {errors[worst_idx]:.4f})',
          fontweight='bold')
plt.tight_layout()
plt.show()

print("\n[INSIGHT] This helps us understand model limitations!")
print("Maybe the model needs more features, or this molecule is an outlier.")

### 7.2 Feature Engineering Insights

In [None]:
# Calculate average absolute SHAP values for each feature
mean_abs_shap = np.abs(shap_values).mean(axis=0)

feature_importance = pd.DataFrame({
    'Feature': feature_cols,
    'Importance': mean_abs_shap
}).sort_values('Importance', ascending=False)

print("Feature Importance Ranking")
print("="*60)
print(feature_importance.to_string(index=False))
print("="*60)

# Visualize
plt.figure(figsize=(10, 6))
plt.barh(feature_importance['Feature'], feature_importance['Importance'], color='steelblue')
plt.xlabel('Mean |SHAP Value|', fontsize=12)
plt.title('Feature Importance for Solubility Prediction', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

print("\n[RECOMMENDATION]")
print(f"Top 3 features: {', '.join(feature_importance['Feature'].head(3).tolist())}")
print("Focus on these when designing new molecules!")
print(f"\nLeast important: {feature_importance['Feature'].iloc[-1]}")
print("Consider removing or replacing with better features.")

## Section 8: Summary & Best Practices

### What We Learned:

1. **SHAP Values**
   - Explain how features contribute to predictions
   - Based on solid game theory (Shapley values)
   - Works for any model (trees, neural networks, etc.)

2. **Global Interpretation**
   - Summary plots show overall feature importance
   - Dependence plots show feature effects
   - Helps understand model behavior

3. **Local Interpretation**
   - Waterfall plots explain individual predictions
   - Force plots show feature contributions visually
   - Builds trust in specific decisions

4. **Practical Applications**
   - Debug bad predictions
   - Identify important features
   - Guide feature engineering
   - Discover new insights

### Best Practices:

1. **Always Interpret Important Decisions**
   - High-stakes predictions (drug approval, safety)
   - Unexpected results
   - Regulatory requirements

2. **Check for Sensibility**
   - Do the important features make scientific sense?
   - Are there spurious correlations?
   - Is the model learning the right chemistry?

3. **Use Multiple Views**
   - Global: Summary plots, feature importance
   - Local: Waterfall, force plots
   - Dependence: How features interact

4. **Combine with Domain Knowledge**
   - SHAP shows correlations, not causation
   - Validate insights with chemistry knowledge
   - Use to generate hypotheses, then test experimentally

### When to Use SHAP:

✅ **Good for:**
- Understanding model predictions
- Debugging models
- Building trust with stakeholders
- Feature selection
- Scientific discovery

❌ **Not sufficient for:**
- Proving causation (use experiments)
- Replacing domain expertise
- Legal compliance alone (consult regulations)

### Connection to Other Notebooks:

| Notebook | Connection |
|----------|------------|
| **02_Regression** | Applied SHAP to regression models |
| **03_Classification** | Can apply same techniques to classifiers |
| **04_PyTorch** | SHAP works with neural networks too |
| **06_CV** | Next: Visualize what CNNs learn |

### Resources:

- [SHAP Documentation](https://shap.readthedocs.io/)
- [Interpretable ML Book](https://christophm.github.io/interpretable-ml-book/)
- [SHAP Paper](https://arxiv.org/abs/1705.07874)

---

**Key Takeaway**: A model you can explain is a model you can trust and improve!