# SHAPXplain: Example Usage

This notebook demonstrates how to use the SHAPXplain package to integrate SHAP explanations with LLMs.

In [1]:
import numpy as np 
import os
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from shap import TreeExplainer
from shapxplain import ShapLLMExplainer
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIModel
from openai import OpenAI
import nest_asyncio
nest_asyncio.apply() # Fixes issues with pydantic-ai event loops in jupyter

In [2]:
# Load data and train model
data = load_iris()
X, y = data.data, data.target
rf_model = RandomForestClassifier(random_state=42)
rf_model.fit(X, y)

# Generate SHAP values
explainer = TreeExplainer(rf_model)
shap_values = explainer.shap_values(X)

In [3]:
# Create an LLM agent - Pydantic-ai will pick up environment avr if set
llm_agent = Agent(model="openai:gpt-4o")

# Create an LLM agent - https://ai.pydantic.dev/models/#deepseek
#
# # Access the DEEPSEEK_API_KEY environment variable
# deepseek_api_key = os.getenv('DEEPSEEK_API_KEY')
# deepseek_client = OpenAI(api_key=deepseek_api_key, base_url="https://api.deepseek.com")
# deepseek_model = OpenAIModel(model_name="deepseek-reasoner")
# llm_agent = Agent(model=deepseek_model)


# Instantiate the SHAPXplain explainer
llm_explainer = ShapLLMExplainer(
    model=rf_model,
    llm_agent=llm_agent,
    feature_names=data.feature_names,
    significance_threshold=0.1
)

In [4]:
# Generate an explanation for a specific data point (index 0)
data_point = X[0]
# Get predictions
prediction_probs = rf_model.predict_proba(data_point.reshape(1, -1))[0]
predicted_class_idx = rf_model.predict(data_point.reshape(1, -1))[0]
prediction_class = data.target_names[predicted_class_idx]

In [5]:
# For multi-class problems, shap_values is a 3D array (instances, features, classes)
# Select SHAP values for the predicted class and the specific data point
data_point_index = 0  # Index of the data point to explain
predicted_class_idx = np.argmax(prediction_probs)  # Index of the predicted class

# Extract SHAP values for the data point and class
class_shap_values = shap_values[data_point_index][:, predicted_class_idx]

# Verify shapes
print("Class SHAP values shape:", len(class_shap_values))  # Should match `data_point`
print("Data point shape:", len(data_point))  # Should match `class_shap_values`

# Ensure the dimensions match
assert len(class_shap_values) == len(data_point), "SHAP values and data point dimensions do not match!"



Class SHAP values shape: 4
Data point shape: 4


In [6]:
explanation = llm_explainer.explain(
    shap_values=class_shap_values,  # SHAP values for the predicted class
    data_point=data_point,
    prediction=prediction_probs[predicted_class_idx],
    prediction_class=prediction_class,
    additional_context={
        "dataset": "Iris",
        "feature_descriptions": {
            "sepal length": "Length of the sepal in cm",
            "sepal width": "Width of the sepal in cm",
            "petal length": "Length of the petal in cm",
            "petal width": "Width of the petal in cm"
        }
    }
)

print("Summary:", explanation.summary)
print("\nDetailed Explanation:", explanation.detailed_explanation)
print("\nRecommendations:", explanation.recommendations)
print("\nConfidence Level:", explanation.confidence_level)

Summary: The prediction that the flower is of the species 'setosa' is primarily influenced by the petal length and width, with significant contributions also from the sepal length.

Detailed Explanation: In determining that the flower is a 'setosa', the prediction is largely influenced by the physical dimensions of the petals, where both the length and width distinctly align with common characteristics of the setosa species. The combination of relatively short and narrow petals is crucial here, as these are definitive traits of setosa flowers. Additionally, the comparatively modest sepal length further supports this classification. Together, these factors account for the unique physical appearance typical of the setosa species.

Recommendations: ['Ensure data is accurately measured and recorded to maintain prediction reliability.', 'Consider gathering more samples of setosa to reinforce the strength of classification features.']

Confidence Level: high
