In [4]:
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
from math import log
import operator
import pickle

In [5]:
def createDataset():
    dataSet=[
        [0,0,0,0,'no'],
        [0,0,0,1,'no'],
        [0,1,0,1,'yes'],
        [0,1,1,0,'yes'],
        [0,0,0,0,'no'],
        [1,0,0,0,'no'],
        [1,0,0,1,'no'],
        [1,1,1,1,'yes'],
        [1,0,1,2,'yes'],
        [1,0,1,2,'yes'],
        [2,0,1,2,'yes'],
        [2,0,1,1,'yes'],
        [2,1,0,1,'yes'],
        [2,1,0,2,'yes'],
        [2,0,0,0,'no']
    ]
    labels=['F1_AGE','F2_WORK','F3_HOME','F4_LOAN']
    return dataSet,labels
# 创建决策树
def createTree(dataset,labels,featureLabels):
    classList=[example[-1] for example in dataset]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataset)==1: # 原来的数据集中只剩下标签列，其他属性列均已完成分类
        return majorityCnt(classList) # 最多的类别
    bestFeat=chooseBestFeatureToSplit(dataset) # 选择信息增益最大的特征去进行分类
    bestFeatLabel=labels[bestFeat]
    featureLabels.append(bestFeatLabel)
    myTree={bestFeatLabel:{}}
    del labels[bestFeat]
    featValue=[example[bestFeat] for example in dataset]
    uniqueVal=set(featValue)
    # 递归构建子树
    for value in uniqueVal:
        sublabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataset(dataset,bestFeat,value),sublabels,featureLabels)
    return myTree


In [6]:
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=1
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) # 排序找到最大值
    return sortedClassCount[0][0]

# 计算熵值，找到最佳分类特征
def chooseBestFeatureToSplit(dataset):
    numFeatures=len(dataset[0])-1 # 获取特征数量
    baseEntropy=calcShannonEnt(dataset) # 计算初始的熵
    bsetInfoGain=0
    bestFeature=-1
    for i in range(numFeatures):
        feaList=[example[i] for example in dataset]
        uniqueVals=set(feaList)
        newEntropy=0
        # 注意分析熵值
        for val in uniqueVals:
            subDataset=splitDataset(dataset,i,val) # 删除dataset中的当前列
            prop=len(subDataset)/float(len(dataset))
            newEntropy+=prop*calcShannonEnt(subDataset)
        infoGain=baseEntropy-newEntropy # 计算新的信息增益
        if infoGain>bsetInfoGain:
            bestFeature=i
    
    return bestFeature

def splitDataset(dataset,axis,val):
    retDataset=[]
    for featVec in dataset:
        if featVec[axis]==val:
            reduceFeatVec=featVec[:axis]
            reduceFeatVec.extend(featVec[axis+1:]) # 加上该轴后面的特征
            retDataset.append(reduceFeatVec)
    return retDataset

# 计算熵值
def calcShannonEnt(dataset):
    numexamples=len(dataset)
    labelCounts={}
    for featVet in dataset:
        currentlabel=featVet[-1]
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel]=0
        labelCounts[currentlabel]+=1
    
    shannon=0
    for key in labelCounts:
        prop=float(labelCounts[key]/numexamples) # 计算当前类在总样本中所占的比例
        shannon-=prop*log(prop,2) # 计算熵值
    return shannon