In [1]:
import numpy as np
from sklearn.datasets import load_iris

In [2]:
dataset = load_iris()
x = dataset.data
y = dataset.target
features = dataset.feature_names

In [3]:
class Node:
    
    def __init__(self, value, type='decision'):
        self.value = value    
        self.type = type
        self.children = {}

In [4]:
def calc_shannon_ent(data, labels):
    data_count = float(data.shape[0])
    features = data.shape[1]
    # 统计每个分类的个数
    labels_count = np.array([labels[labels == label].size for label in set(labels)])
    # 计算数据集的信息熵
    base_ent = -np.sum((labels_count / data_count) * np.log2(labels_count / data_count))
    # 存放各种特征值信息增益
    gain_list = []
    
    # 计算每个特征划分后的数据集的信息熵
    for feature in range(0, features):
        # 获取特征对应的数据
        feature_data = data[:, feature]
        # 获取特征信息，以特征值为 key，特征值的数目为 value
        feature_info = {feature: feature_data[feature_data == feature].size for feature in set(feature_data)}
        feature_shannonEnt = 0
        # 获取每个特征值的分类信息并计算条件信息熵
        for feature_value in feature_info:
            # 当前特征值的数目
            feature_count = float(feature_info[feature_value])
            label_data = labels[feature_data == feature_value]
            labels_feature = np.array([label_data[label_data == label].size for label in set(label_data)])
            # 计算每个分类的概率
            p_label = labels_feature / feature_count
            feature_shannonEnt += (feature_count / data_count) * np.sum(-p_label * np.log2(p_label))
        gain_list.append(base_ent - feature_shannonEnt)
    gain_list = np.array(gain_list)
    return np.argmax(gain_list)

In [5]:
def split_dataset(data, labels, feature, value):
    """
    :param data:    数据集 ndarray
    :param labels:  标签列表 ndarray
    :param feature: 特征 
    :param value:   特征值
    :return: 
    """
    feature_data = data[:, feature]
    select_rows = feature_data == value
    return (np.delete(data[select_rows], feature, axis=1), labels[select_rows])

In [6]:
def voting_label(labels):
    return sorted([(label, len(labels[label == label])) for label in set(labels)])[-1][1]

In [7]:
def create_tree(data, labels, features):
    # 判断：特征集是否存在，如果不存在，则当前结点作为叶结点
    if len(features) == 0:
        return Node(voting_label(labels))
    # 判断：标签集，若标签只有一种，则当前结点作为叶结点
    if len(set(labels)) == 1:
        return Node(labels[0])
    # 获取最优特征的下标
    best_feature_index = calc_shannon_ent(data, labels)
    best_feature = features[best_feature_index]
    # 创建结点
    node = Node(best_feature)
    # 将已划分的特征从特征集中移除
    features = np.delete(features, best_feature_index)
    # 根据最优特征划分数据集
    best_feature_data = data[:, best_feature_index]
    best_feature_info = {feature: best_feature_data[best_feature_data == feature].size for feature in set(best_feature_data)}
    for feature_value in best_feature_info:
        split_data, split_labels = split_dataset(data, labels, best_feature_index, feature_value)
        node.children[feature_value] = create_tree(split_data, split_labels, features)
    return node     

In [8]:
root = create_tree(x, y, list(range(x.shape[1])))

In [9]:
root.value

2

In [10]:
features[root.value]

'petal length (cm)'

In [11]:
root.children

{1.7: <__main__.Node at 0x203a7aa9588>,
 1.4: <__main__.Node at 0x203a7a5ee80>,
 1.6: <__main__.Node at 0x203a7a5eba8>,
 1.3: <__main__.Node at 0x203a7aa9278>,
 1.5: <__main__.Node at 0x203a7aa9780>,
 1.1: <__main__.Node at 0x203a7aa9748>,
 1.2: <__main__.Node at 0x203a7aa94e0>,
 1.0: <__main__.Node at 0x203a7aa9128>,
 1.9: <__main__.Node at 0x203a7aa9828>,
 4.7: <__main__.Node at 0x203a7aa97b8>,
 4.5: <__main__.Node at 0x203a7aa9860>,
 4.9: <__main__.Node at 0x203a7aa9160>,
 4.0: <__main__.Node at 0x203a7aa9b70>,
 5.0: <__main__.Node at 0x203a7aa9a58>,
 6.0: <__main__.Node at 0x203a7aa9e80>,
 3.5: <__main__.Node at 0x203a7aa9e48>,
 3.0: <__main__.Node at 0x203a7aa93c8>,
 4.6: <__main__.Node at 0x203a7aa9e10>,
 4.4: <__main__.Node at 0x203a7aa9c88>,
 4.1: <__main__.Node at 0x203a7aa9d30>,
 5.1: <__main__.Node at 0x203a7aa9cf8>,
 5.9: <__main__.Node at 0x203a7aa9ba8>,
 5.6: <__main__.Node at 0x203a7aa9d68>,
 5.5: <__main__.Node at 0x203a7aa9f28>,
 5.4: <__main__.Node at 0x203a7aa9fd0>,


In [12]:
root.children[1.7].value

0