## Usage Example

Let's load some data.

In [5]:
from sklearn.datasets import load_boston

dataset = load_boston()
X = dataset.data
y = (dataset.target > 21).astype(int)

Let's train a very simple XGB model.

In [6]:
from xgboost import XGBClassifier


# NOTE: The model must implement `predict`, `predict_proba`, and `get_booster` methods.
model = XGBClassifier()
model.fit(X, y)





XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
              importance_type='gain', interaction_constraints='',
              learning_rate=0.300000012, max_delta_step=0, max_depth=6,
              min_child_weight=1, missing=nan, monotone_constraints='()',
              n_estimators=100, n_jobs=4, num_parallel_tree=1, random_state=0,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
              tree_method='exact', validate_parameters=1, verbosity=None)

Let's evaluate the model.

In [7]:
from sklearn.metrics import classification_report

print(classification_report(y, model.predict(X)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       249
           1       1.00      1.00      1.00       257

    accuracy                           1.00       506
   macro avg       1.00      1.00      1.00       506
weighted avg       1.00      1.00      1.00       506



#### Explanations

The following example shows how to use the package to generate the feature importance explanations with 100-NN as counterfactual generator and the SHAP TreeExplainer as feature importance estimator.

Let's setup the explainers.

In [8]:
from cfshap.utils.preprocessing import EfficientQuantileTransformer
from cfshap.counterfactuals import KNNCounterfactuals
from cfshap.attribution import TreeExplainer, CompositeExplainer
from cfshap.trend import TrendEstimator

MAX_SAMPLES = 10000

# We will need a scaler in the input space for the counterfactual generator
scaler = EfficientQuantileTransformer()
scaler.fit(X)

# Background/Counterfactuals generator
background_generator = KNNCounterfactuals(
    model=model,
    X=X,
    n_neighbors=100,
    distance='cityblock',
    scaler=scaler,
    max_samples=MAX_SAMPLES,
)

# We will need a trend estimator for the attribution estimator
trend_estimator = TrendEstimator(strategy='mean')

# Feature importance estimator
importance_estimator = TreeExplainer(
    model,
    data=None,
    trend_estimator=trend_estimator,
    max_samples=MAX_SAMPLES,
)

# Let's setup the explainer
explainer = CompositeExplainer(
    background_generator,
    importance_estimator,
)

Let's compute the explanations.

In [9]:
# Let's generate the explanations for the first 10 samples
explanations = explainer(X[:10])

Let's check the feature importance.

In [13]:
import pandas as pd

pd.DataFrame(explanations.values, columns = dataset.feature_names)

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT
0,-0.368411,0.001084,0.016953,6.6e-05,-0.071173,4.944348,0.451436,0.599951,-0.12442,0.468224,2.5668,0.262139,4.089339
1,-0.688685,0.025537,0.117732,-0.005397,0.482204,4.693255,-0.37566,0.659881,-0.073497,0.737982,2.195861,0.174194,1.450351
2,-0.620587,0.052311,0.077167,-0.000131,0.274878,5.873381,0.460579,0.618844,-0.067338,0.596424,1.719293,1.09438,3.774376
3,-0.564064,0.052311,-0.024788,-0.000697,-0.027789,6.080283,1.029735,0.026832,-0.09411,0.537722,0.824205,0.410816,4.969478
4,0.184132,0.040598,-0.017089,-0.000951,-0.041633,6.125439,1.021442,-0.010025,-0.09891,0.545943,0.753446,0.025192,5.016416
5,-0.679237,0.055874,-0.0083,-0.004868,0.023068,4.939248,0.722469,0.14656,-0.073102,0.533288,0.847074,0.365709,5.043894
6,0.520886,0.033432,0.165588,-0.004216,0.376512,0.846783,1.195629,0.915779,0.055242,0.292543,3.424313,0.603873,0.481862
7,0.987279,0.015798,0.337859,-0.016991,0.358317,2.485108,-0.892284,0.692895,-0.002049,0.2147,3.336585,0.318524,0.27225
8,-0.301995,0.057322,-0.046974,-0.038744,-0.163927,-3.900997,-2.276446,-2.736532,-0.115068,-0.372046,0.581672,-0.156959,-3.37851
9,0.220481,0.065478,0.046834,-0.066511,-0.208922,-2.391635,-2.432828,-3.073034,-0.103017,-0.273681,0.769095,-0.363471,-2.573689
