# 🔍 Model Interpretation & Explainability with Google Drive

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kstawiski/OmicSelector2/blob/main/examples/06_model_interpretation.ipynb)

**What You'll Learn:**
- 📊 Comprehensive model evaluation metrics
- 🎯 SHAP values for feature importance
- 📈 ROC curves, PR curves, calibration plots
- 🧠 Feature interactions and dependencies
- 💾 Save all interpretation results to Drive

**Estimated Time**: 25-30 minutes  
**Prerequisites**: Basic ML model training

---

In [None]:
# Setup
!pip install -q git+https://github.com/kstawiski/OmicSelector2.git
!pip install -q shap lime

from google.colab import drive
import os

drive.mount('/content/drive', force_remount=False)
BASE_DIR = '/content/drive/MyDrive/OmicSelector2'
os.makedirs(f'{BASE_DIR}/results/interpretation', exist_ok=True)
os.makedirs(f'{BASE_DIR}/plots/interpretation', exist_ok=True)

print(f'✅ Drive mounted: {BASE_DIR}')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    roc_auc_score, roc_curve, precision_recall_curve,
    confusion_matrix, classification_report,
    calibration_curve
)
import pickle

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

print('✅ Libraries loaded!')

## 📊 Load Data and Train Model

In [None]:
# Load or create data
np.random.seed(42)
X, y = make_classification(
    n_samples=500, n_features=50, n_informative=20,
    n_redundant=10, random_state=42
)
X = pd.DataFrame(X, columns=[f'GENE_{i:04d}' for i in range(X.shape[1])])

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Train model
print('🤖 Training Random Forest...')
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Predictions
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test, y_pred_proba)

print(f'✅ Model trained! Test AUC: {auc:.3f}')

## 📈 Comprehensive Evaluation Metrics

In [None]:
print('📊 Model Evaluation\n' + '='*60 + '\n')

# Classification report
print('Classification Report:')
print(classification_report(y_test, y_pred, target_names=['Non-responder', 'Responder']))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)

# Create comprehensive plot
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. Confusion Matrix
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0, 0],
            xticklabels=['Non-responder', 'Responder'],
            yticklabels=['Non-responder', 'Responder'])
axes[0, 0].set_ylabel('True Label')
axes[0, 0].set_xlabel('Predicted Label')
axes[0, 0].set_title('Confusion Matrix', fontweight='bold')

