In [29]:
import matplotlib.pyplot as plt
import math
from copy import deepcopy

In [6]:
def CreateDateSet():
    dataset = [
        [0, 0, 0, 0, 'no'],
        [0, 0, 0, 1, 'no'],
        [0, 1, 0, 1, 'yes'],
        [0, 1, 1, 0, 'yes'],
        [0, 0, 0, 0, 'no'],
        [1, 0, 0, 0, 'no'],
        [1, 0, 0, 1, 'no'],
        [1, 1, 1, 1, 'yes'],
        [1, 0, 1, 2, 'yes'],
        [1, 0, 1, 2, 'yes'],
        [2, 0, 1, 1, 'yes'],
        [2, 1, 0, 1, 'yes'],
        [2, 1, 0, 2, 'yes'],
        [2, 0, 0, 0, 'no'],
    ]
    
    labels = ['F1-AGE', 'F2-WORK', 'F3-HOME', 'F4-LOAN']

    return dataset, labels

- 特征的索引是根据数据集的列来定的
- 在`split_dataset`中使用`del`，会删除该列的特征值，导致在下一次调用`choose_best_feature_to_split`时，数据集的特征数减少，但迭代时仍使用了原始列的数量`num_features = len(dataset[0]) - 1`

In [42]:
class CreatTree:
    def __init__(self, dataset, labels):
        self.dataset = dataset
        self.labels = labels

    def creat_tree(self, dataset, labels, feat_labels):
        classlist = [example[-1] for example in dataset]

        if classlist.count(classlist[0]) == len(classlist):
            return classlist[0]

        if len(dataset[0]) == 1:
            return CreatTree.majority_ent(classlist)
        #
        best_feat = self.choose_best_feature_to_split(dataset) # 最优特征的索引值
        best_feat_label = labels[best_feat]
        feat_labels.append(best_feat_label)
        #
        my_tree = {best_feat_label:{}}
        #
        del labels[best_feat]
        #
        feat_values = [example[best_feat] for example in dataset]
        unique_feat_value = set(feat_values)
        for value in unique_feat_value:
            sublabels = labels[:]
            split_data = CreatTree.split_dataset(dataset, best_feat, value)
            my_tree[best_feat_label][value] = self.creat_tree(CreatTree.split_dataset(dataset, best_feat, value), sublabels, feat_labels)
        return my_tree
        
    @staticmethod
    def majority_ent(classlist):
        class_count = {}
        for vote in classlist:
            if vote not in class_count:
                class_count[vote] = 1
            else:
                class_count[vote] += 1
        class_counted = sorted(class_count, key=lambda x:class_count[x], reverse=True)
        # class_counted 是一个 list。该 list 中包含字典的键，并且这些键是按对应的值降序排列的
        return class_counted[0]

    def choose_best_feature_to_split(self, dataset):
        num_features = len(dataset[0]) - 1
        ent = CreatTree.calculate_ent(dataset)
        best_information_gain = 0
        best_feat = None
        for i in range(num_features):
            feat_list = [example[i] for example in dataset] # 得到当前列
            unique_value = set(feat_list) # 此特征下的所有类
            gain = ent
            for value in unique_value:
                sub_dataset = self.split_dataset(dataset, i, value)
                prop = float(len(sub_dataset) / len(feat_list))
                gain -= prop * CreatTree.calculate_ent(sub_dataset)
            if (gain > best_information_gain):
                best_information_gain = gain
                best_feat = i
        return best_feat
            
    @staticmethod
    def calculate_ent(dataset):
        sum_number = len(dataset)
        # 每个类的数量
        label_count = {}
        for feat_vec in dataset:
            current_label = feat_vec[-1]
            if current_label not in label_count.keys():
                label_count[current_label] = 1
            else:
                label_count[current_label] += 1
        ent = 0
        for key in label_count:
            prop = float(label_count[key] / sum_number)
            ent -= (prop * math.log(prop))
        return ent

    @staticmethod
    def split_dataset(dataset, feat_val, value):
        sub_dataset = []
        for feat_vec in dataset:
            if feat_vec[feat_val] == value:
                # reduced_feat_vec = feat_vec[:feat_val] + feat_vec[feat_val + 1:]
                # sub_dataset.append(feat_vec)
                
                # 或者
                feat_vec_copy = deepcopy(feat_vec)
                del feat_vec_copy[feat_val]
                sub_dataset.append(feat_vec_copy)
                
                # del feat_vec[feat_val]

        return sub_dataset
        # return sub_dataset if sub_dataset else None

In [43]:
dataset, labels = CreateDateSet()
creat_my_tree = CreatTree(dataset, labels)

feat_labels = []
creat_my_tree.creat_tree(dataset, labels, feat_labels)

{'F2-WORK': {0: {'F3-HOME': {0: 'no', 1: 'yes'}}, 1: 'yes'}}