# XAI Metrics based on Surrogate Model

In [None]:
from holisticai.utils import BinaryClassificationProxy
from sklearn.ensemble import RandomForestClassifier

from holisticai.datasets import load_dataset

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm


from holisticai.inspection import compute_permutation_importance, compute_conditional_permutation_feature_importance
from holisticai.explainability.metrics import classification_surrogate_explainability_metrics
from holisticai.utils.data_preprocessor import simple_preprocessor
from holisticai.efficacy.metrics import classification_efficacy_metrics

# Binary Classification (Adult Dataset)

In [None]:
dataset = load_dataset('adult', preprocessed=False)
dataset = dataset.train_test_split(test_size=0.2, random_state=42)
dataset

We use a simple preprocessor to normalize the data

In [4]:
train = dataset['train']
test = dataset['test']
Xt_train, Xt_test, yt_train, yt_test = simple_preprocessor(train['X'], test['X'], train['y'], test['y'])

**Define a Proxy Model**: A proxy model enables standardized use of your model across multiple functions. You simply provide the essential components required for each type of proxy. For instance, in binary classification, only the predict function, predict_proba function, and the labels are needed.

In [None]:
model = RandomForestClassifier(random_state=42)
model.fit(Xt_train, yt_train)

proxy = BinaryClassificationProxy(predict=model.predict, predict_proba=model.predict_proba, classes=model.classes_)
proxy

### Efficacy Metrics

In [None]:
y_pred_test = proxy.predict(Xt_test)
efficacy_metrics = classification_efficacy_metrics(yt_test, y_pred_test)
efficacy_metrics

### Surrogate Model

We create a surrogate model, with supported types including ```shallow_tree``` (a tree model with depth=3) and ```tree```.

In [None]:
from holisticai.utils.surrogate_models import create_surrogate_model

y_train_pred = proxy.predict(Xt_train)    
surrogate = create_surrogate_model(Xt_train, y_train_pred, surrogate_type="shallow_tree")
surrogate

### Surrogate XAI Metrics

**Accuracy Degradation**: Evaluates the extent to which accuracy decreases relative to the true labels when constructing the surrogate model.

**Surrogate Fidelity**: Assesses how closely the surrogate model’s predictions align with those of the original model.

**Surrogate Feature Stability**: Measures the consistency of features used in the surrogate tree across multiple bootstraps.

In [None]:
y_train_pred = proxy.predict(Xt_train)    
surrogate = create_surrogate_model(Xt_train, y_train_pred, surrogate_type="shallow_tree")

y_test_pred = proxy.predict(Xt_test)
surrogate_metrics = classification_surrogate_explainability_metrics(Xt_test, yt_test, y_test_pred, surrogate)
surrogate_metrics