In [1]:
import numpy as np
LARGE_VALUE = 100000000

In [2]:
def get_gini_score(branches, labels):
    all_samples = [len(branch) for branch in branches]
    nsamples = sum(all_samples)
    gini_score = 0.0
    for branch in branches:
        branch_size = len(branch)
        if branch_size == 0.0: continue
        score = 0.0
        for label in labels:
            # calculate the number of samples 
            # that is labeled the given label
            y = len(branch[branch[:, -1]==label])
            # calculate its proportion to the branchesize
            p = y / branch_size
            score += (p**2)

        gini_score += (1.0 - score) * (branch_size / nsamples)
    return gini_score

In [3]:
# Returns two lists (left_set, right_set) of samples
# split the `samples` based on the feature at `feature_id` (col)
# the split_value is the splitting point 
# We'll check later if this split is a good split 
def try_potential_split(samples, split_value, feature_id):
    left_branch = samples[samples[:, feature_id] < split_value]
    right_branch = samples[samples[:, feature_id] >= split_value]
    return left_branch, right_branch

In [4]:
def select_best_split(samples):
    # The labels is contained at the last column of the numpy matrix
    labels = samples[:, -1]
    unique_labels = np.unique(labels)
    nsamples = len(labels)
    nfeatures = len(samples[-1, :]) - 1
    
    best_feature_id = None
    best_split_value = None 
    best_gini_score = LARGE_VALUE
    best_two_branches = None
    node = {} 
    
    for feature_id in range(nfeatures):
        for sample in samples:
            
            split_value = sample[feature_id]
            two_branches = try_potential_split(samples, split_value, feature_id)
            gini_score = get_gini_score(two_branches, unique_labels)
            
            if gini_score < best_gini_score:
                best_feature_id = feature_id
                best_split_value = split_value
                best_gini_score = gini_score
                best_two_branches = two_branches 
    
    node['feature_id'] = best_feature_id
    node['split_value'] = best_split_value
    node['two_branches'] = best_two_branches
    
    return node

In [5]:
# select the label given the branch of samples
# returns the most frequent label
def leaf_node_label(branch):
    return np.argmax(np.bincount([int(sample[-1]) for sample in branch]))

In [6]:
def recursive_build_tree(node, max_depth, min_size, depth):
    left_branch, right_branch = node['two_branches']
    del(node['two_branches'])

    if len(left_branch) == 0 and len(right_branch) != 0:
        label = leaf_node_label(right_branch)
        node['left_label'] = label
        node['right_label'] = label
        return
    
    if len(left_branch) != 0 and len(right_branch) == 0:
        label = leaf_node_label(left_branch)
        node['left_label'] = label
        node['right_label'] = label
        return

    # left_branch and right_branch cannot be zero anymore
    # at this point
    if depth >= max_depth:
        node['left_label'] = leaf_node_label(left_branch) 
        node['right_label'] = leaf_node_label(right_branch)
        return
    
    if len(left_branch) <= min_size:
        node['left_label'] = leaf_node_label(left_branch)
    else:
        node['left_label'] = select_best_split(left_branch)
        recursive_build_tree(node['left_label'], max_depth, min_size, depth+1)

    if len(right_branch) <= min_size:
        node['right_label'] = leaf_node_label(right_branch)
    else:
        node['right_label'] = select_best_split(right_branch)
        recursive_build_tree(node['right_label'], max_depth, min_size, depth+1)
        

In [7]:
def decisionTree(data, max_depth, min_samples_leaf, min_samples_split=None):
    # min_samples_split is unused as of the moment
    root = select_best_split(data)
    recursive_build_tree(root, max_depth, min_samples_leaf, 1)
    return root

In [8]:
def print_tree_helper(depth):
    for _ in range(depth + 1):
        print("-", end="")

def print_tree(node, depth=0):
    
    if isinstance(node, dict):
        print_tree_helper(depth)
        
        print(('> f{:d} < {:2.3f}'.format(node['feature_id'], node['split_value'])))
        
        if 'left_label' in node.keys():
            print_tree(node['left_label'], depth + 1)
        
        if 'right_label' in node.keys():
            print_tree(node['right_label'], depth + 1)

    else:
        print_tree_helper(depth)
        print('-> [', node, ']')

In [9]:
# example data set taken from
#https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
dataset = [[2.771244718,1.784783929,0],
    [1.728571309,1.169761413,0],
    [3.678319846,2.81281357,0],
    [3.961043357,2.61995032,0],
    [2.999208922,2.209014212,0],
    [7.497545867,3.162953546,1],
    [9.00220326,3.339047188,1],
    [7.444542326,0.476683375,1],
    [10.12493903,3.234550982,1],
    [6.642287351,3.319983761,1]]

tree = decisionTree(np.array(dataset), 3, 1, 1)
print_tree(tree)

-> f0 < 6.642
--> f0 < 2.771
----> [ 0 ]
---> f0 < 2.771
-----> [ 0 ]
-----> [ 0 ]
--> f0 < 7.498
---> f0 < 7.445
-----> [ 1 ]
-----> [ 1 ]
---> f0 < 7.498
-----> [ 1 ]
-----> [ 1 ]


In [10]:
# Generate Synthetic Data Set
dataset = 10 * np.random.rand(200, 15)
labels = []

for i in range(200):
    labels.append(np.random.randint(0, 10))

dataset[:, -1] = np.array(labels).T

#Test and print decision tree
tree = decisionTree(dataset, 12, 3)
print_tree(tree)

-> f8 < 9.669
--> f10 < 4.252
---> f0 < 7.763
----> f11 < 9.108
-----> f5 < 7.320
------> f4 < 5.658
-------> f6 < 5.967
--------> f8 < 2.020
----------> [ 6 ]
---------> f8 < 7.000
----------> f3 < 3.966
------------> [ 1 ]
-----------> f0 < 6.526
-------------> [ 8 ]
-------------> [ 0 ]
----------> f3 < 8.974
-----------> f3 < 3.537
-------------> [ 8 ]
------------> f0 < 7.198
--------------> [ 2 ]
--------------> [ 2 ]
------------> [ 4 ]
--------> f9 < 3.473
---------> f2 < 5.467
-----------> [ 1 ]
----------> f4 < 3.288
------------> [ 0 ]
------------> [ 6 ]
---------> f0 < 0.701
-----------> [ 9 ]
----------> f0 < 0.701
------------> [ 9 ]
------------> [ 9 ]
-------> f2 < 5.670
--------> f8 < 2.119
---------> f5 < 3.532
-----------> [ 9 ]
-----------> [ 4 ]
---------> f5 < 4.812
----------> f12 < 1.654
------------> [ 9 ]
-----------> f0 < 5.287
------------> f0 < 0.094
--------------> [ 7 ]
--------------> [ 7 ]
-------------> [ 7 ]
-----------> [ 4 ]
--------> f3 < 6.943
--