In [1]:
import numpy as np
import pandas as pd

In [134]:
class DecisionTree:
    def __init__(self, max_depth=2, min_size=None):
        self.max_depth = max_depth
        self.min_size = min_size
        self.root = None
        
    def fit(self, dataset, label):
        new_dataset = dataset.copy()
        new_dataset['label'] = label
        
        self.dataset = new_dataset.as_matrix()
        self.label = label.unique()
        
        if self.min_size is None:
            self.min_size = len(self.dataset)/10
            
        self.root = self._split_tree(self.dataset)
        self._split(self.root, 1)
    
    def predict(self, dataset):
        if self.root is None:
            raise "Decison Tree belum di fit"
            
        rows = dataset.as_matrix()
        
        return [self._predict(self.root, row) for row in rows]
            
                
    def _predict(self, node, row):
        if row[node['index']] < node['value']:
            if isinstance(node['left'], dict):
                return self._predict(node['left'], row)
            else:
                return node['left']
        else:
            if isinstance(node['right'], dict):
                return self._predict(node['right'], row)
            else:
                return node['right']

    def evaluate(self, test_data):
        pass
    
    def _calculate_gini_index(self, groups):
        instances = sum(len(group) for group in groups)
        gini = 0.0
        for group in groups:
            size = len(group)
            if size == 0:
                continue
            score = 0.0
            for class_val in self.label:
                p = [row[-1] for row in group].count(class_val) / size
                score += p * p
            gini += (1.0 - score) * (size / instances)
        return gini
    
    def _split_tree(self, dataset):
        b_index, b_value, b_score, b_groups = 999, 999, 999, None
        for index in range(len(dataset[0])-1):
            #Check if data is binary
            col_data = [data[index] for data in dataset]
            
            if (len(set(col_data)) == 2):
                groups = self._test_split(index, max(col_data))
                gini = self._calculate_gini_index(groups)
                if gini < b_score:
                    b_index, b_value, b_score, b_groups = index, row[index], gini, groups
                continue
                
            for row in dataset:
                groups = self._test_split(index, row[index], dataset)
                gini = self._calculate_gini_index(groups)
                if gini < b_score:
                    b_index, b_value, b_score, b_groups = index, row[index], gini, groups
        return {'index':b_index, 'value':b_value, 'groups':b_groups}
    
    def _test_split(self, index, value, dataset):
        left, right = list(), list()
        for row in dataset:
            if row[index] < value:
                left.append(row)
            else:
                right.append(row)
        return left, right

    def _split(self, node, depth):
        left, right = node['groups']
        del(node['groups'])
        # check for a no split
        if not left or not right:
            node['left'] = node['right'] = self._to_terminal(left + right)
            return
        # check for max depth
        if depth >= self.max_depth:
            node['left'], node['right'] = self._to_terminal(left), self._to_terminal(right)
            return
        # process left child
        if len(left) <= self.min_size:
            node['left'] = self._to_terminal(left)
        else:
            node['left'] = self._split_tree(left)
            self._split(node['left'], depth+1)
        # process right child
        if len(right) <= self.min_size:
            node['right'] = self._to_terminal(right)
        else:
            node['right'] = self._split_tree(right)
            self._split(node['right'], depth+1)
            
    def _to_terminal(self, group):
        outcomes = [row[-1] for row in group]
        return max(set(outcomes), key=outcomes.count)