In [1]:
import numpy as np


#加载数据集
def load_data():
    with open('决策树数据.txt') as fr:
        lines = fr.readlines()

    x = np.empty((len(lines), 7), dtype=int)

    for i in range(len(lines)):
        line = lines[i].strip().split(',')
        x[i] = line

    test_x = x[10:]
    #x = x[:10]

    return x, test_x


x, test_x = load_data()
x, test_x

(array([[0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1, 1, 0],
        [1, 1, 0, 1, 1, 1, 0],
        [0, 2, 2, 0, 2, 1, 1],
        [2, 1, 1, 1, 0, 0, 1],
        [1, 1, 0, 0, 1, 1, 1],
        [2, 0, 0, 2, 2, 0, 1],
        [0, 0, 1, 1, 1, 0, 1],
        [0, 0, 1, 0, 0, 0, 0],
        [2, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 0, 0],
        [1, 1, 1, 1, 1, 0, 1],
        [2, 2, 2, 2, 2, 0, 1],
        [2, 0, 0, 2, 2, 1, 1],
        [0, 1, 0, 1, 0, 0, 1]]),
 array([[0, 0, 1, 0, 0, 0, 0],
        [2, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 0, 0],
        [1, 1, 1, 1, 1, 0, 1],
        [2, 2, 2, 2, 2, 0, 1],
        [2, 0, 0, 2, 2, 1, 1],
        [0, 1, 0, 1, 0, 0, 1]]))

In [2]:
#cart决策时使用gini系数,而不是entropy
def get_gini(_x, col, value):
    gini = 0

    #使用列值把数据分为和value相等的数据和不想等的两部分
    for symbol in ['eq', 'neq']:
        sub_x = _x[_x[:, col] == value]
        if symbol == 'neq':
            sub_x = _x[_x[:, col] != value]

        if len(sub_x) == 0:
            gini += 1e20

        #计算两部分数据各自出现的概率
        prob = len(sub_x) / len(_x)

        #计算在两部分数据中,y的占比
        prob_y0 = np.sum(sub_x[:, -1] == 0) / len(sub_x)
        prob_y1 = np.sum(sub_x[:, -1] == 1) / len(sub_x)

        #计算两边的基尼指数
        #gini = 数据出现的概率 * (1 - y0概率的平方 - y1概率的平方)
        gini += prob * (1 - np.power(prob_y0, 2) - np.power(prob_y1, 2))

    #两部分的gini的和等于总体的gini
    return gini


get_gini(x, 0, 0)

0.49732620320855625

In [3]:
def get_split_col_value(_x):
    min_col = None
    min_value = None
    min_gini = 1e20

    #遍历所有的列,最后一列是y,不需要计算
    for col in range(_x.shape[1] - 1):

        #遍历所有取值
        for value in set(_x[:, col]):

            len_col_value = np.sum(_x[:, col] == value)

            #如果一个字段只有一个值的话,就不能切
            if len_col_value == len(_x) or len_col_value == 0:
                continue

            #信息增益,就是切分数据后,熵值能下降多少,这个值越大越好
            gini = get_gini(_x, col, value)

            #信息增益最大的列,就是要被拆分的列
            if gini < min_gini:
                min_gini = gini
                min_col = col
                min_value = value

    return min_col, min_value


get_split_col_value(x)

(3, 0)

In [4]:
#创建节点和叶子对象,用来构建树
class Node():
    def __init__(self, col, value):
        self.col = col
        self.value = value
        self.children = {}

    def __str__(self):
        return 'Node col=%d value=%d' % (self.col, self.value)


class Leaf():
    def __init__(self, y):
        self.y = y

    def __str__(self):
        return 'Leaf y=%d' % self.y


print(Node(0, 0)), print(Leaf(1))

Node col=0 value=0
Leaf y=1


(None, None)

In [5]:
#打印树的方法
def print_tree(node, prefix='', subfix=''):
    prefix += '-' * 4
    print(prefix, node, subfix)
    if isinstance(node, Leaf):
        return
    for i in node.children:
        subfix = 'symbol=' + str(i)
        print_tree(node.children[i], prefix, subfix)


