# Lab session week 5 - extra: Implementing your own decision tree

In this **extra** lab session you will build your own implementation of a CART decision tree.

In [1]:
from IPython.display import HTML
import matplotlib
import pandas as pd
import numpy.testing as npt

## Implement the CART decision tree algorithm
See the slides for the explanation of both gini impurity and information gain.

We will implement this algorith in the following steps:
1. Write functions for calculating gini impurity and information gain
2. Write a function that is able to find the best next question (using the previously implemented methods)
3. Use the construct tree method to build the tree
4. Write a classify method to use the tree for classification

This is the dataset that we will use. Please note that the categorical variable color has been one-hot encoded (converted into boolean values)

We can also easiliy calculate the possibilities by hand so that you can check if your implementation is correct:

In [2]:
headings = ['isRed', 'isBlack', 'isYellow', 'seats', 'label']

data_with_labels = pd.DataFrame( 
       [[1, 0, 0, 5, 'car'],
        [0, 1, 0, 2, 'car'],
        [0, 1, 0, 2, 'motorbike'],
        [0, 0, 1, 25, 'train'],
        [0, 0, 1, 30, 'train']], columns=headings)

display(data_with_labels)

Unnamed: 0,isRed,isBlack,isYellow,seats,label
0,1,0,0,5,car
1,0,1,0,2,car
2,0,1,0,2,motorbike
3,0,0,1,25,train
4,0,0,1,30,train


We will use the following class to make our implementation easier. Each instance of Question represents one possible question. In this example we will use only numerical questions. Each question contains the name of a column in the original dataset and the value. The question will always use the operator >=.

**For example: Question('isBlack', 5) represents the question dataset['isBlack'] >= 5**

This question object is also able to partition your dataset by this question. Use the partition function for this.

In [3]:
class Question:
    """Class that represents a question
    Each question contains the column name and the value
    The question will then represent dataset[col_name] >= value
    """
    
    def __init__(self, col_name, value):
        self.col_name = col_name
        self.value = value

    def partition(self, data):
        """Partitions the dataset using this question
        Returns the true partion and the false partion, as one tuple with two dataframes
        """
        return data[data[self.col_name] >= self.value], data[data[self.col_name] < self.value]

    def answer(self, sample):
        return sample[self.col_name] >= self.value
    
    def __repr__(self):
        """toString() function for Python
        """
        return str(self.col_name) + ' >= ' + str(self.value)

We will already give you a function to get all possible questions of a given dataset. This method will return a list with all possible (remaining) questions for a given dataset:

In [4]:
def all_possible_questions(data):
    """ Returns a list of all possible questions
    Each question consists of the column index and the value
    The question is then: dataset[column] >= value
    The last column will be skipped since it contains the labels!!!
    """
    result = list()
    
    # Loop over all columns except for the last column (since it contains the labels)
    for col_name in data.iloc[:,:-1].columns:
        # Take column from data set
        column = data[col_name]

        # For each unique value in the data set, create a possible question object
        # Each question stores the column index and the value (question: data[col] >= value)
        for val in column.unique():
            result.append(Question(col_name, val))
        
    return result

In [5]:
### Examples

# Example 1: get all questions:
questions = all_possible_questions(data_with_labels)
print(questions)

# Example 2: partition the dataset using one question (question 6: seats >= 5)
print("Question: ", questions[6])
left, right = questions[6].partition(data_with_labels)

print("Partition True:")
display(left)

print("Partition False:")
display(right)

[isRed >= 1, isRed >= 0, isBlack >= 0, isBlack >= 1, isYellow >= 0, isYellow >= 1, seats >= 5, seats >= 2, seats >= 25, seats >= 30]
Question:  seats >= 5
Partition True:


Unnamed: 0,isRed,isBlack,isYellow,seats,label
0,1,0,0,5,car
3,0,0,1,25,train
4,0,0,1,30,train


Partition False:


Unnamed: 0,isRed,isBlack,isYellow,seats,label
1,0,1,0,2,car
2,0,1,0,2,motorbike


### 1.1 Implement gini_impurity and information_gain
The slides contain information about both metrics. Please implement them in Python (this is a good exercise to practice some math, understand metrics and work comfortably with Pandas).

In [6]:
def gini_impurity(data):
    """ Calculate the Gini Impurity for a dataframe (make sure the last column contains the labels!!!)
    Please see the slides for the formula
    """
    # TODO: implement method
    return 0

def information_gain(left_data, right_data, current_gini):
    """ Calculate the information gain for the current split
    left_data and right_data is both a set of rows.
    current_gini is the gini impurity of the current node
    """
    # TODO: implement method
    return 0

In [7]:
###################
### Test methods: #
###################

# Test 1: gini impurity (should give value 0.64)
current_gini = gini_impurity(data_with_labels)
npt.assert_almost_equal(current_gini, 0.64)

# Tests part 2: gini impurity and information gain (using questions[6]: seats >= 5)
questions = all_possible_questions(data_with_labels)
left, right = questions[6].partition(data_with_labels)

