In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
#!pip install shap --upgrade
import shap
print(shap.__version__)

0.42.0


Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


# Definizione modello

Utilizziamo un dataset fornito con la libreria SHAP. In realta' si tratta dello stesso dataset utilizzato nelle sessioni precedenti (Adult Census Income), utilizziamo quello fornito da SHAP perche' la libreria non supporta direttamente operazioni su pd.DataFrame


In [2]:
X = pd.read_csv('adult_training.csv')
x_train, x_val = train_test_split(X, test_size=0.1, random_state=42)

y_train = x_train['income']
y_val = x_val['income']

x_train = x_train.drop(['income'], axis=1)
x_val = x_val.drop(['income'], axis=1)

Identificazione colonne categoriche e allenamento modello

In [5]:
categorical_features = list(x_train.select_dtypes(include="object").columns)

In [6]:
categorical_features

['work_class',
 'education',
 'marital_status',
 'occupation',
 'relationship',
 'race',
 'sex',
 'native_country']

In [7]:
clf = CatBoostClassifier(n_estimators=100).fit(x_train, y_train, cat_features= categorical_features)

Learning rate set to 0.348433
0:	learn: 0.5459735	total: 65.5ms	remaining: 6.49s
1:	learn: 0.4711175	total: 74.6ms	remaining: 3.66s
2:	learn: 0.4164643	total: 81.3ms	remaining: 2.63s
3:	learn: 0.3854485	total: 87.1ms	remaining: 2.09s
4:	learn: 0.3651656	total: 92.4ms	remaining: 1.75s
5:	learn: 0.3474597	total: 99ms	remaining: 1.55s
6:	learn: 0.3373062	total: 105ms	remaining: 1.4s
7:	learn: 0.3290484	total: 111ms	remaining: 1.28s
8:	learn: 0.3237030	total: 116ms	remaining: 1.17s
9:	learn: 0.3204212	total: 120ms	remaining: 1.08s
10:	learn: 0.3180089	total: 126ms	remaining: 1.02s
11:	learn: 0.3152022	total: 131ms	remaining: 961ms
12:	learn: 0.3120860	total: 136ms	remaining: 911ms
13:	learn: 0.3100019	total: 141ms	remaining: 867ms
14:	learn: 0.3083003	total: 146ms	remaining: 826ms
15:	learn: 0.3065585	total: 151ms	remaining: 795ms
16:	learn: 0.3053146	total: 157ms	remaining: 765ms
17:	learn: 0.3042755	total: 161ms	remaining: 735ms
18:	learn: 0.3035536	total: 166ms	remaining: 709ms
19:	lear

In [8]:
from sklearn.metrics import classification_report

print(classification_report(y_val, clf.predict(x_val)))

              precision    recall  f1-score   support

       <=50K       0.88      0.95      0.91      2274
        >50K       0.79      0.62      0.70       743

    accuracy                           0.87      3017
   macro avg       0.84      0.78      0.81      3017
weighted avg       0.86      0.87      0.86      3017



# Spiegazioni con SHAP

Usiamo una lambda per passare il classicatore in un formato input->output compatibile con SHAP

In [9]:
# questo metodo inizializza il tool di visualizzazione basato su JavaScript
shap.initjs()


In [10]:
# Inizializziamo l'explainer utilizzando il metodo KernelExplainer
# shap.kmeans viene usato per definire il backgroud da cui fare il sampling dei punti fittizi
explainer = shap.TreeExplainer(clf) 

In [11]:
# Generiamo una spiegazione per l'input definito dalla prima riga di x_test
shap_values = explainer.shap_values(x_val.iloc[0:1])

In [12]:
shap.force_plot(explainer.expected_value, shap_values[0], feature_names=x_train.columns)

In [13]:
shap_values = explainer.shap_values(x_val.sample(n=1000))

In [14]:
shap.force_plot(explainer.expected_value, shap_values, feature_names=x_train.columns)

# Analisi robustezza spiegazioni

Selezioniamo un'istanza e vediamo come cambia la spiegazione cambiando leggermente l'input

In [32]:
index_to_expl = 132

shap_values_row = explainer.shap_values(x_val.iloc[index_to_expl:index_to_expl+1])
print(x_val.iloc[index_to_expl])
shap.force_plot(explainer.expected_value, shap_values_row, feature_names=x_train.columns)


age                             24
work_class                 Private
education         High School grad
marital_status       Never-Married
occupation             Blue-Collar
relationship         Not-in-family
race                         White
sex                           Male
capital_gain                     0
capital_loss                     0
hours_per_week                  48
native_country       United-States
Name: 18403, dtype: object


Perturbiamo l'input cambiando il valore della feature 'Capital Loss'

In [37]:
import copy
x_perturb = copy.deepcopy(x_val)
x_perturb['age']=32
shap_values_row = explainer.shap_values(x_perturb.iloc[index_to_expl:index_to_expl+1])
print(x_perturb.iloc[0])
print(shap_values_row.shape)
shap.force_plot(explainer.expected_value, shap_values_row, feature_names=x_train.columns)

age                          32
work_class              Private
education           Prof-School
marital_status    Never-Married
occupation         Professional
relationship      Not-in-family
race                      White
sex                        Male
capital_gain                  0
capital_loss                  0
hours_per_week               55
native_country    United-States
Name: 217, dtype: object
(1, 12)


# Clustering delle spiegazioni

Andiamo a dividere le istanze del test set in base ai tipi di errori commessi

In [43]:
ypred = clf.predict(x_val)

TP = np.where((y_val=='>50K')&(ypred=='>50K'))
TN = np.where((y_val!='>50K')&(ypred!='>50K'))
FN = np.where((y_val=='>50K')&(ypred!='>50K'))
FP = np.where((y_val!='>50K')&(ypred=='>50K'))

In [46]:
print(TP[0].shape,FP[0].shape,TN[0].shape,FN[0].shape)

(462,) (121,) (2153,) (281,)


Una volta isolati i False Positives usiamo il tool Force plot per raggruppare le spiegazioni in base alle similitudini

In [48]:
shap.initjs()
shap_values = explainer.shap_values(x_val.iloc[FN])

In [50]:
shap.plots.force(explainer.expected_value, shap_values, feature_names=x_train.columns)

In [51]:
explainer.expected_value

-2.1574556925131896

In [53]:
shap_values = explainer.shap_values(x_val.iloc[FN[0][273]:FN[0][273]+1])
shap.force_plot(explainer.expected_value, shap_values, feature_names=x_train.columns)