# Decision Trees: A Primer

Decision trees are excellent for classification tasks, and can even be modified to perform regression.

In this notebook, we will go through the steps required to implement a decision tree for classification, determining if a banknote is a counterfeit. 

This will set the basic framework for us to build on when implementing Random Forests, in the next notebook.

In [1]:
import numpy as np
from time import time

In [2]:
##Import our banknote dataset
data = np.loadtxt("data_banknote_authentication.csv",dtype=np.float32, delimiter=",")

In [3]:
##Partition the data into features X and predicted Y values
X = data[:,0:4]
y = data[:,4:]
del data

## What to split based on?

Decision trees involve splitting data points at each node/split based on a single feature. 

For example, if you were constructing a tree to predict whether an individual owns a smartphone, a reasonable feature to split could be their age.

However, we need some objective criteria to determine how good our split is, which will help our model to decide 
    1. Which feature to use for splitting, and; 
    2. What value of that feature to spilt on
    
A commonly used metric is the gini impurity, which evaluate how "pure" the sample group is. A completely pure group, where all the classes are identical would have a score of 0. 

We evaluate the quality of a given split by taking the weighted average of gini impurities of the two resulting groups created.

In [4]:
def gini_impurity(y):
    n = len(y)
    classes = set(np.squeeze(y,axis=1))
    impurity = 0
    for c in classes:
        p = sum(y == c)/n
        impurity += p*(1-p)
    return (impurity, n)

In [5]:
def weighted_impurity(y_left,y_right):
    gini_left, n_left = gini_impurity(y_left)
    gini_right, n_right = gini_impurity(y_right)
    impurity = ((gini_left[0]*n_left)+(gini_right[0]*n_right))/(n_left+n_right)
    return impurity

## Searching for a split point

We exhaustively loop through every single feature in the data, assessing quality of split using the gini impurity function above. At each node, we pick the split that gives us the lowest impurity. This continues until we hit a stoppping criteria

In [6]:
## Split the data based on a feature and its value
def split_data(X, i, value):
    left = X[:,i] < value
    right = X[:,i] >= value
    return (left,right)

In [7]:
## Evaluate each feature in the data to search for the best split
def search_split(X,y):
    best_i = 0
    best_value = 0
    best_impurity = 1
    for row in X:
        for i,value in enumerate(row):
            left, right = split_data(X,i,value)
            if left.sum()==0 or right.sum()==0:
                continue
            imp = weighted_impurity(y[left],y[right])
            if imp < best_impurity:
                best_i = i
                best_value = value
                best_impurity = imp
    return (best_i, best_value, best_impurity)

## Stopping Criteria

Given free reign, a decision tree will continue to split until every single leaf is "pure", i.e. contains only one class. This is undesirable as it might cause our model to overfit to the data, especially deep into the tree, where splitting might occur on very small groups.

Therefore when building our tree, we will define several parameters to constrain our tree, such as minimum number of datapoints to perform a split, as well as the depth of the tree constructed.

In [8]:
## Set leaf node to predict based on mean y values
def set_leaf(y):
    n = len(y)
    node = {}
    node["leaf"] = True
    classes = set(np.squeeze(y,axis=1))
    node["prediction"] = {}
    for c in classes:
        node["prediction"][c] = (sum(y==c)/n)[0]
    return node

In [9]:
## Use recursion to build the tree until a stopping criteria is met
def build_tree(X,y,min_n,max_depth,depth=1):
    node = {}
    i, value, imp = search_split(X,y)
    node["i"] = i
    node["value"] = value
    node["impurity"] = imp
    node["depth"] = depth
    left, right = split_data(X, i, value)

    ## Stopping conditions, 
    ##  1) Only 1 class left;
    ##  2) reach min_n num of points;
    ##  3) reach max depth of tree
    left_stop = (gini_impurity(y[left])[0][0]==0) or (left.sum()<10) or (depth==max_depth)
    right_stop = (gini_impurity(y[right])[0][0]==0) or (right.sum()<10) or (depth==max_depth)
    
    if left_stop:
        node["left"] = set_leaf(y[left])
    else:
        node["left"] = build_tree(X[left],y[left],min_n,max_depth,depth=depth+1)
        
    if right_stop:
        node["right"] = set_leaf(y[right])
    else:
        node["right"] = build_tree(X[right],y[right],min_n,max_depth,depth=depth+1)
    return node

In [10]:
tic = time()
tree = build_tree(X,y,min_n=10, max_depth=2)
toc = time()
print("Tree built in {:.2f} seconds.".format(toc-tic))

Tree built in 22.43 seconds.


## Using our Tree for Predictions

The output of our tree is a dictionary of classes and a predicted confidence between 0 and 1, indicating how confident the model is that given the data, it belongs to that particular class.

In [11]:
## Use recursion once again to navigate the tree
def tree_predict(node, X):
    if node.get("leaf"):
        return node["prediction"]
    i = node["i"]
    value = node["value"]
    if X[i] < value:
        return tree_predict(node["left"],X)
    else:
        return tree_predict(node["right"],X)

In [12]:
prediction = tree_predict(tree,X[100])
classes = list(prediction.keys())
class_1 = classes[0]
class_2 = classes[1]
print("Our model predicts: Class {} - {:.2f}%, Class {} - {:.2f}%".format(
        class_1,prediction[class_1],class_2,prediction[class_2]))

Our model predicts: Class 0.0 - 0.93%, Class 1.0 - 0.07%


## Visualizing our Tree

In [13]:
import pprint
pp = pprint.PrettyPrinter()
pp.pprint(tree)

{'depth': 1,
 'i': 0,
 'impurity': 0.24679933491786996,
 'left': {'depth': 2,
          'i': 1,
          'impurity': 0.15961960854754184,
          'left': {'leaf': True,
                   'prediction': {0.0: 0.070652173913043473,
                                  1.0: 0.92934782608695654}},
          'right': {'leaf': True,
                    'prediction': {0.0: 0.80952380952380953,
                                   1.0: 0.19047619047619047}},
          'value': 7.6273999},
 'right': {'depth': 2,
           'i': 2,
           'impurity': 0.13876960087955628,
           'left': {'leaf': True,
                    'prediction': {0.0: 0.23809523809523808,
                                   1.0: 0.76190476190476186}},
           'right': {'leaf': True,
                     'prediction': {0.0: 0.93313521545319467,
                                    1.0: 0.066864784546805348}},
           'value': -4.3839002},
 'value': 0.32229999}
