In [1]:
import numpy as np

def load_data():
    with open('decision_tree_data.txt') as fr:
        lines = fr.readlines()

    x = np.empty((len(lines), 7), dtype=int)

    for i in range(len(lines)):
        line = lines[i].strip().split(',')
        x[i] = line

    test_x = x[10:]
    x = x[:10]

    return x, test_x


x, test_x = load_data()
x, test_x

(array([[0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1, 1, 0],
        [1, 1, 0, 1, 1, 1, 0],
        [0, 2, 2, 0, 2, 1, 1],
        [2, 1, 1, 1, 0, 0, 1],
        [1, 1, 0, 0, 1, 1, 1],
        [2, 0, 0, 2, 2, 0, 1],
        [0, 0, 1, 1, 1, 0, 1]]),
 array([[0, 0, 1, 0, 0, 0, 0],
        [2, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 0, 0],
        [1, 1, 1, 1, 1, 0, 1],
        [2, 2, 2, 2, 2, 0, 1],
        [2, 0, 0, 2, 2, 1, 1],
        [0, 1, 0, 1, 0, 0, 1]]))

In [22]:
class Node():
    def __init__(self, col):
        self.col = col
        self.children = {}

    def __str__(self):
        return 'Node col=%d' % self.col


class Leaf():
    def __init__(self, y):
        self.y = y

    def __str__(self):
        return 'Leaf y=%d' % self.y


print(Node(0)), print(Leaf(1))

Node col=0
Leaf y=1


(None, None)

In [23]:
def print_tree(node, prefix='', subfix=''):
    prefix += '-' * 4
    print(prefix, node, subfix)
    if isinstance(node, Leaf):
        return
    for i in node.children:
        subfix = 'value=' + str(i)
        print_tree(node.children[i], prefix, subfix)


print_tree(Node(0))

---- Node col=0 


In [24]:
import pickle

with open('tree.dump', 'rb') as fr:
    root = pickle.load(fr)

print_tree(root)

---- Node col=0 
-------- Node col=2 value=0
------------ Leaf y=0 value=0
------------ Leaf y=1 value=1
------------ Leaf y=1 value=2
-------- Node col=1 value=1
------------ Leaf y=0 value=0
------------ Node col=3 value=1
---------------- Leaf y=1 value=0
---------------- Leaf y=0 value=1
-------- Leaf y=1 value=2


In [25]:
# 第 0 列，值为 1
x_0_1 = x[x[:, 0] == 1]
# 第 1 列，值为 1
x_0_1_and_1_1 = x_0_1[x_0_1[:, 1] == 1]

test_x_0_1 = test_x[test_x[:, 0] == 1]
test_x_0_1_and_1_1 = test_x_0_1[test_x_0_1[:, 1] == 1]

node = root.children[1].children[1]

print_tree(node)

---- Node col=3 
-------- Leaf y=1 value=0
-------- Leaf y=0 value=1


In [26]:
def post_cut(node, _x, test_x):
    after_correct = 0
    for split_value in np.unique(test_x[:, node.col]):
        sub_test_x = test_x[test_x[:, node.col] == split_value]
        sub_test_y = sub_test_x[:, -1]

        after_correct += np.sum(sub_test_y == node.children[split_value].y)

    _y = _x[:, -1]
    test_y = test_x[:, -1]

    # 求众数
    vote_y = np.bincount(_y).argmax()

    pre_correct = np.sum(test_y == vote_y)
    
    
    if after_correct <= pre_correct:
        return Leaf(y=vote_y)

    return node


root.children[1].children[1] = post_cut(node, x_0_1_and_1_1,
                                         test_x_0_1_and_1_1)
print_tree(root)

---- Node col=0 
-------- Node col=2 value=0
------------ Leaf y=0 value=0
------------ Leaf y=1 value=1
------------ Leaf y=1 value=2
-------- Node col=1 value=1
------------ Leaf y=0 value=0
------------ Leaf y=0 value=1
-------- Leaf y=1 value=2


In [27]:
x_0_1 = x[x[:, 0] == 1]

test_x_0_1 = test_x[test_x[:, 0] == 1]

node = root.children[1]

print_tree(node)

---- Node col=1 
-------- Leaf y=0 value=0
-------- Leaf y=0 value=1


In [28]:
root.children[1] = post_cut(node, x_0_1, test_x_0_1)
print_tree(root)

---- Node col=0 
-------- Node col=2 value=0
------------ Leaf y=0 value=0
------------ Leaf y=1 value=1
------------ Leaf y=1 value=2
-------- Leaf y=0 value=1
-------- Leaf y=1 value=2


In [29]:
x_0_0 = x[x[:, 0] == 0]

test_x_0_0 = test_x[test_x[:, 0] == 0]

node = root.children[0]

print_tree(node)

---- Node col=2 
-------- Leaf y=0 value=0
-------- Leaf y=1 value=1
-------- Leaf y=1 value=2


In [30]:
root.children[0] = post_cut(node, x_0_0, test_x_0_0)
print_tree(root)

---- Node col=0 
-------- Leaf y=0 value=0
-------- Leaf y=0 value=1
-------- Leaf y=1 value=2


In [31]:
def pred(_x, node):
    col_value = _x[node.col]
    node = node.children[col_value]

    if isinstance(node, Leaf):
        return node.y

    return pred(_x, node)


correct = 0
for i in x:
    if pred(i, root) == i[-1]:
        correct += 1

print(correct / len(x))

print('-------------------------')

correct = 0
for x in test_x:
    if pred(x, root) == i[-1]:
        correct += 1

print(correct / len(test_x))

0.7
-------------------------
0.42857142857142855
