In [2]:
from math import log 

def calc_shannon_ent(dataset):
    """计算香农熵"""
    num_entries = len(dataset)
    label_counts = {}
    for feat_vec in dataset:
        current_label = feat_vec[-1]
        if current_label not in label_counts.keys():    # 为了所有可能的分类创建字典
            label_counts[current_label] = 0
        label_counts[current_label] += 1
    shannon_ent = 0.0 
    for key in label_counts:
        prob = float(label_counts[key]) / num_entries   # 计算选择此类的概率
        shannon_ent -= prob * log(prob, 2)  # 以2为底求对数
    return shannon_ent

In [3]:
def create_dataset():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    # 这里的label代表的是dataset里面前两列特征的实际含义
    # 避免和别的label搞混淆
    return dataSet, labels

In [4]:
my_dat, labels = create_dataset()

In [5]:
my_dat

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

In [6]:
calc_shannon_ent(my_dat)

0.9709505944546686

In [7]:
my_dat[0][-1] = 'maybe'
my_dat

[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

In [8]:
calc_shannon_ent(my_dat)

1.3709505944546687

In [9]:
def split_dataset(dataset, axis, value):
    """按照指定特征划分数据集"""
    ret_dataset = []
    for feat_vec in dataset:
        if feat_vec[axis] == value:
            reduced_feat_vec = feat_vec[:axis]
            reduced_feat_vec.extend(feat_vec[axis+1:])
            ret_dataset.append(reduced_feat_vec)    # 抽取符合要求的数据
    return ret_dataset

In [11]:
my_dat, labels = create_dataset()
my_dat

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

In [12]:
split_dataset(my_dat, 0, 1)

[[1, 'yes'], [1, 'yes'], [0, 'no']]

In [13]:
split_dataset(my_dat, 0, 0)

[[1, 'no'], [1, 'no']]

In [18]:
def choose_best_feature_to_split(dataset):
    """选择最好的数据划分方式"""
    num_features = len(dataset[0]) - 1
    base_entroy = calc_shannon_ent(dataset)
    best_info_gain = 0.0; best_feature = -1 # 初始化原数据集的香农熵
    for i in range(num_features):
        feat_list = [example[i] for example in dataset]
        unique_vals = set(feat_list)    # 创建唯一的分类标签set
        new_entroy = 0.0 
        for value in unique_vals:
            sub_dataset = split_dataset(dataset, i, value)
            prob = len(sub_dataset) / float(len(dataset))
            new_entroy += prob * calc_shannon_ent(sub_dataset)  # 计算每种划分方式的信息熵
        info_gain = base_entroy - new_entroy
        if (info_gain > best_info_gain):
            best_info_gain = info_gain
            best_feature = i    # 计算最多信息增益下的划分特征
    return best_feature

In [25]:
my_dat, labels = create_dataset()
my_dat

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

In [26]:
choose_best_feature_to_split(my_dat)

0

In [27]:
import operator

def majority_cnt(class_list):
    """返回数量最多的类"""
    class_count = {}
    for vote in class_list:
        if vote not in class_count.keys():
            class_count[vote] = 0
        class_count[vote] += 1
    
    sorted_class_count = sorted(class_count.items(), key=lambda x: x[1], reverse=True)
    return sorted_class_count[0][0]

In [45]:
def create_tree(dataset, labels):
    """创建树"""
    class_list = [example[-1] for example in dataset]
    if class_list.count(class_list[0]) == len(class_list):  # 类别完全相同，停止划分
        return class_list[0]
    if len(dataset[0]) == 1:
        return majority_cnt(class_list)     # 遍历完所有特征时返回出现次数最多的
    best_feat = choose_best_feature_to_split(dataset)
    best_feat_label = labels[best_feat]
    my_tree = {best_feat_label: {}}
    del(labels[best_feat])
    feat_vals = [example[best_feat] for example in dataset]
    unique_vals = set(feat_vals)
    for value in unique_vals:
        sub_labels = labels[:]
        my_tree[best_feat_label][value] =\
             create_tree(split_dataset(dataset, best_feat, value), sub_labels)
    return my_tree

In [46]:
my_dat, labels = create_dataset()
my_tree = create_tree(my_dat, labels)
my_tree

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

In [None]:
# 使用文本注释绘制树
import matplotlib.pyplot as plt 

decision_node = dict(boxstyle='sawtooth', fc='0.8')
leaf_node = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')

def plot_node(node_text, center_pt, parent_pt, node_type):
    create_plot.ax1.annotate(node_text, xy=parent_pt,  xycoords='axes fraction',
             xytext=center_pt, textcoords='axes fraction',
             va="center", ha="center", bbox=node_type, arrowprops=arrow_args)

def create_plot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    