# SHAPXplain: Example Usage

This notebook demonstrates how to use the SHAPXplain package to integrate SHAP explanations with LLMs, including both synchronous and asynchronous approaches.

## LLM Setup and Environment Variables

SHAPXplain uses [pydantic-ai](https://github.com/pydantic/pydantic-ai) to interact with LLMs. To use your preferred LLM, you'll need to set the appropriate environment variables:

- For OpenAI: `OPENAI_API_KEY`
- For Anthropic: `ANTHROPIC_API_KEY`
- For other providers, refer to the [pydantic-ai documentation](https://ai.pydantic.dev/)

You can set these directly in your environment or use a `.env` file in your project root. Below is an example of how to set them programmatically:

In [1]:
# Uncomment to set environment variables programmatically
# import os
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"  # Replace with your actual API key

# Or load from .env file
# from dotenv import load_dotenv
# load_dotenv()  # This will load variables from .env file in the current directory

False

In [2]:
import numpy as np
import asyncio
import time
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from shap import TreeExplainer
from shapxplain import ShapLLMExplainer
from pydantic_ai import Agent
import nest_asyncio

# This is needed for running async code in Jupyter notebooks
nest_asyncio.apply()

print("✅ All imports successful")

✅ All imports successful


## 1. Preparing the Data and Model

First, let's load the Iris dataset, train a random forest classifier, and generate SHAP values.

In [3]:
# Load data and train model
data = load_iris()
X, y = data.data, data.target
rf_model = RandomForestClassifier(random_state=42)
rf_model.fit(X, y)

# Generate SHAP values
explainer = TreeExplainer(rf_model)
shap_values = explainer.shap_values(X)

print(f"✅ Model trained on {len(X)} samples")
print(f"✅ SHAP values generated with shape: {np.array(shap_values).shape}")
print(f"✅ Iris classes: {data.target_names}")

✅ Model trained on 150 samples
✅ SHAP values generated with shape: (150, 4, 3)
✅ Iris classes: ['setosa' 'versicolor' 'virginica']


## 2. Creating the ShapLLMExplainer

Now we'll create our ShapLLMExplainer with the enhanced features including retry logic and caching.

In [4]:
# Create an LLM agent - Pydantic-ai will pick up environment var if set
llm_agent = Agent(
    model="openai:gpt-4o"
)  # Can also use other models like "anthropic:claude-3-opus-20240229"

print(f"✅ Created LLM agent with model: {llm_agent.model}")

# Instantiate the enhanced SHAPXplain explainer with new parameters
llm_explainer = ShapLLMExplainer(
    model=rf_model,
    llm_agent=llm_agent,
    feature_names=data.feature_names,
    significance_threshold=0.1,
    max_retries=3,  # New: Number of retries for failed API calls
    retry_delay=1.0,  # New: Base delay between retries in seconds
    cache_size=1000,  # New: Size of the LRU cache for LLM queries
)

print(f"✅ ShapLLMExplainer created with features: {data.feature_names}")
print(
    f"✅ Error handling: {llm_explainer.max_retries} retries with {llm_explainer.retry_delay}s base delay"
)

✅ Created LLM agent with model: OpenAIModel(model_name='gpt-4o')
✅ ShapLLMExplainer created with features: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
✅ Error handling: 3 retries with 1.0s base delay


## 3. Preparing a Single Data Point for Explanation

Let's prepare a single data point for explanation.

In [5]:
# Select a data point (index 0)
data_point_index = 0
data_point = X[data_point_index]

# Print selected data point as a table
print("\n🔹 Selected Data Point:\n")
print(f"{'Feature':<20}{'Value':<10}")
print("-" * 30)
for feature, value in zip(data.feature_names, data_point):
    print(f"{feature:<20}{value:<10.2f}")

# Get predictions
prediction_probs = rf_model.predict_proba(data_point.reshape(1, -1))[0]
predicted_class_idx = rf_model.predict(data_point.reshape(1, -1))[0]
prediction_class = data.target_names[predicted_class_idx]

# Print prediction results
print(f"\n✅ Predicted class: {prediction_class} ({prediction_probs[predicted_class_idx]:.2%} confidence)")
print("📊 Class probabilities:")
for i, prob in enumerate(prediction_probs):
    print(f"  - {data.target_names[i]}: {prob:.2%}")

# Extract SHAP values for a specific class and data point
class_shap_values = shap_values[data_point_index, :, predicted_class_idx]

# Print SHAP values as a table without pandas
print(f"\n🔍 SHAP Values for Predicted Class ({prediction_class}):\n")
print(f"{'Feature':<20}{'SHAP Value':<15}")
print("-" * 35)
for feature, shap_value in zip(data.feature_names, class_shap_values):
    print(f"{feature:<20}{shap_value:<15.4f}")



🔹 Selected Data Point:

Feature             Value     
------------------------------
sepal length (cm)   5.10      
sepal width (cm)    3.50      
petal length (cm)   1.40      
petal width (cm)    0.20      

✅ Predicted class: setosa (100.00% confidence)
📊 Class probabilities:
  - setosa: 100.00%
  - versicolor: 0.00%
  - virginica: 0.00%

🔍 SHAP Values for Predicted Class (setosa):

Feature             SHAP Value     
-----------------------------------
sepal length (cm)   0.0732         
sepal width (cm)    0.0025         
petal length (cm)   0.2974         
petal width (cm)    0.2936         


## 4. Using Data Contracts for Enhanced Explanations

One of the most powerful features of SHAPXplain is the ability to provide domain-specific context through `additional_context`, effectively creating a "data contract" that guides the LLM. 

This data contract provides the LLM with:
- Domain-specific terminology and context
- Feature descriptions and normal ranges
- Target class characteristics
- Units of measurement and application context

This helps generate more accurate, relevant, and actionable explanations.

In [6]:
# Create a comprehensive data contract with domain context
iris_context = {
    "dataset": "Iris",
    "domain": "botany",
    "feature_descriptions": {
        "sepal length": "Length of the sepal in cm. Ranges from 4.3 to 7.9 cm across species.",
        "sepal width": "Width of the sepal in cm. Ranges from 2.0 to 4.4 cm across species.",
        "petal length": "Length of the petal in cm. Ranges from 1.0 to 6.9 cm across species.",
        "petal width": "Width of the petal in cm. Ranges from 0.1 to 2.5 cm across species.",
    },
    "species_characteristics": {
        "setosa": "Characterized by small petals, both in length and width. Sepals tend to be wider.",
        "versicolor": "Has medium-sized petals and sepals.",
        "virginica": "Typically has the largest petals and longer sepals.",
    },
    "measurement_units": "centimeters",
    "application": "Species classification for botanical research",
}

print("📝 Data contract created with the following elements:")
for key, value in iris_context.items():
    if isinstance(value, dict):
        print(f"  - {key}: {len(value)} entries")
    else:
        print(f"  - {key}: {value}")

📝 Data contract created with the following elements:
  - dataset: Iris
  - domain: botany
  - feature_descriptions: 4 entries
  - species_characteristics: 3 entries
  - measurement_units: centimeters
  - application: Species classification for botanical research


In [7]:
# Generate explanation using the data contract
print("⏳ Generating explanation using the data contract and LLM...")
t0 = time.time()

explanation = llm_explainer.explain(
    shap_values=class_shap_values,
    data_point=data_point,
    prediction=prediction_probs[predicted_class_idx],
    prediction_class=prediction_class,
    additional_context=iris_context,
)

elapsed = time.time() - t0
print(f"✅ Explanation generated in {elapsed:.2f} seconds")

# Access the explanation components
print("\n📋 Explanation Summary:")
print(explanation.summary)

print("\n📋 Detailed Explanation:")
print(explanation.detailed_explanation)

print("\n📋 Recommendations:")
for i, rec in enumerate(explanation.recommendations, 1):
    print(f"  {i}. {rec}")

print("\n📋 Confidence Level:")
print(explanation.confidence_level)

print("\n📋 Feature Interactions:")
if explanation.feature_interactions:
    for interaction, desc in explanation.feature_interactions.items():
        print(f"  - {interaction}: {desc}")
else:
    print("  No significant feature interactions identified")

⏳ Generating explanation using the data contract and LLM...
✅ Explanation generated in 8.73 seconds

📋 Explanation Summary:
The prediction that the flower is setosa is primarily driven by the lengths and widths of its petals and sepals, which are characteristic of this species.

📋 Detailed Explanation:
In this case, the flower's petals are short and narrow, which is typical for the setosa species. Additionally, the sepal length and width contribute to this classification, with wider sepals being another indicator of setosa. The combination of small petal size and specific sepal dimensions helps distinguish setosa from other species, which usually have larger petals.

📋 Recommendations:
  1. Ensure future analyses include detailed measurements of petal and sepal dimensions as they are crucial for accurate species classification.
  2. Consider using these identified characteristic features to design educational materials or tools that help in the identification of setosa in field researc

## 5. Asynchronous Explanation Generation

Now let's demonstrate the new asynchronous capabilities of SHAPXplain, which are particularly useful for web applications and services where you don't want to block the main thread while waiting for the LLM response.

In [8]:
# Define an async function to get an explanation
async def get_async_explanation():
    print("⏳ Generating explanation asynchronously...")
    t0 = time.time()

    # Using the same data point and context as before
    async_explanation = await llm_explainer.explain_async(
        shap_values=class_shap_values,
        data_point=data_point,
        prediction=prediction_probs[predicted_class_idx],
        prediction_class=prediction_class,
        additional_context=iris_context,
    )

    elapsed = time.time() - t0
    print(f"✅ Async explanation generated in {elapsed:.2f} seconds")
    return async_explanation


# Run the async function
async_explanation = asyncio.run(get_async_explanation())

# Access the explanation components
print("\n📋 Explanation Summary:")
print(async_explanation.summary)

print("\n📋 Detailed Explanation:")
print(async_explanation.detailed_explanation)

print("\n📋 Recommendations:")
for i, rec in enumerate(async_explanation.recommendations, 1):
    print(f"  {i}. {rec}")

print("\n📋 Confidence Level:")
print(async_explanation.confidence_level)

print("\n📋 Feature Interactions:")
if async_explanation.feature_interactions:
    for interaction, desc in async_explanation.feature_interactions.items():
        print(f"  - {interaction}: {desc}")
else:
    print("  No significant feature interactions identified")

⏳ Generating explanation asynchronously...
✅ Async explanation generated in 8.44 seconds

📋 Explanation Summary:
The prediction that the flower is of the species 'setosa' is mainly driven by its small petal length and width, which are highly characteristic of this species.

📋 Detailed Explanation:
The flower being analyzed has certain physical traits that match well with what is typically seen in the 'setosa' species. The petals are quite short and narrow, aligning with the known characteristics of 'setosa', which is recognized for having smaller petals. Additionally, the sepal measurements also support this classification; while the sepal length is relatively average, the width is on the broader side, another trait typical of 'setosa'. These combined factors strongly suggest that the flower belongs to the 'setosa' species, which is commonly distinguished by these smaller, distinctively sized features.

📋 Recommendations:
  1. Ensure the database of known 'setosa' features is kept upda

## 6. Batch Processing - Synchronous vs Asynchronous

Let's compare the performance of synchronous and asynchronous batch processing. For large batches, asynchronous processing can significantly improve performance by processing multiple items in parallel.

In [9]:
# Prepare a small batch of data for demonstration
batch_size = 5
batch_indices = range(batch_size)

# Prepare the batch data
data_points = [X[i] for i in batch_indices]
predictions = [
    rf_model.predict_proba(X[i].reshape(1, -1))[0][
        rf_model.predict(X[i].reshape(1, -1))[0]
    ]
    for i in batch_indices
]
prediction_classes = [
    data.target_names[rf_model.predict(X[i].reshape(1, -1))[0]] for i in batch_indices
]
shap_values_batch = [
    shap_values[i, :, rf_model.predict(X[i].reshape(1, -1))[0]] for i in batch_indices
]

print(f"✅ Prepared batch data for {batch_size} samples")
print(f"📊 Sample classes: {prediction_classes}")

✅ Prepared batch data for 5 samples
📊 Sample classes: [np.str_('setosa'), np.str_('setosa'), np.str_('setosa'), np.str_('setosa'), np.str_('setosa')]


In [10]:
# 6.1 Synchronous batch processing
print("⏳ Starting synchronous batch processing...")
start_time = time.time()

sync_batch_response = llm_explainer.explain_batch(
    shap_values_batch=shap_values_batch,
    data_points=data_points,
    predictions=predictions,
    prediction_classes=prediction_classes,
    additional_context=iris_context,
)

sync_time = time.time() - start_time
print(f"✅ Synchronous batch processing completed in {sync_time:.2f} seconds")
print(f"📊 Generated {len(sync_batch_response.responses)} explanations")
print(f"📝 First explanation summary: {sync_batch_response.responses[0].summary}")

⏳ Starting synchronous batch processing...
✅ Synchronous batch processing completed in 38.40 seconds
📊 Generated 5 explanations
📝 First explanation summary: The prediction that the flower is setosa is primarily driven by the lengths and widths of its petals and sepals, which are characteristic of this species.


In [11]:
# 6.2 Asynchronous batch processing
async def process_batch_async():
    print("⏳ Starting asynchronous batch processing...")
    start_time = time.time()

    async_batch_response = await llm_explainer.explain_batch_async(
        shap_values_batch=shap_values_batch,
        data_points=data_points,
        predictions=predictions,
        prediction_classes=prediction_classes,
        additional_context=iris_context,
    )

    async_time = time.time() - start_time
    print(f"✅ Asynchronous batch processing completed in {async_time:.2f} seconds")
    return async_batch_response, async_time


async_batch_response, async_time = asyncio.run(process_batch_async())

print(f"📊 Generated {len(async_batch_response.responses)} explanations")
print(f"📝 First explanation summary: {async_batch_response.responses[0].summary}")

# Performance comparison
speedup = sync_time / async_time
print(f"\n⚡ Performance comparison:")
print(f"  - Synchronous: {sync_time:.2f} seconds")
print(f"  - Asynchronous: {async_time:.2f} seconds")
print(f"  - Speedup: {speedup:.2f}x faster with asynchronous processing")
print(
    f"  - Efficiency: {speedup / batch_size:.2%} of theoretical maximum ({batch_size}x)"
)

⏳ Starting asynchronous batch processing...
✅ Asynchronous batch processing completed in 16.81 seconds
📊 Generated 5 explanations
📝 First explanation summary: The prediction was primarily driven by the size of the petals, which are notably small, and this is a strong indication of the setosa species.

⚡ Performance comparison:
  - Synchronous: 38.40 seconds
  - Asynchronous: 16.81 seconds
  - Speedup: 2.28x faster with asynchronous processing
  - Efficiency: 45.68% of theoretical maximum (5x)


## 7. Error Handling

SHAPXplain includes robust error handling with retry logic. Below is a demonstration of how to handle errors.

In [12]:
# Example of handling invalid inputs
try:
    print("⚠️ Deliberately creating an error situation...")
    # Deliberately create a mismatch between shap_values and data_point lengths
    invalid_explanation = llm_explainer.explain(
        shap_values=class_shap_values[:3],  # Only using first 3 values
        data_point=data_point,  # Using all 4 features
        prediction=prediction_probs[predicted_class_idx],
        prediction_class=prediction_class,
    )
except ValueError as e:
    print(f"✅ Successfully handled input validation error: {e}")
except RuntimeError as e:
    print(f"✅ Successfully handled LLM query error: {e}")

⚠️ Deliberately creating an error situation...
✅ Successfully handled input validation error: Length mismatch: shap_values (3) != data_point (4)


## 8. Batch Insights

One of the key advantages of batch processing with SHAPXplain is the ability to get cross-case insights. These insights identify common patterns, feature distributions, and general recommendations across the entire batch.

In [13]:
# Examine the batch response summary statistics and insights
print("📊 Batch Summary Statistics:")
print(f"  - Total processed: {async_batch_response.summary_statistics['total_processed']}")
print(f"  - Confidence distribution: {async_batch_response.summary_statistics['confidence_summary']}")

print("\n📊 Common Features:")
if async_batch_response.summary_statistics.get("common_features"):
    for feature, count in async_batch_response.summary_statistics["common_features"]:
        print(f"  - {feature}: appears in {count}/{batch_size} cases ({count / batch_size:.0%})")
else:
    print("  No common features found across the batch")

print("\n📊 Batch Insights:")
for i, insight in enumerate(async_batch_response.batch_insights, 1):
    print(f"  {i}. {insight}")

📊 Batch Summary Statistics:
  - Total processed: 5
  - Confidence distribution: {<SignificanceLevel.HIGH: 'high'>: 5, <SignificanceLevel.MEDIUM: 'medium'>: 0, <SignificanceLevel.LOW: 'low'>: 0}

📊 Common Features:
  - petal length (cm): appears in 3/5 cases (60%)
  - petal width (cm): appears in 3/5 cases (60%)
  - sepal length (cm): appears in 3/5 cases (60%)

📊 Batch Insights:
  1. The model consistently predicts the same outcome with high confidence, indicating robust performance but potentially limited variability in the test data.
  2. The importance of petal length, petal width, and sepal length is consistent across the majority of predictions, suggesting these features are critical in the decision-making process of the model.
  3. There are no outliers or unusual predictions, as all predictions are uniformly high, showing no deviation in the model's output.
  4. Petal length and petal width frequently interact together, suggesting a strong relationship between these features in 

## 9. Medical Example: Using a Different Data Contract

Let's demonstrate how the data contract can be customized for a different domain - in this case, medical diagnostics. This showcases the flexibility of SHAPXplain across different domains.

In [14]:
# For demonstration, we'll reuse the same model but with a medical context
# In a real-world scenario, you would use an actual medical model

# Create a medical data contract
medical_context = {
    "domain": "medical_diagnosis",
    "feature_descriptions": {
        "sepal length": "Patient age in years. Normal range varies by condition.",
        "sepal width": "Blood glucose level in mg/dL. Normal range: 70-99 mg/dL fasting.",
        "petal length": "Systolic blood pressure in mmHg. Normal range: <120 mmHg.",
        "petal width": "Body Mass Index. Normal range: 18.5-24.9.",
    },
    "reference_ranges": {
        "blood_glucose": {
            "low": "<70",
            "normal": "70-99",
            "prediabetes": "100-125",
            "diabetes": ">126",
        },
        "blood_pressure": {
            "normal": "<120",
            "elevated": "120-129",
            "stage1": "130-139",
            "stage2": ">=140",
        },
    },
    "diagnostic_categories": {
        "setosa": "Type 1 Diabetes",
        "versicolor": "Type 2 Diabetes",
        "virginica": "Gestational Diabetes",
    },
    "patient_context": "65-year-old male with family history of type 2 diabetes",
}

print("📝 Medical data contract created")
print("⏳ Generating medical context explanation...")

# Generate explanation using the medical contract
medical_explanation = llm_explainer.explain(
    shap_values=class_shap_values,
    data_point=data_point,
    prediction=prediction_probs[predicted_class_idx],
    prediction_class=prediction_class,
    additional_context=medical_context,
)

print("\n📋 Medical Context Summary:")
print(medical_explanation.summary)

print("\n📋 Medical Recommendations:")
for i, rec in enumerate(medical_explanation.recommendations, 1):
    print(f"  {i}. {rec}")

📝 Medical data contract created
⏳ Generating medical context explanation...

📋 Medical Context Summary:
The prediction indicates a diagnosis of Type 1 Diabetes, largely driven by elevated blood pressure (petal length) and a higher BMI (petal width), with age and blood glucose level also playing supportive roles.

📋 Medical Recommendations:
  1. Encourage regular monitoring of blood pressure and BMI to maintain them within optimal ranges.
  2. Schedule regular check-ups to ensure early detection and management of any diabetic symptoms.


## 10. Conclusion

This notebook demonstrated the key features of SHAPXplain:

1. **Data Contracts**: Providing domain-specific context to enhance explanations
2. **Asynchronous API**: Processing explanations in parallel for improved performance
3. **Batch Processing**: Efficiently handling multiple predictions with both sync and async methods
4. **Error Handling**: Using retry logic to ensure robust operation
5. **Batch Insights**: Getting cross-case patterns and recommendations
6. **Domain Flexibility**: Adapting explanations to different contexts such as botany or medicine

The performance comparison showed that async processing can be significantly faster for batch operations, with a speedup approaching the theoretical maximum of N times faster for N parallel items.

SHAPXplain helps bridge the gap between complex SHAP values and human-understandable insights, making machine learning models more interpretable and trustworthy.