print_tree(Node(0, 0))

---- Node col=0 value=0 


In [6]:
#先在所有数据上求最大信息增益的列,结果是3,0
get_split_col_value(x)

(3, 0)

In [7]:
#根据上面的结果,创建根节点,根节点根据列0的值来分割数据
root = Node(3, 0)
print(root)

Node col=3 value=0


In [8]:
#添加子节点的方法
def create_children(_x, parent_node):

    #遍历父节点col列所有的取值
    for symbol in ['eq', 'neq']:
        #首先根据父节点col列的取值分割数据
        sub_x = _x[_x[:, parent_node.col] == parent_node.value]
        if symbol == 'neq':
            sub_x = _x[_x[:, parent_node.col] != parent_node.value]
            
        #取去重y值
        unique_y = np.unique(sub_x[:, -1])

        #如果所有的y都是一样的,说明是个叶子节点
        if len(unique_y) == 1:
            parent_node.children[symbol] = Leaf(unique_y[0])
            continue

        #否则,是个分支节点,计算最佳切分列
        split_col, split_value = get_split_col_value(sub_x)

        #添加分支节点到父节点上
        parent_node.children[symbol] = Node(col=split_col, value=split_value)


create_children(x, root)

print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
-------- Node col=0 value=1 symbol=neq


In [9]:
#继续创建,3=0节点的下一层
x_3_eq_0 = x[x[:, 3] == 0]
x_3_neq_0 = x[x[:, 3] != 0]
create_children(x_3_eq_0, root.children['eq'])
create_children(x_3_neq_0, root.children['neq'])

print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
------------ Leaf y=0 symbol=eq
------------ Node col=0 value=0 symbol=neq
-------- Node col=0 value=1 symbol=neq
------------ Node col=2 value=0 symbol=eq
------------ Leaf y=1 symbol=neq


In [10]:
#继续创建下一层
x_3_eq_0_and_5_neq_0 = x_3_eq_0[x_3_eq_0[:, 5] != 0]
create_children(x_3_eq_0_and_5_neq_0, root.children['eq'].children['neq'])

x_3_neq_0_and_0_eq_1 = x_3_neq_0[x_3_neq_0[:, 0] == 1]
create_children(x_3_neq_0_and_0_eq_1, root.children['neq'].children['eq'])

print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
------------ Leaf y=0 symbol=eq
------------ Node col=0 value=0 symbol=neq
---------------- Node col=1 value=1 symbol=eq
---------------- Leaf y=1 symbol=neq
-------- Node col=0 value=1 symbol=neq
------------ Node col=2 value=0 symbol=eq
---------------- Leaf y=0 symbol=eq
---------------- Leaf y=1 symbol=neq
------------ Leaf y=1 symbol=neq


In [11]:
#继续创建下一层
x_3_eq_0_and_5_neq_0_and_0_eq_0 = x_3_eq_0_and_5_neq_0[
    x_3_eq_0_and_5_neq_0[:, 0] == 0]
create_children(x_3_eq_0_and_5_neq_0_and_0_eq_0,
                root.children['eq'].children['neq'].children['eq'])

print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
------------ Leaf y=0 symbol=eq
------------ Node col=0 value=0 symbol=neq
---------------- Node col=1 value=1 symbol=eq
-------------------- Leaf y=0 symbol=eq
-------------------- Leaf y=1 symbol=neq
---------------- Leaf y=1 symbol=neq
-------- Node col=0 value=1 symbol=neq
------------ Node col=2 value=0 symbol=eq
---------------- Leaf y=0 symbol=eq
---------------- Leaf y=1 symbol=neq
------------ Leaf y=1 symbol=neq


In [12]:
#预测方法,测试
def pred(_x, node):
    symbol = 'eq'
    if _x[node.col] != node.value:
        symbol = 'neq'

    node = node.children[symbol]

    if isinstance(node, Leaf):
        return node.y

    return pred(_x, node)


correct = 0
for i in x:
    if pred(i, root) == i[-1]:
        correct += 1

print(correct / len(x))

1.0
