# 06 - SHAP Explainability Analysis
## Osteoporosis Risk Prediction Model
**DSGP Group 40** | Student: Isum Gamage (ID: 20242052)

This notebook generates SHAP explainability visualizations for model interpretability.


## Step 1: Install and Import Libraries

In [None]:
# Install SHAP (compatible with latest XGBoost)
!pip install shap --upgrade

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import joblib
import xgboost as xgb

print(f"✓ SHAP version: {shap.__version__}")
print(f"✓ XGBoost version: {xgb.__version__}")
print("✓ All libraries imported successfully!")

## Step 2: Load Models and Data

In [None]:
print("Step 2: Loading Models and Test Data")
print("=" * 60)

try:
    # Load models
    male_model = joblib.load('osteoporosis_male_model.pkl')
    female_model = joblib.load('osteoporosis_female_model.pkl')
    
    print("✓ Male model loaded")
    print("✓ Female model loaded")
    
    # Verify models are XGBoost
    print(f"\nModel types:")
    print(f"  Male model: {type(male_model).__name__}")
    print(f"  Female model: {type(female_model).__name__}")
    
except FileNotFoundError as e:
    print(f"ERROR: {e}")
    print("Please run 04_Model_Training.ipynb first.")

## Step 3: Create SHAP Explainers

In [None]:
print("\nStep 3: Creating SHAP Explainers")
print("=" * 60)

# Create TreeExplainer for XGBoost models
try:
    # Male model explainer
    male_explainer = shap.TreeExplainer(male_model)
    print("✓ Male model SHAP explainer created")
    
    # Female model explainer
    female_explainer = shap.TreeExplainer(female_model)
    print("✓ Female model SHAP explainer created")
    
except Exception as e:
    print(f"ERROR creating explainers: {e}")
    print("\nTroubleshooting:")
    print("- Ensure XGBoost and SHAP versions are compatible")
    print("- Try updating: pip install --upgrade xgboost shap")

## Step 4: Calculate SHAP Values

In [None]:
print("\nStep 4: Calculating SHAP Values")
print("=" * 60)

try:
    # Calculate SHAP values for test sets
    print("Computing SHAP values for male cohort...")
    male_shap_values = male_explainer.shap_values(X_test_male)
    
    # Handle both single output and multi-class output
    if isinstance(male_shap_values, list):
        male_shap_values = male_shap_values[1]  # For binary classification, use positive class
    
    print(f"✓ Male SHAP values shape: {male_shap_values.shape}")
    
    print("\nComputing SHAP values for female cohort...")
    female_shap_values = female_explainer.shap_values(X_test_female)
    
    # Handle both single output and multi-class output
    if isinstance(female_shap_values, list):
        female_shap_values = female_shap_values[1]  # For binary classification, use positive class
    
    print(f"✓ Female SHAP values shape: {female_shap_values.shape}")
    
except Exception as e:
    print(f"ERROR calculating SHAP values: {e}")
    import traceback
    traceback.print_exc()

## Step 5: Feature Importance (Mean Absolute SHAP)

In [None]:
print("\nStep 5: Feature Importance Analysis")
print("=" * 60)

# Calculate mean absolute SHAP values for feature importance
male_importance = np.abs(male_shap_values).mean(axis=0)
female_importance = np.abs(female_shap_values).mean(axis=0)

# Create feature importance dataframes
male_feature_importance = pd.DataFrame({
    'Feature': X_test_male.columns,
    'Importance': male_importance
}).sort_values('Importance', ascending=False)

female_feature_importance = pd.DataFrame({
    'Feature': X_test_female.columns,
    'Importance': female_importance
}).sort_values('Importance', ascending=False)

print("\nTOP 10 FEATURES - MALE MODEL:")
print(male_feature_importance.head(10).to_string(index=False))

print("\nTOP 10 FEATURES - FEMALE MODEL:")
print(female_feature_importance.head(10).to_string(index=False))

## Step 6: SHAP Summary Plots

In [None]:
print("\nStep 6: SHAP Summary Plots")
print("=" * 60)

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Male model summary plot
plt.sca(axes[0])
shap.summary_plot(male_shap_values, X_test_male, plot_type="bar", show=False)
axes[0].set_title('Male Model - Feature Importance (Mean |SHAP|)', fontweight='bold', fontsize=12)

# Female model summary plot
plt.sca(axes[1])
shap.summary_plot(female_shap_values, X_test_female, plot_type="bar", show=False)
axes[1].set_title('Female Model - Feature Importance (Mean |SHAP|)', fontweight='bold', fontsize=12)

