In [1]:
# 导入常用包
import xgboost as xgb
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.feature_selection import SelectFromModel

In [2]:
# 数据集
cancer = datasets.load_breast_cancer()
X = cancer.data
Y = cancer.target

In [3]:
# 数据集的情况
# X.shape
# Y.shape
# X, Y

In [4]:
# 拆分训练集、测试集
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 1/5., random_state = 8)

In [5]:
xgb_train = xgb.DMatrix(X_train, label = Y_train)
xgb_test  = xgb.DMatrix(X_test,  label = Y_test)

In [6]:
params = {"objective": "binary:logistic",
          "booster": "gbtree",
          "eta": 1,
          "max_depth": 2
         }

num_round = 10

In [7]:
watchlist = [(xgb_test, 'eval'), (xgb_train, 'train')]

In [8]:
bst = xgb.train(params, xgb_train, num_round, watchlist)

[0]	eval-error:0.10526	train-error:0.04835
[1]	eval-error:0.10526	train-error:0.03956
[2]	eval-error:0.07018	train-error:0.02857
[3]	eval-error:0.07895	train-error:0.01758
[4]	eval-error:0.07018	train-error:0.01099
[5]	eval-error:0.04386	train-error:0.01099
[6]	eval-error:0.04386	train-error:0.00879
[7]	eval-error:0.05263	train-error:0.00440
[8]	eval-error:0.03509	train-error:0.00220
[9]	eval-error:0.04386	train-error:0.00220


In [9]:
# SHAP预测样本归因分析

pred_contribs = bst.predict(xgb_test, pred_contribs = True)

# 打印第一个样本：因为该测试集包含30个特征，因此输出向量是31维，最后一列即为偏置项
print(pred_contribs[0])

[ 0.          0.2905957   0.          0.17352277 -0.04271989  0.3219791
  0.42704344  0.3493424   0.          0.          0.09485918  0.19571848
  0.          0.73449516  0.         -0.598713    0.         -0.21128973
  0.          0.          0.          0.72367036  1.495664    0.82000506
 -0.6888138   0.          0.6567661   0.9051105   0.         -0.01652571
  1.3566536 ]


In [10]:
# Sabbas预测样本归因分析

pred_contribs = bst.predict(xgb_test, pred_contribs = True, approx_contribs = True)

# 输出结果同上，30个特征对当前样本的贡献 和 最后一列偏置项
print(pred_contribs[0])

[ 0.          0.34236163  0.          0.21349472  0.          0.3800876
  0.          0.40334025  0.          0.          0.13804139  0.28650513
  0.          0.6550807   0.         -0.81003463  0.          0.
  0.          0.          0.          0.5576966   1.8381076   0.6787368
 -0.731996    0.          0.35468978  1.324599    0.          0.
  1.3566536 ]
