# Understanding Decision Tree Classifier

In [1]:
'''
    To classify for fish or not fish 
    0 = Can survive without coming to surface?
    1 = Has flippers?
    2 = Fish?
'''
def create_dataset():
    dataset = [
        [1,1,'yes'],
        [1,1,'yes'],
        [1,0,'no'],
        [0,1,'no'],
        [0,1,'no'],
    ]
    labels = ['no surfacing','flippers']
    return dataset,labels

In [2]:
dataset,labels = create_dataset()

In [3]:
#Calculate Shanon Entropy 
from math import log

def calcShanonEnt(dataSet):
    numEntries = len(dataset)
    labelCount = {} # dictionary for xi
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCount:
            labelCount[currentLabel] = 0
        labelCount[currentLabel] +=1
    shanonEnt = 0.0
    for key in labelCount:
        prob = float(labelCount[key])/numEntries # P(xi) = xi/n
        shanonEnt -= prob * log(prob,2)        
    return shanonEnt

In [4]:
def splitDataset(dataSet,axis,value):
    '''
        dataset -> dataset to split
        axis -> the feature  to split on
        value -> value of feature to return
    '''
    
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec =  featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

In [5]:
splitDataset(dataset,0,0)

[[1, 'no'], [1, 'no']]

In [6]:
def chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0])-1
    baseEntropy = calcShanonEnt(dataset)
    baseInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataset]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataset = splitDataset(dataset,i,value)
            prob = len(subDataset)/float(len(dataset))
            newEntropy +=prob *calcShanonEnt(subDataset)
            infoGain = baseEntropy - newEntropy
            if (infoGain>baseInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature            

In [7]:
import operator
def majorityCount(classList):
    classCount = {}
    for label in classList:
        if label not in classCount.keys():classCount[label] = 0
        classCount[label] += 1
    sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

In [8]:
def createTree(dataset,lables):
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(classList[0]) == 1:
        return majorityCount(classList)
    bestFeat = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = lables[bestFeat]
    tree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataset]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        tree[bestFeatLabel][value] = createTree(splitDataset(dataset,bestFeat,value),subLabels)
    return tree
    

In [9]:
createTree(dataset,labels)

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}