# 2. ROC Curve
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
axes[0, 1].plot(fpr, tpr, linewidth=2, label=f'AUC = {auc:.3f}')
axes[0, 1].plot([0, 1], [0, 1], 'k--', linewidth=1)
axes[0, 1].set_xlabel('False Positive Rate')
axes[0, 1].set_ylabel('True Positive Rate')
axes[0, 1].set_title('ROC Curve', fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Precision-Recall Curve
precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
axes[0, 2].plot(recall, precision, linewidth=2)
axes[0, 2].set_xlabel('Recall')
axes[0, 2].set_ylabel('Precision')
axes[0, 2].set_title('Precision-Recall Curve', fontweight='bold')
axes[0, 2].grid(True, alpha=0.3)

# 4. Prediction distribution
axes[1, 0].hist(y_pred_proba[y_test==0], bins=30, alpha=0.6, label='Class 0', edgecolor='black')
axes[1, 0].hist(y_pred_proba[y_test==1], bins=30, alpha=0.6, label='Class 1', edgecolor='black')
axes[1, 0].set_xlabel('Predicted Probability')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Prediction Distribution', fontweight='bold')
axes[1, 0].legend()

# 5. Calibration curve
prob_true, prob_pred = calibration_curve(y_test, y_pred_proba, n_bins=10)
axes[1, 1].plot(prob_pred, prob_true, marker='o', linewidth=2, label='Model')
axes[1, 1].plot([0, 1], [0, 1], 'k--', linewidth=1, label='Perfect')
axes[1, 1].set_xlabel('Predicted Probability')
axes[1, 1].set_ylabel('True Probability')
axes[1, 1].set_title('Calibration Curve', fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# 6. Feature importance (top 20)
importances = pd.DataFrame({
    'feature': X_train.columns,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False).head(20)

axes[1, 2].barh(range(len(importances)), importances['importance'], color='steelblue')
axes[1, 2].set_yticks(range(len(importances)))
axes[1, 2].set_yticklabels(importances['feature'])
axes[1, 2].invert_yaxis()
axes[1, 2].set_xlabel('Importance')
axes[1, 2].set_title('Top 20 Feature Importances', fontweight='bold')

plt.tight_layout()
eval_path = f'{BASE_DIR}/plots/interpretation/comprehensive_evaluation.png'
plt.savefig(eval_path, dpi=300, bbox_inches='tight')
print(f'\n💾 Evaluation plots saved to: {eval_path}')
plt.show()

## 🎯 SHAP Values for Feature Importance

In [None]:
print('🎯 Computing SHAP values...')

# Create SHAP explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

# For binary classification, take class 1
if isinstance(shap_values, list):
    shap_values = shap_values[1]

print('✅ SHAP values computed!\n')

# Summary plot
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, plot_type='bar', show=False, max_display=20)
plt.tight_layout()
shap_bar_path = f'{BASE_DIR}/plots/interpretation/shap_summary_bar.png'
plt.savefig(shap_bar_path, dpi=300, bbox_inches='tight')
print(f'💾 SHAP bar plot saved to: {shap_bar_path}')
plt.show()

# Detailed beeswarm plot
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, show=False, max_display=20)
plt.tight_layout()
shap_bee_path = f'{BASE_DIR}/plots/interpretation/shap_summary_beeswarm.png'
plt.savefig(shap_bee_path, dpi=300, bbox_inches='tight')
print(f'💾 SHAP beeswarm plot saved to: {shap_bee_path}')
plt.show()

# Save SHAP values
shap_df = pd.DataFrame(shap_values, columns=X_test.columns)
shap_path = f'{BASE_DIR}/results/interpretation/shap_values.csv'
shap_df.to_csv(shap_path, index=False)
print(f'💾 SHAP values saved to: {shap_path}')

## 📝 Create Interpretation Report

In [None]:
from datetime import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

report = f"""
═══════════════════════════════════════════════════════════════
    MODEL INTERPRETATION REPORT - OmicSelector2
═══════════════════════════════════════════════════════════════

Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Location: {BASE_DIR}

MODEL INFORMATION
───────────────────────────────────────────────────────────────
Model Type:           Random Forest Classifier
Number of Trees:      100
Features Used:        {X_train.shape[1]}
Training Samples:     {X_train.shape[0]}
Test Samples:         {X_test.shape[0]}

PERFORMANCE METRICS
───────────────────────────────────────────────────────────────
Accuracy:             {accuracy_score(y_test, y_pred):.4f}
Precision:            {precision_score(y_test, y_pred):.4f}
Recall:               {recall_score(y_test, y_pred):.4f}
F1 Score:             {f1_score(y_test, y_pred):.4f}
AUC-ROC:              {auc:.4f}

TOP 10 IMPORTANT FEATURES (SHAP)
───────────────────────────────────────────────────────────────
{chr(10).join([f'{i+1:2d}. {feat}' for i, feat in enumerate(importances.head(10)['feature'])])}

CONFUSION MATRIX
───────────────────────────────────────────────────────────────
                 Predicted
                 Neg    Pos
Actual Neg      {cm[0, 0]:4d}   {cm[0, 1]:4d}
       Pos      {cm[1, 0]:4d}   {cm[1, 1]:4d}

FILES SAVED TO GOOGLE DRIVE
───────────────────────────────────────────────────────────────
Model:                {BASE_DIR}/models/
SHAP Values:          {BASE_DIR}/results/interpretation/
Plots:                {BASE_DIR}/plots/interpretation/

KEY INSIGHTS
───────────────────────────────────────────────────────────────
✓ Model shows {'excellent' if auc > 0.9 else 'good' if auc > 0.8 else 'moderate'} discriminative ability (AUC = {auc:.3f})
✓ SHAP analysis reveals feature importance and interactions
✓ All interpretation results saved to Google Drive
✓ Model {'is well-calibrated' if abs(prob_true[-1] - prob_pred[-1]) < 0.1 else 'may need calibration'}

═══════════════════════════════════════════════════════════════
"""

print(report)

report_path = f'{BASE_DIR}/results/interpretation/interpretation_report.txt'
with open(report_path, 'w') as f:
    f.write(report)

print(f'\n💾 Report saved to: {report_path}')
print(f'\n✅ All interpretation results saved to Google Drive!')
print(f'📂 Access at: {BASE_DIR}')

## 🎓 Summary

### What You've Learned

✅ **Comprehensive Evaluation**: ROC, PR, calibration curves  
✅ **SHAP Values**: Feature importance and interactions  
✅ **Model Diagnostics**: Confusion matrix, prediction distributions  
✅ **Drive Integration**: All results saved automatically  

### Why Model Interpretation Matters

1. **Trust**: Understand how models make predictions
2. **Debugging**: Identify potential issues
3. **Biological Insights**: Discover meaningful patterns
4. **Clinical Translation**: Explain to stakeholders
5. **Regulatory**: FDA/EMA require explainability

### 📚 Next Steps

- **[07_complete_workflow.ipynb](07_complete_workflow.ipynb)** - Full end-to-end pipeline

---

**Explainable AI = trustworthy biomarkers! 🧬🔬**