# 决策树
## 优点
计算复杂度不高，输出结果易于理解，对中间值的缺失不敏感，可以处理不相关特征数据
## 缺点
可能会产生过度匹配的问题

In [1]:
import numpy as np
import operator

In [11]:
def  calc_shannon_ent(dataset):
    """计算数据的信息熵
        该方法label必须在最后一列
    """
    num_entries = len(dataset)
    label_count = {}
    for fea_vec in dataset:
        current_label = fea_vec[-1]
        if current_label not in label_count.keys():
            label_count[current_label] = 0
            label_count[current_label] += 1
    shanon_ent = 0.0
    for key in label_count:
        prob = float(label_count[key])/num_entries
        shanon_ent -= prob * np.log2(prob)
    return shanon_ent

def split_data_set(dataset, axis, value):
    ret_data_set =[]
    for feat_vec in dataset:
        if feat_vec[axis] == value:
            reduced_feat_vec = feat_vec[:axis]
            reduced_feat_vec.extend(feat_vec[axis+1: ])
            ret_data_set.append(reduced_feat_vec)
    return ret_data_set

def choose_best_feature_to_split(dataset):
    num_features = len(dataset[0]) - 1# 特征数量
    base_entropy = calc_shannon_ent(dataset)
    best_info_gain = 0.0
    best_feature = -1
    for i in range(num_features):
        feat_list = [example[i] for example in dataset]
        unique_value = set(feat_list)
        new_entropy = 0.0
        for value in unique_value:
            sub_dataset = split_data_set(dataset, i, value)
            prob = len(sub_dataset)/np.float(len(dataset))
            new_entropy -= prob*calc_shannon_ent(sub_dataset) # 计算信息熵
        info_gain = base_entropy - new_entropy 
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature

In [27]:
def  majority_cnt(class_list):
    class_count = {}
    for vote in class_list:
        if vote not in class_count: class_count[vote] = 0
        class_count[vote] += 1
    sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]

def create_tree(dataset, labels):
    """
    labels:指的features的label
    """
    class_list = [example[-1] for example in dataset]
    if class_list.count(class_list[0]) == len(class_list):
        return class_list[0] # 如果类别完全一致
    if len(dataset[0]) == 1:
        return majority_cnt(class_list)
    best_feat = choose_best_feature_to_split(dataset)
    best_feat_label = labels[best_feat]
    my_tree = {best_feat_label: {}}
    del labels[best_feat]
    feat_values = [example[best_feat] for example in dataset]
    unique_vals = set(feat_values)
    for value in unique_vals:
        sub_labels = labels[:]
        my_tree[best_feat_label][value] = create_tree(split_data_set(dataset, best_feat, value), sub_labels)
    return my_tree

In [36]:
choose_best_feature_to_split(X)

0

In [28]:
def create_dataset():
    dataset = [[1, 1, 'yes'],
              [1, 1, 'yes'],
              [1, 0, 'no'],
              [0, 1, 'no'],
              [0, 1, 'no']]
    features_labels = ['no surfacing', 'flippers']
    return dataset, features_labels

In [29]:
X, labels = create_dataset()
create_tree(X, labels)

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}