In [1]:
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt

### 导入计算基尼系数的函数

In [2]:
def gini(y):
    """ 计算基尼系数 """
    counter = Counter(y)
    result = 0
    for v in counter.values():
        # 计算每个种类概率的平方和
        result += (v/len(y))**2
    return 1-result

In [3]:
def cut(X, y, d, v):
    """ 按d,v值切割X,y数据集 """
    
    ind_left = (X[:, d] <= v)
    ind_right = (X[:, d] > v)
    return X[ind_left], X[ind_right], y[ind_left], y[ind_right]


def try_split(X, y):
    """ 对于X,y数据集，按最佳的基尼系数分割点进行分割，并返回最佳的d,v,g """
    
    best_g = 1
    best_d = -1
    best_v = -1
    
    for d in range(X.shape[1]):
        sorted_index = np.argsort(X[:, d])
        for i in range(len(X)-1):
            
            if X[sorted_index[i], d] == X[sorted_index[i+1], d]:
                continue
            
            v = (X[sorted_index[i], d] + X[sorted_index[i+1], d]) / 2
            #print('d={}, v={}'.format(d, v))
            
            X_left, X_right, y_left, y_right = cut(X, y, d, v)
            g_all = gini(y_left) + gini(y_right)
            
            #print('d={}, v={}, g={}'.format(d, v, g_all))
            
            if g_all < best_g:
                best_g = g_all
                best_d = d
                best_v = v
                
    return best_d, best_v, best_g

In [4]:
X = np.loadtxt('data/iris/x.txt')
X = X[:, 2:]  # 只取后两个维度的数据
y = np.loadtxt('data/iris/y.txt')
try_split(X, y)

(0, 2.45, 0.5)

### 定义 Node 节点类

In [5]:
class Node():
    def __init__(self, d=None, v=None, g=None, l=None):
        self.dim   = d
        self.value = v
        self.gini  = g
        self.label = l
        
        self.children_left  = None
        self.children_right = None
        
    def __repr__(self):
        return 'Node(d={}, v={}, g={}, l={})'.format(self.dim, self.value, self.gini, self.label)

### 构建决策树

In [6]:
def create_tree(X, y):
    """ 递归构建决策树 """
    
    d, v, g = try_split(X, y)
    
    if d==-1 or g==0:
        return None
    
    node = Node(d, v, g)
    
    X_left, X_right, y_left, y_right = cut(X, y, d, v)
    
    node.children_left = create_tree(X_left, y_left)
    if node.children_left is None:
        lable = Counter(y_left).most_common(1)[0][0]
        node.children_left = Node(l=lable)
    
    node.children_right = create_tree(X_right, y_right)
    if node.children_right is None:
        lable = Counter(y_right).most_common(1)[0][0]
        node.children_right = Node(l=lable)
    
    return node

In [7]:
tree = create_tree(X, y)
tree

Node(d=0, v=2.45, g=0.5, l=None)

### 使用决策树进行预测

In [8]:
def predict(x, node):
    """ 返回数据点x在决策树node中的预测类别 """
    
    if node.label is not None:
        return node.label
    
    if x[node.dim] <= node.value:
        # left
        return predict(x, node.children_left)
    else:
        # right
        return predict(x, node.children_right)

测试

In [9]:
predict(X[0], tree)

0.0

In [10]:
predict(X[66], tree)

1.0

In [11]:
predict(X[102], tree)

2.0