In [2]:
import numpy as np
from sklearn import datasets

In [6]:
iris = datasets.load_iris()
X = iris.data[:,:]
y = iris.target
X.shape

(150, 4)

In [7]:
# d 表示特征的索引（列），value 表示特征的分割值
def split(X, y, d, v):
    index_a = (X[:, d] <= value)
    index_b = (X[:, d] > value)
    
    return X[index_a], X[index_b], y[index_a], y[index_b]

In [16]:
from collections import Counter
from math import log

def entropy(y):
    counter = Counter(y)
    res = 0.0
    for num in counter.values():
        p = num / len(y)
        res += -p * log(p)
    return res

In [17]:
def try_split(X, y):
    best_entropy = float('inf')
    # d 表示特征的索引（列），value 表示特征的分割值
    best_d = -1
    best_v = -1
    
    for d in range(X.shape[1]):
        # 按照特征（列）值排序
        sorted_index = np.argsort(X[:,d])
        for i in range(1, len(X)):
            # 将某一列相邻的两个元素的均值，作为分离列的最优值
            if X[sorted_index[i], d] != X[sorted_index[i - 1], d]:
                v = (X[sorted_index[i], d] + X[sorted_index[i-1], d])/2
                X_l, X_r, y_l, y_r = split(X, y, d, v)
                p_l = len(X_l) / len(X)
                p_r = len(X_r) / len(X)
                e = p_l * entropy(y_l) + p_r * entropy(y_r)
                if e < best_entropy:
                    best_entropy = e
                    best_d = d
                    best_v = v
                    
    return best_entropy, best_d, best_v

In [18]:
best_entropy, best_d, best_v = try_split(X, y)
print("best_entropy =", best_entropy)
print("best_d =", best_d)
print("best_v =", best_v)

best_entropy = 0.46209812037329684
best_d = 2
best_v = 2.45


In [19]:
X1_l, X1_r, y1_l, y1_r = split(X, y, best_d, best_v)

In [20]:
entropy(y1_l)

0.0

In [21]:
entropy(y1_r)

0.6931471805599453

### 二次划分

In [22]:
best_entropy2, best_d2, best_v2 = try_split(X1_r, y1_r)
print("best_entropy =", best_entropy2)
print("best_d =", best_d2)
print("best_v =", best_v2)

best_entropy = 0.2147644654371359
best_d = 3
best_v = 1.75


In [23]:
X2_l, X2_r, y2_l, y2_r = split(X1_r, y1_r, best_d2, best_v2)

In [24]:
entropy(y2_l)

0.30849545083110386

In [25]:
entropy(y2_r)

0.10473243910508653