In [45]:
from math import log

In [55]:
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1] # 获取最后一列
        if currentLabel not in labelCounts.keys(): # 统计单个label出现次数
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

In [56]:
def createDataSet():
    dataSet = [[1, 1, "yes"], [1, 1, "yes"], [1, 0, "no"], [0, 1, "no"], [0, 1, "no"]]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

In [61]:
def splitDataSet(dataSet, axis, value):
    '''
    按照特征value，划分数据集
    @param
    axis: 第几列数据
    value: 获取的第axis与value做比较
    '''
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

In [62]:
def chooseBestFeatureToSplit(dataSet):
    '''
    选择最好的数据划分方式
    数据有列元素组成，且长度相同，最后一列为label
    
    按列进行遍历， 先获取一列中有几种不同的值（uniqueVals），然后计算
    这一列按这几个不同值进行划分后的熵并累加，接着取累计熵最大列作为最好的
    特征列输出（bestFeature), 说明按第bestFeature列的数据进行划分，
    分类效果最好。
    '''
    numFeatures = len(dataSet[0]) - 1 # 最后一列是label
    baseEntroy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList) #第i数据共有几种不同值
        newEntropy = 0.0
        for value in uniqueVals:#计算第i列数据，按不同value的熵，并累计
            subDatasSet = splitDataSet(dataSet, i, value)
            prob = len(subDatasSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDatasSet)
        infoGain = baseEntroy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

In [63]:
dataset, labels = createDataSet()

In [64]:
chooseBestFeatureToSplit(dataset)

0