In [1]:
import numpy as np

In [5]:
def load_data():
    with open("decision_tree_data.txt") as f:
        lines = f.readlines()
    
    x = np.empty((len(lines), 7), dtype=int)
    
    for i in range(len(lines)):
        line = lines[i].strip().split(",")
        x[i] = line
    
    test_x = x[10:]
    
    return x, test_x

In [7]:
x, test_x = load_data()

In [9]:
def get_gini(_x, col, value):
    gini = 0
    
    for symbol in ['eq', 'neq']:
        sub_x = _x[_x[:, col] == value]
        if symbol == 'neq':
            sub_x = _x[_x[:, col] != value]
            
        if len(sub_x) == 0:
            gini ++ 1e20
            
        prob = len(sub_x) / len(_x)
        
        prob_y0 = np.sum(sub_x[:, -1] == 0) / len(sub_x)
        prob_y1 = np.sum(sub_x[:, -1] == 1) / len(sub_x)
        
        gini += prob * (1 - np.power(prob_y0, 2) - np.power(prob_y1, 2))
        
    return gini

In [10]:
get_gini(x, 0, 0)

0.49732620320855625

In [17]:
def get_split_col_value(_x):
    min_col = None
    min_value = None
    min_gini = 1e20
    
    # 遍历所有列，最后一列是 y，不需要计算
    for col in range(_x.shape[1] - 1):
        # 遍历所有取值
        for value in set(_x[:, col]):
            len_col_value = np.sum(_x[:, col] == value)
            # 如果一个字段只有一个值的话，就不能切
            if len_col_value == len(_x) or len_col_value == 0:
                continue
                
            gini = get_gini(_x, col, value)
            
            if (gini < min_gini):
                min_gini = gini
                min_col = col
                min_value = value
                
    return min_col, min_value

In [18]:
class Node():
    def __init__(self, col, value):
        self.col = col
        self.value = value
        self.children = {}

    def __str__(self):
        return 'Node col=%d value=%d' % (self.col, self.value)


class Leaf():
    def __init__(self, y):
        self.y = y

    def __str__(self):
        return 'Leaf y=%d' % self.y

In [19]:
def print_tree(node, prefix='', subfix=''):
    prefix += '-' * 4
    print(prefix, node, subfix)
    if isinstance(node, Leaf):
        return
    for i in node.children:
        subfix = 'symbol=' + str(i)
        print_tree(node.children[i], prefix, subfix)


print_tree(Node(0, 0))

---- Node col=0 value=0 


In [20]:
get_split_col_value(x)

(3, 0)

In [21]:
root = Node(3, 0)
print(root)

Node col=3 value=0


In [25]:
def create_children(_x, parent_node):
    for symbol in ['eq', 'neq']:
        sub_x = _x[_x[:, parent_node.col] == parent_node.value]
        if symbol == 'neq':
            sub_x = _x[_x[:, parent_node.col] != parent_node.value]
            
        unique_y = np.unique(sub_x[:, -1])
        
        if len(unique_y) == 1:
            parent_node.children[symbol] = Leaf(unique_y[0])
            continue
            
        split_col, split_value = get_split_col_value(sub_x)
        
        parent_node.children[symbol] = Node(col=split_col, value=split_value)
        
create_children(x, root)
print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
-------- Node col=0 value=1 symbol=neq


In [26]:
x_3_eq_0 = x[x[:, 3] == 0]
x_3_neq_0 = x[x[:, 3] != 0]
create_children(x_3_eq_0, root.children['eq'])
create_children(x_3_neq_0, root.children['neq'])

print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
------------ Leaf y=0 symbol=eq
------------ Node col=0 value=0 symbol=neq
-------- Node col=0 value=1 symbol=neq
------------ Node col=2 value=0 symbol=eq
------------ Leaf y=1 symbol=neq


In [27]:
x_3_eq_0_and_5_neq_0 = x_3_eq_0[x_3_eq_0[:, 5] != 0]
create_children(x_3_eq_0_and_5_neq_0, root.children['eq'].children['neq'])

x_3_neq_0_and_0_eq_1 = x_3_neq_0[x_3_neq_0[:, 0] == 1]
create_children(x_3_neq_0_and_0_eq_1, root.children['neq'].children['eq'])

print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
------------ Leaf y=0 symbol=eq
------------ Node col=0 value=0 symbol=neq
---------------- Node col=1 value=1 symbol=eq
---------------- Leaf y=1 symbol=neq
-------- Node col=0 value=1 symbol=neq
------------ Node col=2 value=0 symbol=eq
---------------- Leaf y=0 symbol=eq
---------------- Leaf y=1 symbol=neq
------------ Leaf y=1 symbol=neq


In [28]:
x_3_eq_0_and_5_neq_0_and_0_eq_0 = x_3_eq_0_and_5_neq_0[
    x_3_eq_0_and_5_neq_0[:, 0] == 0]
create_children(x_3_eq_0_and_5_neq_0_and_0_eq_0,
                root.children['eq'].children['neq'].children['eq'])

print_tree(root)

---- Node col=3 value=0 
-------- Node col=5 value=0 symbol=eq
------------ Leaf y=0 symbol=eq
------------ Node col=0 value=0 symbol=neq
---------------- Node col=1 value=1 symbol=eq
-------------------- Leaf y=0 symbol=eq
-------------------- Leaf y=1 symbol=neq
---------------- Leaf y=1 symbol=neq
-------- Node col=0 value=1 symbol=neq
------------ Node col=2 value=0 symbol=eq
---------------- Leaf y=0 symbol=eq
---------------- Leaf y=1 symbol=neq
------------ Leaf y=1 symbol=neq


In [29]:
def pred(_x, node):
    symbol = 'eq'
    if _x[node.col] != node.value:
        symbol = 'neq'

    node = node.children[symbol]

    if isinstance(node, Leaf):
        return node.y

    return pred(_x, node)


correct = 0
for i in x:
    if pred(i, root) == i[-1]:
        correct += 1

print(correct / len(x))

1.0