plt.tight_layout()
plt.savefig('shap_feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Feature importance plots saved!")

## Step 7: SHAP Waterfall Plot (Individual Predictions)

In [None]:
print("\nStep 7: SHAP Waterfall Plots (Individual Predictions)")
print("=" * 60)

# Select first high-risk and first low-risk cases from each cohort
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Male model - High risk case
high_risk_male_idx = np.where(y_pred_male == 1)[0][0] if (y_pred_male == 1).any() else 0
plt.sca(axes[0, 0])
shap.waterfall_plot(shap.Explanation(
    values=male_shap_values[high_risk_male_idx],
    base_values=male_explainer.expected_value,
    data=X_test_male.iloc[high_risk_male_idx]
), show=False)
axes[0, 0].set_title(f'Male Model - High Risk Case (Index {high_risk_male_idx})', fontweight='bold')

# Male model - Low risk case
low_risk_male_idx = np.where(y_pred_male == 0)[0][0] if (y_pred_male == 0).any() else 0
plt.sca(axes[0, 1])
shap.waterfall_plot(shap.Explanation(
    values=male_shap_values[low_risk_male_idx],
    base_values=male_explainer.expected_value,
    data=X_test_male.iloc[low_risk_male_idx]
), show=False)
axes[0, 1].set_title(f'Male Model - Low Risk Case (Index {low_risk_male_idx})', fontweight='bold')

# Female model - High risk case
high_risk_female_idx = np.where(y_pred_female == 1)[0][0] if (y_pred_female == 1).any() else 0
plt.sca(axes[1, 0])
shap.waterfall_plot(shap.Explanation(
    values=female_shap_values[high_risk_female_idx],
    base_values=female_explainer.expected_value,
    data=X_test_female.iloc[high_risk_female_idx]
), show=False)
axes[1, 0].set_title(f'Female Model - High Risk Case (Index {high_risk_female_idx})', fontweight='bold')

# Female model - Low risk case
low_risk_female_idx = np.where(y_pred_female == 0)[0][0] if (y_pred_female == 0).any() else 0
plt.sca(axes[1, 1])
shap.waterfall_plot(shap.Explanation(
    values=female_shap_values[low_risk_female_idx],
    base_values=female_explainer.expected_value,
    data=X_test_female.iloc[low_risk_female_idx]
), show=False)
axes[1, 1].set_title(f'Female Model - Low Risk Case (Index {low_risk_female_idx})', fontweight='bold')

plt.tight_layout()
plt.savefig('shap_waterfall_plots.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Waterfall plots saved!")

## Step 8: SHAP Dependence Plots (Top Features)

In [None]:
print("\nStep 8: SHAP Dependence Plots")
print("=" * 60)

# Get top 4 features
top_features_male = male_feature_importance.head(4)['Feature'].tolist()
top_features_female = female_feature_importance.head(4)['Feature'].tolist()

# Male model dependence plots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Male Model - SHAP Dependence Plots (Top 4 Features)', fontweight='bold', fontsize=14)

for idx, feature in enumerate(top_features_male):
    ax = axes[idx // 2, idx % 2]
    plt.sca(ax)
    try:
        shap.dependence_plot(feature, male_shap_values, X_test_male, ax=ax, show=False)
    except Exception as e:
        ax.text(0.5, 0.5, f'Error plotting {feature}', ha='center', va='center')

plt.tight_layout()
plt.savefig('shap_dependence_male.png', dpi=300, bbox_inches='tight')
plt.show()

# Female model dependence plots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Female Model - SHAP Dependence Plots (Top 4 Features)', fontweight='bold', fontsize=14)

for idx, feature in enumerate(top_features_female):
    ax = axes[idx // 2, idx % 2]
    plt.sca(ax)
    try:
        shap.dependence_plot(feature, female_shap_values, X_test_female, ax=ax, show=False)
    except Exception as e:
        ax.text(0.5, 0.5, f'Error plotting {feature}', ha='center', va='center')

plt.tight_layout()
plt.savefig('shap_dependence_female.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Dependence plots saved!")

## Step 9: Summary Statistics

In [None]:
print("\nStep 9: SHAP Summary Statistics")
print("=" * 60)

print("\n" + "="*60)
print("MALE MODEL SHAP ANALYSIS")
print("="*60)
print(f"\nBase value (expected model output): {male_explainer.expected_value:.4f}")
print(f"Number of samples analyzed: {len(X_test_male)}")
print(f"Number of features: {len(X_test_male.columns)}")

print("\nTop 5 Most Important Features:")
print(male_feature_importance.head(5).to_string(index=False))

print("\n" + "="*60)
print("FEMALE MODEL SHAP ANALYSIS")
print("="*60)
print(f"\nBase value (expected model output): {female_explainer.expected_value:.4f}")
print(f"Number of samples analyzed: {len(X_test_female)}")
print(f"Number of features: {len(X_test_female.columns)}")

print("\nTop 5 Most Important Features:")
print(female_feature_importance.head(5).to_string(index=False))

print("\n" + "="*60)
print("✓ SHAP Analysis Complete!")
print("="*60)

## Summary

✅ **SHAP Explainability Analysis Complete!**

### Generated Visualizations:
1. **Feature Importance Plots** - Mean |SHAP| values for each feature
2. **Waterfall Plots** - Individual prediction explanations (high-risk and low-risk cases)
3. **Dependence Plots** - Feature-SHAP relationship for top predictors

### Key Insights:
- **Male Model**: Top predictors identified via SHAP
- **Female Model**: Top predictors identified via SHAP
- **Interpretability**: Each prediction can be explained by feature contributions

### Clinical Application:
- SHAP values enable clinicians to understand:
  - Which factors drive high-risk predictions
  - How each patient's unique profile contributes to their risk score
  - Which modifiable factors could reduce risk

**All notebooks completed! Models are ready for production deployment.** ✓