In [2]:
import numpy as np

def load_data():
    with open("decision_tree_data.txt") as f:
        lines = f.readlines()
        
    x = np.empty((len(lines), 7), dtype=int)
    
    for i in range(len(lines)):
        line = lines[i].strip().split(',')
        x[i] = line
        
    test_x = x[10:]
    x = x[:10]
    
    return x, test_x

In [3]:
x, test_x = load_data()
x.shape, test_x.shape

((10, 7), (7, 7))

In [4]:
def get_entropy(_x):
    entropy = 0.0
    
    y = _x[:, -1]
    
    bincount = np.bincount(y)
    for count in bincount:
        if count == 0:
            continue
            
        prob = count / len(_x)
        
        entropy -= prob * np.log2(prob)
        
    return entropy

get_entropy(x)

1.0

In [5]:
# 计算增益的方式与 ID3 有差异
def get_gain(_x, col):
    col_entropy = 0
    iv = 1e-20
    
    for value in set(_x[:, col]):
        x_by_col_and_value = _x[_x[:, col] == value]
        
        prob = len(x_by_col_and_value) / len(_x)
        entropy = get_entropy(x_by_col_and_value)
        col_entropy += prob * entropy
        
        iv -= prob * np.log2(prob)
        
    gain = get_entropy(_x) - col_entropy
    
    return gain / iv
        
get_gain(x, 0)

0.1810129868433342

In [6]:
def get_split_col(_x):
    best_col = -1
    best_gain = 0
    
    # 最后一列是 y
    for col in range(_x.shape[1] - 1):
        gain = get_gain(_x, col)
        
        if gain > best_gain:
            best_gain = gain
            best_col = col
            
    return best_col

get_split_col(x)        

0

In [7]:
class Node():
    def __init__(self, col):
        self.col = col
        self.children = {}
        
    def __str__(self):
        return 'Node col=%d' % self.col
    
class Leaf():
    def __init__(self, y):
        self.y = y

    def __str__(self):
        return 'Leaf y=%d' % self.y

In [8]:
def print_tree(node, prefix='', subfix=''):
    prefix += '-' * 4
    print(prefix, node, subfix)
    if isinstance(node, Leaf):
        return
    for i in node.children:
        subfix = 'value=' + str(i)
        print_tree(node.children[i], prefix, subfix)

In [9]:
def create_children(_x, parent_node):
    for split_value in np.unique(_x[:, parent_node.col]):
        sub_x = _x[_x[:, parent_node.col] == split_value]
        print(sub_x)
        print("-----")
        
        unique_y = np.unique(sub_x[:, -1])
        
        if len(unique_y) == 1:
            parent_node.children[split_value] = Leaf(unique_y)
            continue
            
        split_col = get_split_col(sub_x)
        
        parent_node.children[split_value] = Node(col=split_col)


In [10]:
# 在完整数据集上求信息增益最大的列
split_col = get_split_col(x)
split_col

0

In [11]:
# 根据上面的结果，创建根节点
root = Node(0)
create_children(x, root)
print_tree(root)

[[0 0 0 0 0 0 0]
 [0 1 0 0 1 1 0]
 [0 2 2 0 2 1 1]
 [0 0 1 1 1 0 1]]
-----
[[1 0 1 0 0 0 0]
 [1 0 0 0 0 0 0]
 [1 1 0 1 1 1 0]
 [1 1 0 0 1 1 1]]
-----
[[2 1 1 1 0 0 1]
 [2 0 0 2 2 0 1]]
-----
---- Node col=0 
-------- Node col=2 value=0
-------- Node col=1 value=1
-------- Leaf y=1 value=2


In [12]:
#继续 0=0 节点的下一层
x_0_0 = x[x[:, 0] == 0]
create_children(x_0_0, root.children[0])

print_tree(root)

[[0 0 0 0 0 0 0]
 [0 1 0 0 1 1 0]]
-----
[[0 0 1 1 1 0 1]]
-----
[[0 2 2 0 2 1 1]]
-----
---- Node col=0 
-------- Node col=2 value=0
------------ Leaf y=0 value=0
------------ Leaf y=1 value=1
------------ Leaf y=1 value=2
-------- Node col=1 value=1
-------- Leaf y=1 value=2


In [13]:
# 创建 0=1 的下一层
x_0_1 = x[x[:, 0] == 1]
create_children(x_0_1, root.children[1])

print_tree(root)

[[1 0 1 0 0 0 0]
 [1 0 0 0 0 0 0]]
-----
[[1 1 0 1 1 1 0]
 [1 1 0 0 1 1 1]]
-----
---- Node col=0 
-------- Node col=2 value=0
------------ Leaf y=0 value=0
------------ Leaf y=1 value=1
------------ Leaf y=1 value=2
-------- Node col=1 value=1
------------ Leaf y=0 value=0
------------ Node col=3 value=1
-------- Leaf y=1 value=2


In [14]:
# 创建 0=1,1=1 的下一层
x_0_1_and_1_1 = x_0_1[x_0_1[:, 1] == 1]
create_children(x_0_1_and_1_1, root.children[1].children[1])

print_tree(root)

[[1 1 0 0 1 1 1]]
-----
[[1 1 0 1 1 1 0]]
-----
---- Node col=0 
-------- Node col=2 value=0
------------ Leaf y=0 value=0
------------ Leaf y=1 value=1
------------ Leaf y=1 value=2
-------- Node col=1 value=1
------------ Leaf y=0 value=0
------------ Node col=3 value=1
---------------- Leaf y=1 value=0
---------------- Leaf y=0 value=1
-------- Leaf y=1 value=2


In [15]:
# 测试
def pred(_x, node):
    col_value = _x[node.col]
    node = node.children[col_value]

    if isinstance(node, Leaf):
        return node.y

    return pred(_x, node)


correct = 0
for i in x:
    if pred(i, root) == i[-1]:
        correct += 1

print(correct / len(x))

print('-------------------------')

correct = 0
for i in test_x:
    if pred(i, root) == i[-1]:
        correct += 1

print(correct / len(test_x))

1.0
-------------------------
0.2857142857142857
