In [87]:
from math import log

[list.sort() vs sorted()](https://docs.python.org/3/howto/sorting.html#sortinghowto)

In [88]:
class ID3:
    def createDataSet(self):
        dataSet = [[1, 1, 'yes'],
                   [1, 1, 'yes'],
                   [1, 0, 'no'],
                   [0, 1, 'no'],
                   [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        return dataSet, labels
    
    def calcShannonEnt(self, dataSet):
        """Calculate the Shannon entropy of a dataset.

        :param dataSet: The dataset needs to calculated.
        """
        numEntries = len(dataSet)
        labelCounts = {}
        for featVec in dataSet:
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        shannonEnt = -sum([(value / numEntries) * log((value / numEntries), 2) 
                           for value in labelCounts.values()])
        return shannonEnt

    def cond_entropy(self, dataset, feature_idx):
        """Calculate the weighted entropy of several sub datasets.

        :param dataset: The raw dataset
        :param feature_idx: The index of feature splited the dataset.
        """
        dataset_size = len(dataset)
        sub_datasets = {}
        for data in dataset:
            feature_value = data[feature_idx]
            if feature_value not in sub_datasets:
                sub_datasets[feature_value] = []
            sub_datasets[feature_value].append(data)
        # Sub dataset's weighted entropy
        cond_entropy = sum([(len(sub_dataset) / dataset_size) * cal_entropy(sub_dataset) 
                            for sub_dataset in sub_datasets.values()])
        return cond_entropy

    def info_gain(self, entropy, cond_entropy):
        return entropy - cond_entropy

    def splitDataSet(self, dataSet, axis, value):
        """Dataset splitting on a given feature.
        
        :param axis: The index of the specified feature.
        :param value: The value of the specified feature.
        """
        retDataSet = []
        for featVec in dataSet:
            if featVec[axis] == value:
                # dropping the column of the feature.
                reducedFeatVec = featVec[: axis]
                reducedFeatVec.extend(featVec[axis + 1 :])
                retDataSet.append(reducedFeatVec)
        return retDataSet
    
    def chooseBestFeatureToSplit(self, dataSet):
        """Choose best  feature to split.
        
        :return bestFeature: The index of the best feature to split on.
        """
        numFeatures = len(dataSet[0]) - 1
        baseEntropy = self.calcShannonEnt(dataSet)
        bestInfoGain = 0.0
        bestFeature = -1
        for i in range(numFeatures):
            # Create unique list of class labels.
            featList = [example[i] for example in dataSet]
            uniqueVals = set(featList)
            # Calculate entropy for each split subdataset.
            weightedEntropy = 0.0
            for value in uniqueVals:
                subDataSet = self.splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / len(dataSet)
                weightedEntropy += prob * self.calcShannonEnt(subDataSet)
            infoGain = baseEntropy - weightedEntropy
            # Find the best information gain
            if infoGain > bestInfoGain:
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature
    
    def majorityCnt(classList):
        """Find the class that occurs with the greatest frequency.
        
        :param classList: A list of class names.
        :return: The class that occurs with the greatest frequency.
        """
        classCount = {}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote] = 0
            classCount[vote] += 1
    
    def createTree(self, dataSet, labels):
        """
        """
        # Stop when all classes are equal.
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0]) == len(classList):
            return classList[0]
        # When no more features, return majority.
        if len(dataSet[0]) == 1:
            return majority
        
        
        

## Creating dataset

In [89]:
trees = ID3()
myDat, labels = trees.createDataSet()
myDat

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

## Calculating the Shannon entropy 

In [90]:
trees.calcShannonEnt(myDat)

0.9709505944546686

## Splitting the dataset

In [91]:
trees.splitDataSet(myDat, 0, 1)

[[1, 'yes'], [1, 'yes'], [0, 'no']]

In [92]:
trees.splitDataSet(myDat, 0, 0)

[[1, 'no'], [1, 'no']]

## Choosing the best feature to split on

In [96]:
trees.chooseBestFeatureToSplit(myDat)

0