In [1]:
import numpy as np

def load_data():
    with open('decision_tree_data.txt') as fr:
        lines = fr.readlines()

    x = np.empty((len(lines), 7), dtype=int)

    for i in range(len(lines)):
        line = lines[i].strip().split(',')
        x[i] = line

    test_x = x[10:]
    x = x[:10]

    return x, test_x

x, test_x = load_data()
x, test_x

(array([[0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1, 1, 0],
        [1, 1, 0, 1, 1, 1, 0],
        [0, 2, 2, 0, 2, 1, 1],
        [2, 1, 1, 1, 0, 0, 1],
        [1, 1, 0, 0, 1, 1, 1],
        [2, 0, 0, 2, 2, 0, 1],
        [0, 0, 1, 1, 1, 0, 1]]),
 array([[0, 0, 1, 0, 0, 0, 0],
        [2, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 0, 0],
        [1, 1, 1, 1, 1, 0, 1],
        [2, 2, 2, 2, 2, 0, 1],
        [2, 0, 0, 2, 2, 1, 1],
        [0, 1, 0, 1, 0, 0, 1]]))

In [2]:
def get_entropy(_x):
    entropy = 0

    y = _x[:, -1]

    # 每种标签值出现的次数
    bincount = np.bincount(y)
    
    for count in bincount:
        if count == 0:
            continue

        prob = count / len(_x)
        entropy -= prob * np.log2(prob)

    return entropy

get_entropy(x)

1.0

In [3]:
def get_gain(_x, col):
    # 按列（标签）求熵
    col_entropy = 0

    # 遍历同一个标签的各个标签值
    for value in set(_x[:, col]):
        x_by_col_and_value = _x[_x[:, col] == value]

        prob = len(x_by_col_and_value) / len(_x)
        entropy = get_entropy(x_by_col_and_value)
        col_entropy += prob * entropy

    # 信息增益, 值越大越好
    gain = get_entropy(_x) - col_entropy

    return gain

get_gain(x, 0)

0.2754887502163468

In [4]:
def get_split_col(_x):
    best_col = -1
    best_gain = 0

    # 遍历所有的列, 最后一列是 y（标签）
    for col in range(_x.shape[1] - 1):

        # 信息增益（熵下降）, 值越大越好
        gain = get_gain(_x, col)

        if gain > best_gain:
            best_gain = gain
            best_col = col

    return best_col

get_split_col(x)

0

In [6]:
class Node():
    def __init__(self, col):
        self.col = col
        self.children = {}

    def __str__(self):
        return 'Node col=%d' % self.col


class Leaf():
    def __init__(self, y):
        self.y = y

    def __str__(self):
        return 'Leaf y=%d' % self.y

print(Node(0)), print(Leaf(1))

Node col=0
Leaf y=1


(None, None)

In [7]:
def print_tree(node, prefix='', subfix=''):
    prefix += '-' * 4
    print(prefix, node, subfix)
    if isinstance(node, Leaf):
        return
    for i in node.children:
        subfix = 'value=' + str(i)
        print_tree(node.children[i], prefix, subfix)


print_tree(Node(0))

---- Node col=0 


In [8]:
# 在所有数据上求最大信息增益的列, 结果是第 0 列
get_split_col(x)
root = Node(0)
print(root)

Node col=0


In [14]:
def pre_cut(_x, test_x):
    _y = _x[:, -1]
    test_y = test_x[:, -1]
    
    # 求样本数量最多的标签值
    vote_y = np.bincount(_y).argmax()
    pre_correct = np.sum(test_y == vote_y)
    
    split_col = get_split_col(_x)
    
    # 计算按列分割后的测试正确率
    after_correct = 0
    
    # 遍历分割列的标签值
    for split_value in np.unique(_x[:, split_col]):
        sub_x = _x[_x[:, split_col] == split_value]
        sub_test_x = test_x[test_x[:, split_col] == split_value]
        
        # 标签
        sub_y = sub_x[:, -1]
        sub_test_y = sub_test_x[:, -1]
        
        sub_vote_y = np.bincount(sub_y).argmax()
        
        after_correct += np.sum(sub_test_y == sub_vote_y)
    
    # 分割后的测试正确率提升了, 则分割
    return after_correct > pre_correct

pre_cut(x, test_x)

True

In [17]:
def create_children(_x, test_x, parent_node):

    # 遍历父节点 col列 所有的取值
    for split_value in np.unique(_x[:, parent_node.col]):

        sub_x = _x[_x[:, parent_node.col] == split_value]
        sub_test_x = test_x[test_x[:, parent_node.col] == split_value]

        _y = sub_x[:, -1]

        # 如果所有的样本只有一个标签值, 表示是叶子节点
        # 如果分割后的测试正确率提升了, 则分割
        if len(np.unique(_y)) == 1 or not pre_cut(sub_x, sub_test_x):
            vote_y = np.bincount(_y).argmax()
            parent_node.children[split_value] = Leaf(vote_y)  # 参考 pre_cut 内部计算 pre_correct 的逻辑
            continue

        split_col = get_split_col(sub_x)
        parent_node.children[split_value] = Node(col=split_col)


create_children(x, test_x, root)

print_tree(root)

---- Node col=0 
-------- Leaf y=0 value=0
-------- Leaf y=0 value=1
-------- Leaf y=1 value=2


In [18]:
def pred(_x, node):
    col_value = _x[node.col]
    node = node.children[col_value]

    if isinstance(node, Leaf):
        return node.y

    return pred(_x, node)


correct = 0
for i in x:
    if pred(i, root) == i[-1]:
        correct += 1

print(correct / len(x))

print('-------------------------')

correct = 0
for i in test_x:
    if pred(i, root) == i[-1]:
        correct += 1

print(correct / len(test_x))

0.7
-------------------------
0.5714285714285714
