I accidentally noticed a really awesome blog post that explains how to build a decision tree from scratch in Python http://www.patricklamle.com/Tutorials/Decision%20tree%20python/tuto_decision%20tree.html

To test my own understanding, I wrote the following mock code to reproduce the results. All the credits go to the above blog author.

In [43]:
import sys
import numpy as np

In [44]:
sys.setrecursionlimit(1000)

In [25]:
def countTargets(rows):
    value_counts = {}
    for row in rows:
        r = row[-1]
        if r not in value_counts:
            value_counts[r] = 0
        value_counts[r] += 1
    return value_counts
    
def entropy(rows):
    value_counts = countTargets(rows)
    entr = 0.0
    for val in value_counts.keys():
        p = float(value_counts[val])/float(len(rows))
        entr -= p*np.log2(p)
    return entr

def split(rows, col_idx, col_val):
    split_func = None
    if isinstance(col_val, int) or isinstance(col_val, float):
        split_func = lambda row: row[col_idx] >= col_val
    else:
        split_func = lambda row: row[col_idx] == col_val
    
    true_set = [row for row in rows if split_func(row)]
    false_set = [row for row in rows if not split_func(row)]
    return (true_set, false_set)


In [20]:
class TreeNode(object):
    def __init__(self, col=-1, val=None, data=None, tb=None, fb=None):
        # col: index for the column that's used for the splitting
        # value for the splitting critetia
        # data: # if this node is a leaf node, data will be the target counts in this leaf node, None otherwise
        # tb: True branch, branch for the examples that meet the splitting criteria
        # fb: False branch, branch for the examples that do not meet the splitting criteria
        self.col = col
        self.val = val
        self.data = data 
        self.tb = tb 
        self.fb = fb
        

In [46]:
def buildTree(rows, score_function=entropy):
    if not rows:
        return TreeNode()
    
    current_score = score_function(rows)

    best_gain = 0
    best_criteria = None
    best_split = None    
    
    n_features = len(rows[0]) - 1
    for col in range(n_features):
        distinct_values = {}        
        for row in rows:
            distinct_values[row[col]] = 1
        for val in distinct_values.keys():
            true_set, false_set = split(rows, col, val)
            w1 = len(true_set)/float(len(rows))
            gain = current_score - w1*score_function(true_set) - (1.0-w1)*score_function(false_set)
            if gain > best_gain and len(true_set) > 0 and len(false_set) > 0:
                best_gain = gain
                best_criteria = (col, val)
                best_split = true_set, false_set
    if best_gain > 0:
        true_branch = buildTree(best_split[0])
        false_branch = buildTree(best_split[1])
        return TreeNode(col=best_criteria[0], val=best_criteria[1], tb=true_branch, fb=false_branch)
    else:
        return TreeNode(data=countTargets(rows))


In [51]:
def showTree(root):
    showTreeRrecursive(root, 0)
    
def showTreeRrecursive(root, depth):
    indent = '    '*depth
    if root.data is not None:
        print(indent+str(root.data))
    else:
        print(indent+'col_index: {0:}, splitting value: {1:}'.format(root.col, root.val))
        print(indent+'T->')
        showTreeRrecursive(root.tb, depth+1)
        print(indent+'F->')
        showTreeRrecursive(root.fb, depth+1)
        

In [23]:
my_data=[['slashdot','USA','yes',18,'None'],
        ['google','France','yes',23,'Premium'],
        ['digg','USA','yes',24,'Basic'],
        ['kiwitobes','France','yes',23,'Basic'],
        ['google','UK','no',21,'Premium'],
        ['(direct)','New Zealand','no',12,'None'],
        ['(direct)','UK','no',21,'Basic'],
        ['google','USA','no',24,'Premium'],
        ['slashdot','France','yes',19,'None'],
        ['digg','USA','no',18,'None'],
        ['google','UK','no',18,'None'],
        ['kiwitobes','UK','no',19,'None'],
        ['digg','New Zealand','yes',12,'Basic'],
        ['slashdot','UK','no',21,'None'],
        ['google','UK','yes',18,'Basic'],
        ['kiwitobes','France','yes',19,'Basic']]

In [47]:
tree = buildTree(my_data)

In [52]:
showTree(tree)

col_index: 0, splitting value: google
T->
    col_index: 3, splitting value: 21
    T->
        {'Premium': 3}
    F->
        col_index: 2, splitting value: no
        T->
            {'None': 1}
        F->
            {'Basic': 1}
F->
    col_index: 0, splitting value: slashdot
    T->
        {'None': 3}
    F->
        col_index: 2, splitting value: yes
        T->
            {'Basic': 4}
        F->
            col_index: 3, splitting value: 21
            T->
                {'Basic': 1}
            F->
                {'None': 3}


In [32]:
print(countTargets(my_data))

{'None': 7, 'Basic': 6, 'Premium': 3}


In [31]:
print(countTargets(split(my_data, 3, 20)[0]))
print('')
print(countTargets(split(my_data, 3, 20)[1]))

{'None': 1, 'Basic': 3, 'Premium': 3}

{'None': 6, 'Basic': 3}


In [33]:
set1, set2 = split(my_data, 3, 20)
entropy(set1), entropy(set2), entropy(my_data)


(1.4488156357251847, 0.91829583405448956, 1.5052408149441479)