In [1]:
import math
import operator
import matplotlib.pyplot as plt

In [2]:
#计算信息熵
def clacShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCnt = {}
    for featVec in dataSet:
        tmpLabel = featVec[-1]
        if tmpLabel not in labelCnt.keys():
            labelCnt[tmpLabel] = 0
        labelCnt[tmpLabel] += 1
    entropy = 0.0
    for key in labelCnt:
        prob = float(labelCnt[key])/numEntries
        entropy -= prob * math.log(prob, 2)
    return entropy

In [3]:
#计算信息熵小测试
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    return dataSet, labels
data, label = createDataSet()
print(clacShannonEnt(data))

0.9709505944546686


In [4]:
#分割数据集
#dataSet 数据集
#axis 第几列
#value  要对比的值
def splitDataSet(dataSet, axis, value):
    ret = []
    for data in dataSet:
        if data[axis] == value:
            tmpVec = data[:axis]
            tmpVec.extend(data[axis+1:])
            ret.append(tmpVec)
    return ret

In [5]:
#分割数据集测试
print(splitDataSet(data,0,1))
print(splitDataSet(data,2,'yes'))

[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 1], [1, 1]]


In [6]:
#信息增益
#选择最好的决策点，基于信息增益来选择
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = clacShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    
    for i in range(numFeatures):
        featList = [data[i] for data in dataSet]   #记录每列上的值
        uniqueVals = set(featList)    #将该列上面的值unique
        tmpEntropy = 0.0
        for value in uniqueVals:         #计算该列上面某一个特征值的信息熵
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            tmpEntropy += prob * clacShannonEnt(subDataSet)
        infoGain = baseEntropy - tmpEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

In [7]:
#返回出现次数最多的
def majorityCnt(classList):
    classCnt = {}
    for vote in classList:
        if vote not in classCnt.keys():
            classCnt[vote] = 0
        classCnt[vote] += 1
    sortedClassCnt = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCnt[0][0]

In [8]:
#递归创建决策树
def createTree(dataSet, labels):
    classList = [data[-1] for data in dataSet]
    if  classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
#     print(bestFeatLabel)
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [data[bestFeat] for data in dataSet]
    uniqueVals = set(featValues)
    for val in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][val] = createTree(splitDataSet(dataSet, bestFeat, val), subLabels)
    return myTree

In [9]:
#创建决策树小测试
decisionTree = createTree(data, label)
print(decisionTree)

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}


In [10]:
#predict
#inputTree ：生成好的决策树
#featLabels ： 标签
#testVec ： 测试向量
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                print(secondDict)
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                print(secondDict)
                classLabel = secondDict[key]
    return classLabel

In [11]:
data, label = createDataSet()
classify(decisionTree, label, [1, 0])

{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
{0: 'no', 1: 'yes'}


'no'