### Using SHAP for Feature Drift Analysis
**Description**: Utilize SHapley Additive exPlanations (SHAP) values to analyze feature
importance changes over time, indicating feature drift.

In [None]:
# write your code from here
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
import pytest

# -------------------------------
# SHAP Drift Analysis Functions
# -------------------------------

def compute_shap_importance(model, X):
    """
    Computes mean absolute SHAP values for each feature.

    Raises:
        TypeError: If input X is not a pandas DataFrame.
        ValueError: If input dataframe is empty or contains non-numeric columns.
    """
    if not isinstance(X, pd.DataFrame):
        raise TypeError("Input X must be a pandas DataFrame.")
    if X.empty:
        raise ValueError("Input dataframe X is empty.")
    if not all(np.issubdtype(dtype, np.number) for dtype in X.dtypes):
        raise ValueError("All columns in input dataframe X must be numeric.")

    explainer = shap.Explainer(model, X)
    shap_values = explainer(X)
    mean_abs_shap = np.abs(shap_values.values).mean(axis=0)
    return pd.Series(mean_abs_shap, index=X.columns)

def plot_shap_drift(train_shap: pd.Series, test_shap: pd.Series):
    """
    Plots the difference in SHAP feature importances to identify drift.

    Raises:
        TypeError: If inputs are not pandas Series.
        ValueError: If input series are empty or have mismatched indices.
    """
    if not isinstance(train_shap, pd.Series) or not isinstance(test_shap, pd.Series):
        raise TypeError("Inputs must be pandas Series.")
    if train_shap.empty or test_shap.empty:
        raise ValueError("Input SHAP Series cannot be empty.")
    if not train_shap.index.equals(test_shap.index):
        raise ValueError("train_shap and test_shap must have the same indices.")

    drift_df = pd.DataFrame({
        'train_importance': train_shap,
        'test_importance': test_shap
    })
    drift_df['change'] = drift_df['test_importance'] - drift_df['train_importance']
    drift_df = drift_df.sort_values(by='change', ascending=False)

    drift_df[['train_importance', 'test_importance']].plot(
        kind='bar', figsize=(12, 6), title='Feature Importance Drift (SHAP)')
    plt.ylabel('Mean |SHAP value|')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return drift_df


# ---------------------------------
# Pytest Unit Tests for the functions
# ---------------------------------

def create_dummy_model_and_data():
    X = pd.DataFrame({
        'a': np.random.rand(10),
        'b': np.random.rand(10)
    })
    y = (X['a'] + X['b'] > 1).astype(int)
    model = RandomForestClassifier(n_estimators=10, random_state=42)
    model.fit(X, y)
    return model, X

def test_compute_shap_importance_valid():
    model, X = create_dummy_model_and_data()
    result = compute_shap_importance(model, X)
    assert isinstance(result, pd.Series)
    assert set(result.index) == set(X.columns)

def test_compute_shap_importance_empty_df():
    model, _ = create_dummy_model_and_data()
    with pytest.raises(ValueError):
        compute_shap_importance(model, pd.DataFrame())

def test_compute_shap_importance_non_numeric():
    model, _ = create_dummy_model_and_data()
    df = pd.DataFrame({'a': ['x', 'y'], 'b': ['z', 'w']})
    with pytest.raises(ValueError):
        compute_shap_importance(model, df)

def test_plot_shap_drift_valid():
    series1 = pd.Series([0.1, 0.2], index=['a', 'b'])
    series2 = pd.Series([0.3, 0.1], index=['a', 'b'])
    df = plot_shap_drift(series1, series2)
    assert 'change' in df.columns
    assert df.shape[0] == 2

def test_plot_shap_drift_empty_series():
    with pytest.raises(ValueError):
        plot_shap_drift(pd.Series(dtype=float), pd.Series(dtype=float))

def test_plot_shap_drift_mismatched_indices():
    s1 = pd.Series([0.1], index=['a'])
    s2 = pd.Series([0.2], index=['b'])
    with pytest.raises(ValueError):
        plot_shap_drift(s1, s2)

def test_plot_shap_drift_invalid_type():
    with pytest.raises(TypeError):
        plot_shap_drift([0.1, 0.2], [0.3, 0.4])


# To run tests, save this code in a file named e.g. test_shap_drift.py
# and run: pytest test_shap_drift.py
