# SHAPXplain: Example Usage

This notebook demonstrates how to use the SHAPXplain package to integrate SHAP explanations with LLMs.

In [1]:
import numpy as np 
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from shap import TreeExplainer
from shapxplain import ShapLLMExplainer
from pydantic_ai import Agent
import nest_asyncio
nest_asyncio.apply() # Fixes issues with pydantic-ai event loops in jupyter

In [2]:
# Load data and train model
data = load_iris()
X, y = data.data, data.target
model = RandomForestClassifier(random_state=42)
model.fit(X, y)

# Generate SHAP values
explainer = TreeExplainer(model)
shap_values = explainer.shap_values(X)

In [3]:
# Create an LLM agent
llm_agent = Agent(model="openai:gpt-4o")

# Instantiate the SHAPXplain explainer
llm_explainer = ShapLLMExplainer(
    model=model,
    llm_agent=llm_agent,
    feature_names=data.feature_names,
    significance_threshold=0.1
)

In [4]:
# Generate an explanation for a specific data point (index 0)
data_point = X[0]
# Get predictions
prediction_probs = model.predict_proba(data_point.reshape(1, -1))[0]
predicted_class_idx = model.predict(data_point.reshape(1, -1))[0]
prediction_class = data.target_names[predicted_class_idx]

In [5]:
# For multi-class problems, shap_values is a 3D array (instances, features, classes)
# Select SHAP values for the predicted class and the specific data point
data_point_index = 0  # Index of the data point to explain
predicted_class_idx = np.argmax(prediction_probs)  # Index of the predicted class

# Extract SHAP values for the data point and class
class_shap_values = shap_values[data_point_index][:, predicted_class_idx]

# Verify shapes
print("Class SHAP values shape:", len(class_shap_values))  # Should match `data_point`
print("Data point shape:", len(data_point))  # Should match `class_shap_values`

# Ensure the dimensions match
assert len(class_shap_values) == len(data_point), "SHAP values and data point dimensions do not match!"



Class SHAP values shape: 4
Data point shape: 4


In [6]:
explanation = llm_explainer.explain(
    shap_values=class_shap_values,  # SHAP values for the predicted class
    data_point=data_point,
    prediction=prediction_probs[predicted_class_idx],
    prediction_class=prediction_class,
    additional_context={
        "dataset": "Iris",
        "feature_descriptions": {
            "sepal length": "Length of the sepal in cm",
            "sepal width": "Width of the sepal in cm",
            "petal length": "Length of the petal in cm",
            "petal width": "Width of the petal in cm"
        }
    }
)

print("Summary:", explanation.summary)
print("\nDetailed Explanation:", explanation.detailed_explanation)
print("\nRecommendations:", explanation.recommendations)
print("\nConfidence Level:", explanation.confidence_level)

Summary: The prediction that the class is 'setosa' is primarily driven by the petal length and petal width, which have measurements typical of this species.

Detailed Explanation: In the Iris dataset, setosa is characterized by smaller petal dimensions compared to other species. The measurements of the petal length (1.4 cm) and width (0.2 cm) are consistent with the species setosa. Additionally, the sepal dimensions also support this classification as they align with commonly observed sepal sizes for setosa (5.1 cm length and 3.5 cm width), although they play a lesser role compared to petal dimensions.

Recommendations: ['To validate this prediction, cross-reference the predicted class with other botanical characteristics unique to setosa.', 'Consider measuring additional features or using a different set of features if there is a need to improve classification accuracy.']

Confidence Level: high
