# CART

定义Gini指数的计算函数

In [None]:
def gini(nums):
    probs = [nums.count(i)/len(nums) for i in set(nums)]
    gini = sum([p*(1-p) for p in probs]) 
    return gini

读入数据并计算标签的Gini指数

In [None]:
import numpy as np
import pandas as pd
from math import log

df = pd.read_csv('../data/example_data.csv', dtype={'windy': 'str'})
gini(df['play'].tolist())

定义根据特征分割数据框的函数

In [None]:
def split_dataframe(data, col):
    '''
    function: split pandas dataframe to sub-df based on data and column.
    input: dataframe, column name.
    output: a dict of splited dataframe.
    '''
    # unique value of column
    unique_values = data[col].unique()    
    # empty dict of dataframe
    result_dict = {elem : pd.DataFrame for elem in unique_values}    
    # split dataframe based on column value
    for key in result_dict.keys():
        result_dict[key] = data[:][data[col] == key]    
    return result_dict

根据温度特征对数据进行划分

In [None]:
split_dataframe(df,'temp')

根据Gini指数和条件Gini指数计算递归选择最优特征，定义函数如下

In [None]:
def choose_best_col(df, label):
    '''
    funtion: choose the best column based on infomation gain.
    input: datafram, label
    output: max infomation gain, best column, 
            splited dataframe dict based on best column.
    '''
    # Calculating label's gini index
    gini_D = gini(df[label].tolist())    
    # columns list except label
    cols = [col for col in df.columns if col not in [label]]    
    # initialize the max infomation gain, best column and best splited dict
    min_value, best_col = 999, None
    min_splited = None
    # split data based on different column
    for col in cols:
        splited_set = split_dataframe(df, col)
        gini_DA = 0
        for subset_col, subset in splited_set.items():            
        # calculating splited dataframe label's gini index
            gini_Di = gini(subset[label].tolist())            
            # calculating gini index of current feature
            gini_DA += len(subset)/len(df) * gini_Di        
            if gini_DA < min_value:
                min_value, best_col = gini_DA, col
                min_splited = splited_set    
            return min_value, best_col, min_splited

定义CART分类树的构建过程

In [None]:
class CartTree:    
    # define a Node class
    class Node:        
        def __init__(self, name):
            self.name = name
            self.connections = {}    

        def connect(self, label, node):
            self.connections[label] = node    

    def __init__(self, data, label):
        self.columns = data.columns
        self.data = data
        self.label = label
        self.root = self.Node("Root")    

    # print tree method
    def print_tree(self, node, tabs):
        print(tabs + node.name)        
        for connection, child_node in node.connections.items():
            print(tabs + "\t" + "(" + connection + ")")
            self.print_tree(child_node, tabs + "\t\t")    

    def construct_tree(self):
        self.construct(self.root, "", self.data, self.columns)    

    # construct tree
    def construct(self, parent_node, parent_connection_label, input_data, columns):
        min_value, best_col, min_splited = choose_best_col(input_data[columns], self.label)   
        if not best_col:
            node = self.Node(input_data[self.label].iloc[0])
            parent_node.connect(parent_connection_label, node)            
            return

        node = self.Node(best_col)
        parent_node.connect(parent_connection_label, node)

        new_columns = [col for col in columns if col != best_col]        
        # Recursively constructing decision trees
        for splited_value, splited_data in min_splited.items():
            self.construct(node, splited_value, splited_data, new_columns)

执行代码生成CART分类树

In [None]:
treel = ID3Tree(df,'play')
treel.construct_tree()
treel.print_tree(treel.root,"")