In [1]:
import numpy as np
import pandas as pd
from math import log


In [2]:
class Tree:
    def __init__(self, value=None, trueBranch=None, falseBranch=None, results=None, col=-1, summary=None, data=None):
        self.value = value
        self.trueBranch = trueBranch
        self.falseBranch = falseBranch
        self.results = results
        self.col = col
        self.summary = summary
        self.data = data

In [3]:
# 载入数据
def loadData():
    pima = pd.read_csv("../Pima.csv",header=None)
    dataMat = np.array(pima.iloc[:,0:10].values.tolist())
    return dataMat

In [4]:
# 计算熵
def calShannon(dataSet):
    # 创建字典，统计该数据集中的各个标签的数量
    lables = {}
    for line in dataSet:
        Lable = line[-1]
        if Lable not in lables.keys():
            lables[Lable] = 0
        lables[Lable] += 1
    length = len(dataSet)
    shannon = 0.0
    for key in lables:
        p = float(lables[key]/length)
        shannon -= p*log(p,2)
    return shannon

In [5]:
# 计算GINI，gini表示不纯度，越小越纯，越大越不纯
def calGini(dataSet):
    # 创建字典，统计该数据集中的各个标签的数量
    lables = {}
    for line in dataSet:
        Lable = line[-1]
        if Lable not in lables.keys():
            lables[Lable] = 0
        lables[Lable] += 1
    # 计算gini
    length = len(dataSet)
    gini = 1.0
    for key in lables.keys():
        gini -= (lables[key]/length)**2
    return gini

In [6]:
# 对数据集dataSet，对于第col列特征，根据value划分为两个数据集
def splitData(dataSet,col,value):
    data1 = []
    data2 = []
    for line in dataSet:
        if(line[col] >= value):
            data1.append(line)
        else:
            data2.append(line)
    return data1,data2

In [7]:
def calculateDiffCount(datas):
    #将输入的数据汇总(input dataSet)
    #return results Set{type1:type1Count,type2:type2Count ... typeN:typeNCount}
    
    results = {}
    for data in datas:
        #data[-1] means dataType
        if data[-1] not in results:
            results[data[-1]] = 1
        else:
            results[data[-1]] += 1
    return results

In [8]:
# 递归调用，选取最佳的特征和最佳特征当中的最佳分割值
def BuildCartDecisionTree(dataSet):
    
    # 目前的gini
    currentgini = calGini(dataSet)
    # 列数
    column_length = len(dataSet[0])
    # 行数（样本数）
    rows_length = len(dataSet)
    
    # giniIndex的差
    best_gini_gain = 0.0
    
    best_value = None
    best_set = None
    
    for col in range(column_length-1):
        values = set([x[col] for x in dataSet])
        for value in values:
            data1,data2 = splitData(dataSet,col,value)
            p = len(data1)/rows_length
            gini = p*calGini(data1)+(1-p)*calGini(data2)
            gain = currentgini-gini
            if(gain > best_gini_gain):
                best_gini_gain = gain
                best_value = (col,value)
                best_set = (data1,data2)
                
    dcY = {'impurity' : '%.3f' % currentgini,
           'samples' : '%d' % rows_length}
                
    if(best_gini_gain > 0.0):
        trueBranch = BuildCartDecisionTree(best_set[0])
        falseBranch = BuildCartDecisionTree(best_set[1])
        return Tree(col=best_value[0],
                    value=best_value[1],
                    trueBranch=trueBranch,
                    falseBranch=falseBranch,
                    summary=dcY)
    else:
        return Tree(results=calculateDiffCount(dataSet),
                    summary=dcY,
                    data=dataSet)

In [14]:
# 基于树对数据分类
def classify(data,tree):
    # 若为树枝末端
    if tree.results != None:
        return tree.results
    else:
        branch = None
        v = data[tree.col]
        if v >= tree.value:
            branch = tree.trueBranch
        else:
            branch = tree.falseBranch
        
        return classify(data, branch)

In [12]:
dataMat = loadData()
tree = BuildCartDecisionTree(dataMat)

In [15]:
classify(dataMat[1],tree)


{0.0: 30}

In [16]:
dataMat[1]

array([ 1.   , 85.   , 66.   , 29.   ,  0.   , 26.6  ,  0.351, 31.   ,
        0.   ])