In [3]:
import numpy as np
from sklearn import tree
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

In [4]:
class DecisionTree():
    def createDataSet(self):
        data = []
        labels = []
        with open(r"..\data\Ch03\data.txt", "r") as ifile:
            for line in ifile:
                tokens = line.strip().split(" ")
                data.append([float(tk) for tk in tokens[:-1]])
                labels.append(tokens[-1])
        
        x = np.array(data)
        labels = np.array(labels)
        y = np.zeros(labels.shape)
        
        """convert "fat" or "thin" into 1 and 0 """
        y[labels == "fat"] = 1
        print(data, '-------', x, '-------', labels, '-------', y)
        return x, y
    
    def predict_train(self, x_train, y_train):
        clf = tree.DecisionTreeClassifier(criterion = 'entropy')
        clf.fit(x_train, y_train)
        
        print('feature_importances_: %s' % clf.feature_importances_)
        
        y_pre = clf.predict(x_train)
        print(y_pre)
        print(y_train)
        print(np.mean(y_pre == y_train))
        return y_pre, clf
    
    def show_precision_recall(self, x, y, clf, y_train, y_pre):
        precision, recall, thresholds = precision_recall_curve(y_train, y_pre)
        answer = clf.predict_proba(x)[:, 1]
        target_names = ['thin', 'fat']
        print(classification_report(y, answer, target_names=target_names))
        print(answer)
        print(y)           

In [5]:
if __name__ == '__main__':
    DT = DecisionTree()
    x, y = DT.createDataSet()

    ''' 拆分训练数据与测试数据， 80%做训练 20%做测试 '''
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
    print('拆分数据：', x_train, x_test, y_train, y_test)

    # 得到训练的预测结果集
    y_pre, clf = DT.predict_train(x_train, y_train)

    # 展现 准确率与召回率
    DT.show_precision_recall(x, y, clf, y_train, y_pre)

[[1.5, 50.0], [1.5, 60.0], [1.6, 40.0], [1.6, 60.0], [1.7, 60.0], [1.7, 80.0], [1.8, 60.0], [1.8, 90.0], [1.9, 70.0], [1.9, 80.0]] ------- [[ 1.5 50. ]
 [ 1.5 60. ]
 [ 1.6 40. ]
 [ 1.6 60. ]
 [ 1.7 60. ]
 [ 1.7 80. ]
 [ 1.8 60. ]
 [ 1.8 90. ]
 [ 1.9 70. ]
 [ 1.9 80. ]] ------- ['thin' 'fat' 'thin' 'fat' 'thin' 'fat' 'thin' 'fat' 'thin' 'thin'] ------- [0. 1. 0. 1. 0. 1. 0. 1. 0. 0.]
拆分数据： [[ 1.7 80. ]
 [ 1.5 50. ]
 [ 1.6 40. ]
 [ 1.9 80. ]
 [ 1.6 60. ]
 [ 1.7 60. ]
 [ 1.8 90. ]
 [ 1.5 60. ]] [[ 1.8 60. ]
 [ 1.9 70. ]] [1. 0. 0. 0. 1. 0. 1. 1.] [0. 0.]
feature_importances_: [0.58187775 0.41812225]
[1. 0. 0. 0. 1. 0. 1. 1.]
[1. 0. 0. 0. 1. 0. 1. 1.]
1.0
              precision    recall  f1-score   support

        thin       1.00      1.00      1.00         6
         fat       1.00      1.00      1.00         4

    accuracy                           1.00        10
   macro avg       1.00      1.00      1.00        10
weighted avg       1.00      1.00      1.00        10

[0. 1. 0. 1. 