In [1]:
from math import log
import operator

def calcShannonEntropy(dataset):
    numEntries = len(dataset)
    
    labelCounts = {}
    
    for featureVector in dataset:
        currentLabel = featureVector[-1]
        
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        
        labelCounts[currentLabel] += 1
        
    shannonEntropy = 0.0
    
    for key in labelCounts:
        probability = float(labelCounts[key]) / numEntries
        shannonEntropy -= probability * log(probability,2)
            
    return shannonEntropy
    

In [2]:
def createDataset():
    dataset = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    
    labels = ['no surfacing', 'flippers']
    
    return dataset, labels 

In [3]:
def splitDataset(dataset, axis, value):
    result = []
    
    for featureVector in dataset:
        if featureVector[axis] == value:
            reducedFeatureVector = featureVector[:axis]
            reducedFeatureVector.extend(featureVector[axis+1:])
            
            result.append(reducedFeatureVector)
    
    return result

In [4]:
def chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0]) - 1
    
    baseEntropy = calcShannonEntropy(dataset)
    
    bestInfoGain = baseEntropy;
    bestFeature = -1
    
    for feature in range(numFeatures):
        featureValuesList = [example[feature] for example in dataset]
        
        uniqueFeatureValuesList = set(featureValuesList)
        
        newEntropy = 0.0
        
        for featureValue in uniqueFeatureValuesList:
            subDataset = splitDataset(dataset, feature, featureValue)
            
            probability = len(subDataset) / float(len(dataset))
            newEntropy += probability + calcShannonEntropy(subDataset)
        
        if newEntropy > bestInfoGain:
            bestInfoGain = newEntropy
            bestFeature = feature
    
    return bestFeature

In [5]:
def majorityCount(classList):
    classCount = {}
    
    for item in classList:  
        if item not in classCount.keys():
            classCount[item] = 0
        
        classCount[item] += 1
        
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    
    return sortedClassCount[0][0]

In [26]:
def createTree(dataset, labels):
    classList = [example[-1] for example in dataset]
    
    if len(set(classList)) == 1:
        return classList[0]
    
    if len(dataset[0]) == 1:
        return majorityCount(classList)
    
    bestFeature = chooseBestFeatureToSplit(dataset)
    print("Dataset:", dataset, "\nBest Feature:", bestFeature)
    
    bestFeatureLabel = labels[bestFeature]
    
    myTree = { bestFeatureLabel: {}}
    
    del(labels[bestFeature])
    
    featureValues = [example[bestFeature] for example in dataset]
    uniqueFeatureValues = set(featureValues)
    
    for value in uniqueFeatureValues:
        subLabels = labels[:]
        
        myTree[bestFeatureLabel][value] = createTree(splitDataset(dataset, bestFeature, value), subLabels)
        
    return myTree

In [28]:
myDataset, labels = createDataset()

In [30]:
createTree(myDataset, labels)

# REPASAR createTree y chooseBestFeatureToSplit

Dataset: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] 
Best Feature: 1


IndexError: list index out of range

In [18]:
myTree

{'flippers': {0: 'no', 1: {'no surfacing': {'no': 'no', 'yes': 'yes'}}}}