# Model Explainability with SHAP

This notebook demonstrates model interpretability using SHAP (SHapley Additive exPlanations) values.

SHAP helps understand which metabolites are most important for model predictions.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import sys
sys.path.append('../src')

import shap
import warnings
warnings.filterwarnings('ignore')

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# Initialize SHAP (suppress warnings)
shap.initjs()


In [None]:
# Load data and preprocess (same as training)
data_path = Path('../data/synthetic/synthetic_urine_metabolomics.csv')
df = pd.read_csv(data_path)

y = (df['diagnosis_label'] != 'control').astype(int).values

from preprocessing import MetabolomicsPreprocessor
from features import FeatureSelector
from models.baseline import BaselineModels
from sklearn.model_selection import train_test_split

preprocessor = MetabolomicsPreprocessor(
    imputation_method='knn',
    normalization_method='log2',
    batch_correction=True,
    scale_method='zscore'
)

X = preprocessor.fit_transform(df)

feature_selector = FeatureSelector(
    method='univariate',
    n_features=min(200, X.shape[1]),
    variance_threshold=0.01
)

X_selected = feature_selector.fit_transform(X, y)

# Get feature names
metab_cols = [col for col in df.columns if col.startswith('metab_')]
if hasattr(feature_selector, 'selected_features'):
    selected_metab_cols = [metab_cols[i] for i in range(len(metab_cols)) 
                         if feature_selector.selected_features[i]]
else:
    selected_metab_cols = [f"feature_{i}" for i in range(X_selected.shape[1])]

X_train, X_test, y_train, y_test = train_test_split(
    X_selected, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")
print(f"Selected features: {len(selected_metab_cols)}")


In [None]:
# Train a model for explanation (Random Forest)
baseline = BaselineModels()
baseline.train_random_forest(X_train, y_train, n_estimators=100, max_depth=10)

model = baseline.models['random_forest']
print("Model trained for explanation")

# Create SHAP explainer
print("Creating SHAP explainer...")
explainer = shap.TreeExplainer(model)

# Calculate SHAP values (use subset for speed)
n_samples = min(100, len(X_test))
sample_indices = np.random.choice(len(X_test), n_samples, replace=False)
X_explain = X_test[sample_indices]

print(f"Calculating SHAP values for {n_samples} samples...")
shap_values = explainer.shap_values(X_explain)

# For binary classification, use positive class
if isinstance(shap_values, list):
    shap_values = shap_values[1]

print(f"SHAP values shape: {shap_values.shape}")

# Get top features
mean_abs_shap = np.abs(shap_values).mean(axis=0)
top_indices = np.argsort(mean_abs_shap)[-20:][::-1]

top_features = pd.DataFrame({
    'feature': [selected_metab_cols[i] for i in top_indices],
    'mean_abs_shap': mean_abs_shap[top_indices]
})

print("\nTop 20 Most Important Features:")
print(top_features)
