In [79]:
#导入工具包、预处理后的数据集
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

df = pd.read_csv('dataset/breast-cancer-wisconsin/breast-cancer-wisconsin.csv')

#划分特征列和标签列
#矩阵用大写，向量用小写
X = df.drop('diagnosis',axis=1)

y=df['diagnosis']

feature_names = X.columns

#类标整数化
from sklearn.preprocessing import LabelEncoder#该类使用0到n_classs-1之间的值对目标标签进行编码。该转换器应用于编码目标值y，而不是输入X
le = LabelEncoder()
y = le.fit_transform(y)

df_train = pd.read_csv('dataset/breast-cancer-wisconsin/train_data.csv')
X_train = df_train.drop('diagnosis',axis=1)
y_train = df_train['diagnosis']

df_test = pd.read_csv('dataset/breast-cancer-wisconsin/test_data.csv')
X_test = df_test.drop('diagnosis',axis=1)
y_test = df_test['diagnosis']

#数据增强
df_enhan = pd.read_csv('data_enhan/lime/lime_enhan_5.csv')
X_enhan = df_enhan.drop('diagnosis',axis=1)
y_enhan=df_enhan['diagnosis']
y_enhan.to_numpy()
X_train = pd.concat([X_train,X_enhan],ignore_index=True)
y_train = np.concatenate((y_train, y_enhan))

from sklearn.preprocessing import StandardScaler
transform = StandardScaler() #实例化转换器
#标准化，保证每个维度的特征数据方差为1，均值为0，使得预测结果不会被某些维度过大的特征值而主导
X_train = transform.fit_transform(X_train)
X_test = transform.transform(X_test)

#重新训练模型
import joblib
from sklearn.ensemble import RandomForestClassifier
#model = RandomForestClassifier(n_estimators=100, random_state=42) 
#model.fit(X_train, y_train)
#joblib.dump(model,'saved_model/breast-cancer-wisconsin/RF/enhance_model/lime_enhan_5_model.pkl')
model = joblib.load('saved_model/breast-cancer-wisconsin/RF/enhance_model/lime_enhan_5_model.pkl')

y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)

#混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix_model = confusion_matrix(y_test,y_pred)
print("Confusion Matrix:\n", confusion_matrix_model)

from sklearn.metrics import classification_report
print(classification_report(y_test,y_pred,target_names=['良性','恶性']))

Confusion Matrix:
 [[107   1]
 [  4  59]]
              precision    recall  f1-score   support

          良性       0.96      0.99      0.98       108
          恶性       0.98      0.94      0.96        63

    accuracy                           0.97       171
   macro avg       0.97      0.96      0.97       171
weighted avg       0.97      0.97      0.97       171



In [80]:
y_test

0      0
1      1
2      1
3      0
4      0
      ..
166    0
167    0
168    0
169    1
170    0
Name: diagnosis, Length: 171, dtype: int64

In [81]:
y_pred

array([0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
       1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0,
       1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1,
       0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], dtype=int64)

In [82]:
compare = np.array(y_test == y_pred)
compare

array([ True,  True,  True,  True,  True,  True,  True,  True, False,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True, False,  True,  True,  True,
        True, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,

In [83]:
#测试集中的正确分类的样本的索引
y_to_explain = np.where(compare==True)
y_to_explain

(array([  0,   1,   2,   3,   4,   5,   6,   7,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
         41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
         54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
         67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  78,  79,  80,
         81,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
         95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,
        108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
        121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
        134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
        147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
        160, 161, 162, 163, 165, 166, 167, 168, 169, 170], dtype=int64),)

In [76]:
#初始化LIME可解释性分析算法
import lime
from lime import lime_tabular
explainer = lime_tabular.LimeTabularExplainer(
    training_data=np.array(X_train), # 训练集特征，必须是 numpy 的 Array
    feature_names=feature_names, # 特征列名
    class_names=['良性', '恶性'], # 预测类别名称
    mode='classification' # 分类模式
)

test_data = pd.read_csv('dataset/breast-cancer-wisconsin/test_data.csv')
bound = pd.read_csv('dataset/breast-cancer-wisconsin/bound.csv')
for idx in range(171):
    exp = explainer.explain_instance(
        data_row=X_test[idx], 
        predict_fn=model.predict_proba,
        num_features = 30
    )
    lista = exp.as_list()
    for j in range(0,3):
        s = lista[j]
        substring = s[0].split(" ")
        result = substring[len(substring)-3]
        if(result=="points_worst"):
            test_data.at[idx,"concave points_worst"]=bound.at[1,"concave points_worst"]+10
        elif(result=="points_mean"):
            test_data.at[idx,"concave points_mean"]=bound.at[1,"concave points_mean"]+10
        elif(result=="points_se"):
            test_data.at[idx,"concave points_se"]=bound.at[1,"concave points_se"]+10
        else:
            test_data.at[idx,result]=bound.at[1,result]+10


In [77]:
lime_explain = test_data.loc[y_to_explain]
#lime_explain.to_csv('explain_set_new/data_enhan/upper_bound/RF/lime/lime_explain_5.csv',index=False)

In [78]:
lime_explain

Unnamed: 0,diagnosis,radius_mean,texture_mean,perimeter_mean,area_mean,smoothness_mean,compactness_mean,concavity_mean,concave points_mean,symmetry_mean,...,radius_worst,texture_worst,perimeter_worst,area_worst,smoothness_worst,compactness_worst,concavity_worst,concave points_worst,symmetry_worst,fractal_dimension_worst
0,0,12.470,18.600,81.09,1336.3,0.09965,0.10580,0.08005,0.03821,0.1925,...,37.46,24.64,197.335,677.9,0.1426,0.23780,0.2671,0.10150,0.3014,0.08750
1,1,18.940,21.310,123.60,1336.3,0.09009,0.10290,0.10800,0.07951,0.1582,...,37.46,26.58,197.335,1866.0,0.1193,0.23360,0.2687,0.17890,0.2551,0.06589
2,1,15.460,19.480,101.70,748.9,0.10920,0.12230,10.28241,0.08087,0.1931,...,37.46,26.00,124.900,1156.0,0.1546,0.23940,0.3791,0.15140,0.2837,0.08019
3,0,12.400,17.680,81.47,1336.3,0.10540,0.13160,0.07741,0.02799,0.1811,...,37.46,22.91,197.335,515.8,0.1450,0.26290,0.2403,0.07370,0.2556,0.09359
4,0,11.540,14.440,74.65,402.9,0.09984,0.11200,0.06737,0.02594,0.1818,...,37.46,52.68,197.335,457.8,0.1345,0.21180,0.1797,0.06918,0.2329,0.08134
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
166,0,12.780,16.490,81.37,502.5,0.09831,0.05234,10.28241,0.02864,0.1590,...,37.46,19.76,197.335,554.9,0.1296,0.07061,0.1039,0.05882,0.2383,0.06410
167,0,14.740,40.245,94.70,668.6,0.08275,0.07214,10.28241,0.03027,0.1840,...,16.51,52.68,107.400,826.4,0.1060,0.13760,0.1611,0.10950,0.2722,0.06956
168,0,9.904,18.060,64.60,302.4,0.09699,0.12940,10.28241,0.03716,0.1669,...,37.46,24.39,197.335,390.2,0.1301,0.29500,0.3486,0.09910,0.2614,0.11620
169,1,13.820,24.490,92.33,595.9,0.11620,0.16810,10.28241,0.06759,0.2275,...,16.01,52.68,106.000,788.0,0.1794,0.39660,0.3381,0.15210,0.3651,0.11830