# Test 2: gini impurity for partial sets should be: 0.444... and 0.5
npt.assert_almost_equal(gini_impurity(left), 0.444, decimal=3)
npt.assert_almost_equal(gini_impurity(right), 0.5)

# Test 3: information_gain should be: 0.17333...
ig = information_gain(left, right, current_gini)
npt.assert_almost_equal(ig, 0.1733, decimal=4)

AssertionError: 
Arrays are not almost equal to 7 decimals
 ACTUAL: 0
 DESIRED: 0.64

### 1.2 Determining the next best quesion for an arbitrary dataset
Now we will implement the method that is able to find the best next question. This method should work as follows:
- Calculate the gini impurity of the whole dataset
- Loop over all the questions and partion the dataset using the questions partition method
- If the question does not split the dataset in two, skip this question
- Else calculate the information gain
- Keep track of the question with the best information gain and return this question and the information gain

In [8]:
def find_best_next_question(questions, data):
    """Find the next best question using the information gain (see slides)
    This method should return the question and the information gain
    """
    
    # TODO: implement method
    return None, 0

In [9]:
###################
### Test methods: #
###################
questions = all_possible_questions(data_with_labels)

# Should give "isYellow" with information gain of 0.37333...
q, ig = find_best_next_question(questions, data_with_labels)

npt.assert_equal(str(q), 'isYellow >= 1')
npt.assert_almost_equal(ig, 0.3733, decimal=4)

AssertionError: 
Items are not equal:
 ACTUAL: 'None'
 DESIRED: 'isYellow >= 1'

### 1.3 Constructing the tree
The tree construction algorithm uses your implemented methods and two additional classes:

In [10]:
class Tree:
    """Normal branched node with a question and two branches
    """

    def __init__(self, question, true_node, false_node):
        self.question = question
        self.true_node = true_node
        self.false_node = false_node
    
    def __repr__(self):
        """ToString method for the tree (will generate dot-format string)
        """
        return 'digraph g {\n' + self.print_tree_dot() + '}'
        
    def print_tree_dot(self, prefix=''):
        """Helper method to print this tree object as a dot-string
        """
        result = '\tn{} [label="{}" shape=box];\n'.format(prefix, str(self.question))
        
        # Recursively add nodes
        result += self.true_node.print_tree_dot(prefix + 't')
        result += self.false_node.print_tree_dot(prefix + 'f')
        
        # Add connections
        result += '\tn{} -> n{} [label="true"];\n'.format(prefix, prefix + 't')
        result += '\tn{} -> n{} [label="false"];\n'.format(prefix, prefix + 'f')
        
        return result
        
class Leaf:
    """ Terminal node for a tree (each leaf contains data)
    """

    def __init__(self, data):
        self.data = data
    
    def probabilities(self):
        """Give a list of probabilities per label and the total label count
        """
        val_counts = self.data.iloc[:,-1].value_counts()
        prob = val_counts / val_counts.sum()
        return prob, val_counts.sum()
    
    def __repr__(self):
        """ToString method for the tree (will generate dot-format string)
        """
        return 'digraph g {\n' + self.print_tree_dot() + '}'
    
    def print_tree_dot(self, prefix=''):
        """Helper method to print this tree object as a dot-string
        """
        prob, counts = self.probabilities()
        return '\tn{} [label="{}"];\n'.format(prefix, 'Counts: ' + str(counts) + '\\n' + str(prob.to_dict()))

The classes represent the two types of nodes: a normal node with two branches and the leafs. The code might look complex because of the print function. In essence only the constructor is interesting for this exercise since it states what data is stored in the tree.

Please try to construct the tree using the following methods and print the result (the graph will be exported in graphviz format) in a graphviz online visualizer.

In [11]:
def construct_tree(data):
    """ This method will construct the decision tree
    Result will be a Tree object. A Tree object has a default print operation in the .dot format.
    This format can be visualized using Graphviz or one of the online graphviz websites
    """
    
    # Get all possible questions
    questions = all_possible_questions(data)
    
    # Get information gain and best question
    question, info_gain = find_best_next_question(questions, data)
    
    # When information gain = 0, no splits needed anymore
    if info_gain == 0:
        return Leaf(data)

    # Split and recursively repeat
    true_data, false_data = question.partition(data)
    
    true_branch = construct_tree(true_data)
    false_branch = construct_tree(false_data)

    # Return tree
    return Tree(question, true_branch, false_branch)

In [12]:
# Construct tree and print the graphviz result
tree = construct_tree(data_with_labels)
print(tree)

digraph g {
	n [label="Counts: 5\n{'car': 0.4, 'train': 0.4, 'motorbike': 0.2}"];
}


### 1.4 Classification
Build a classify method that processes one sample and predicts the correct label by following the questions in the tree.

In [13]:
def classify(tree, sample):
    """
    This method should return the outcome labels and their probabilities for a given sample
    """
    if isinstance(tree, Leaf):
        # TODO: Your leaf code here...
        return 'some-label'
    else:
        # TODO: Your tree branch code here...
        return 'some-recursion'

In [14]:
# Take the first row out of the dataframe and remove the label (should classify as car)
sample = data_with_labels.iloc[0,:-1]

classify(tree, sample)

'some-label'