In [76]:
import pandas as pd
import numpy as np
from pprint import pprint

In [3]:
class my_dt:
    
    def __init__(self, max_depth=8, min_impurity_decrease=0, min_samples=2):
        self.max_depth = int(max_depth)
        self.min_impurity_decrease = min_impurity_decrease
        self.min_samples = int(min_samples)
    
    def check_purity(self,training_data):
        label = training_data[:,-1]
        unique_label = np.unique(label)
        
        if len(unique_label) == 1:
            return True
        else:
            return False
    
    def classify_data(self, training_data):
        label_col = training_data[:,-1]
        unique_label, count_unique_label = np.unique(label_col, return_counts = True)
        
        col_index = count_unique_label.argmax()
        classification = unique_label[col_index]
                
        return classification
    
    
    def calculate_gini(self, training_data):
        label = training_data[:,-1]
        _, unique_label_counts = np.unique(label, return_counts = True)
        
        probability = unique_label_counts / (unique_label_counts).sum()
        gini = 0
        gini = 1 - sum(probability ** 2)
        
        return gini
        
    
    def calculate_overall_gini(self, data_above, data_below):
        n = len(data_above) + len(data_below)
        p_data_above = len(data_above) / n
        p_data_below = len(data_below) / n
        
        overall_gini = (p_data_above * self.calculate_gini(data_above)+
                    p_data_below * self.calculate_gini(data_below))

        return overall_gini
    
    
    def get_potential_splits(self,data):
    
        potential_splits = {}
        _, n_columns = data.shape
        for column_index in range(n_columns - 1):        # excluding the last column which is the label
            potential_splits[column_index] = []
            values = data[:, column_index]
            unique_values = np.unique(values)

            for index in range(len(unique_values)):
                if index != 0:
                    current_value = unique_values[index]
                    previous_value = unique_values[index - 1]
                    potential_split = (current_value + previous_value) / 2
                
                    potential_splits[column_index].append(potential_split)
    
        return potential_splits
    
    
    def split_data(self, training_data, split_column, split_value):
        split_column_values = training_data[:,split_column]
        
        data_below = training_data[split_column_values <= split_value]
        data_above = training_data[split_column_values >  split_value]
    
        return data_below, data_above
    
    
        
        
    def determine_best_split(self, training_data, potential_splits):
        best_gini = 9999
        
        for column_index in potential_splits:
            for value in potential_splits[column_index]:
                data_below, data_above = self.split_data(training_data, split_column=column_index, split_value=value)
                current_overall_gini = self.calculate_overall_gini(data_below, data_above)

                if current_overall_gini <= best_gini:
                    best_gini = current_overall_gini
                    best_split_column = column_index
                    best_split_value = value
    
        return best_split_column, best_split_value        
        
        
    def treeAlgorithm(self, data_train, counter = 0):
        # data preparations
        if counter == 0:
            global COLUMN_HEADERS
            COLUMN_HEADERS = data_train.columns
            data = data_train.values
        else:
            data = data_train
            
        # base cases
        if (self.check_purity(data)) or (len(data) < self.min_samples) or (counter == self.max_depth):
            classification = self.classify_data(data)
        
            return classification

    
        # recursive part
        else:    
            counter += 1

            # helper functions 
            potential_splits = self.get_potential_splits(data)
            split_column, split_value = self.determine_best_split(data, potential_splits)
            data_below, data_above = self.split_data(data, split_column, split_value)

            # instantiate sub-tree
            feature_name = COLUMN_HEADERS[split_column]
            question = "{} <= {}".format(feature_name, split_value)
            sub_tree = {question: []}

            # find answers (recursion)
            yes_answer = self.treeAlgorithm(data_below, counter)
            no_answer = self.treeAlgorithm(data_above, counter)

            # If the answers are the same, then there is no point in asking the qestion.
            # This could happen when the data is classified even though it is not pure
            # yet (min_samples or max_depth base cases).
            if yes_answer == no_answer:
                sub_tree = yes_answer
            else:
                sub_tree[question].append(yes_answer)
                sub_tree[question].append(no_answer)

            return sub_tree