# Day 38: Model Explainability with SHAP

In this lab, we will use **SHAP (SHapley Additive exPlanations)** to explain the predictions of a machine learning model.
We will see which features contributed positively or negatively to a specific prediction.

In [None]:
import sys
import os
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

# Add root directory to sys.path
sys.path.append(os.path.abspath('../../'))

from src.observability.shap_wrapper import ShapExplainer

## 1. Train Model

We use the classic Iris dataset.

In [None]:
data = load_iris()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

# Only take first 2 classes for binary classification simplicity (optional, SHAP handles multi-class too)
mask = y < 2
X = X[mask]
y = y[mask]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(X_train, y_train)

print("Model Accuracy:", model.score(X_test, y_test))

## 2. Explain Predictions

We initialize our `ShapExplainer` and explain a test instance.

In [None]:
# Initialize Explainer
# Ideally use a small background sample for KernelExplainer speed
background = X_train.iloc[:10]

explainer = ShapExplainer()
explainer.fit(model, background)

# Explain the first test instance
instance = X_test.iloc[:1]
print("Explaining instance:")
print(instance)

shap_values = explainer.explain_local(instance)

print("SHAP Values (Feature contributions):")
# For binary classification, KernelExplainer might return valid list or single array depending on version.
# Usually returns a list [values_for_class_0, values_for_class_1]
if isinstance(shap_values, list):
    print("Class 0 contribution:", shap_values[0])
    print("Class 1 contribution:", shap_values[1])
else:
    print(shap_values)

## 3. Visualize

Summary plot of feature importance.

In [None]:
# Explain multiple instances for summary plot
shap_values_batch = explainer.explain_local(X_test.iloc[:20])

# Handle list return type for plot
if isinstance(shap_values_batch, list):
    # Plot for positive class
    explainer.plot_summary(shap_values_batch[1], X_test.iloc[:20])
else:
    explainer.plot_summary(shap_values_batch, X_test.iloc[:20])