In [42]:
from sklearn.datasets import make_classification

x, y = make_classification(n_classes=5, n_informative=7, n_features=10, n_redundant=2)

In [43]:
import numpy as np
from IPython.core.debugger import Tracer
from collections import namedtuple

SplitResult = namedtuple('SplitResult', 'left right')
Dataset = namedtuple('Dataset', 'x y')

def split_dataset(x, y, index, value):
    """Splits matrix in two parts based on row `value` at `index` parameter
    
    Returns
    -------
    split
        Matrix with the first row containing a subset of x and y where x[:, index] < value 
        and a complement of this subset in the second row"""

    index_set = np.argwhere(x[:, index] < value).flatten()
    index_set_compl = np.argwhere(x[:, index] >= value).flatten()
    return SplitResult(left = Dataset(x = x[index_set], y = y[index_set]), 
                       right = Dataset(x = x[index_set_compl], y = y[index_set_compl]))

assert len(split_dataset(x, y, 1, 1).left.x) + len(split_dataset(x, y, 1, 1).right.x) == len(x)

In [6]:
from collections import Counter

def gini_impurity(split, classes):
    """Calculates Gini impurify coefficient.
    
    Parameters
    ----------
    split
        intended to be passed as a result of the `split_dataset` function
    
    See Also
    --------
    split_dataset(x, y, index, value)
    https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity"""
    
    gini = 0
    
    for class_val in classes:
        for group in split:
            group_size = len(group.y)
            
            if group_size == 0:
                continue
            
            counts = Counter(group.y)
            proportion = counts[class_val] / group_size
            gini += proportion * (1.0 - proportion)
            
    return gini
            
assert gini_impurity([Dataset(x = [1, 1], y = [0, 1]), Dataset(x = [1, 1], y = [0, 1])], [0, 1]) == 1.0
assert gini_impurity([Dataset(x =[1, 1], y = [1, 1]), Dataset(x = [1, 1], y = [0, 0])], [0, 1]) == 0.0
gini_impurity([Dataset(x =[1, 1], y = [1, 1]), Dataset(x = [1, 1], y = [1, 1])], [0, 1])

0.0

In [83]:
from collections import namedtuple

BestSplit = namedtuple('BestSplit', 'data split_index split_value gini')

def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

def best_split(x, y, min_samples_split = 2):
    """Finds best split of a dataset based on gini index.
    
    Returns
    -------
    result : dict
        a dictionary containing best gini impurity and best split"""
    
    classes = np.unique(y)
    
    result = BestSplit(data = None, split_index = None, split_value = None, gini = 2)
    x_shuffled, y_shuffled = unison_shuffled_copies(x, y)
    
#     range(0, x.shape[1])
    for feature_index in np.random.choice(x_shuffled.shape[1], x_shuffled.shape[1], replace=False):
        for row in x_shuffled:
            split = split_dataset(x_shuffled, y_shuffled, feature_index, row[feature_index])
            gini = gini_impurity(split, classes)
            
            if gini <= result.gini and len(split.left.x) > min_samples_split and len(split.right.x) > min_samples_split:
                result = BestSplit(data = split,
                                    split_index = feature_index,
                                    split_value = row[feature_index],
                                    gini = gini)
    
    if not result.data:
        result = BestSplit(data = SplitResult(left = Dataset(x = x, y = y), right = Dataset(x = np.array([]), y = np.array([]))), split_index = None, split_value = None, gini = 2)
    
    return result

%time best_split(x,y, min_samples_split=5)

CPU times: user 184 ms, sys: 4 ms, total: 188 ms
Wall time: 188 ms


BestSplit(data=SplitResult(left=Dataset(x=array([[ 0.58759549, -6.94341498,  2.53882828,  3.60975326, -1.61388014,
         1.39415801,  3.53219557, -1.6982606 ,  2.13210981, -0.54818773],
       [ 0.29224537, -8.60359534,  0.50526225,  2.3466257 , -1.79625822,
         3.09083045,  1.53324366, -1.60230614,  3.62556566, -4.08966898],
       [ 0.69356403, -6.48001981,  2.7347695 ,  2.44066031, -0.0695815 ,
         2.75726073,  3.3482061 , -1.70481741,  2.20948264,  2.18424294],
       [ 0.08127334, -5.93375665, -0.80917876,  3.22089041, -0.06811336,
         1.50529453,  0.97895794, -2.58596952,  2.5378343 , -2.59100528],
       [-1.04981844, -7.34857147,  3.12623415, -1.1753704 , -0.53394615,
         4.95152394, -4.47126995, -2.65177403,  0.67338379, -2.78484908],
       [ 0.54254684, -5.8492297 ,  1.1050804 , -1.07280896,  1.86876675,
         3.47585174, -5.60668287, -3.01405845,  0.07874924, -4.22813783]]), y=array([2, 2, 0, 2, 2, 0])), right=Dataset(x=array([[ -2.54924486e+00,   

In [93]:
class Node(object):
    def __init__(self, parent, is_leaf, depth, split = None, left = None, right = None):
        self.split = split
        self.parent = parent
        self.is_leaf = is_leaf
        self.left = left
        self.right = right
        self.depth = depth
        
    def create_child(self, is_leaf, left, min_samples_split=2):
        """Create a child of this node
        
        Parameters
        ----------
        is_leaf : bool
            the child should be a leaf
        left : bool
            left or right child
        """
        result = Node(parent = self,
                      is_leaf = is_leaf,
                      depth = self.depth + 1)
        
        if not is_leaf:
            if left:
                result.split = best_split(self.split.data.left.x, self.split.data.left.y, min_samples_split=min_samples_split)
            elif not left:
                result.split = best_split(self.split.data.right.x, self.split.data.right.y, min_samples_split=min_samples_split)
            
#             print(result.split)
            
            # if the best split is a no-split then we have a terminal node
            if (result.split.data.left.x.size == 0) or (result.split.data.right.x.size == 0):
                result.is_leaf = True

        return result
    
    @property
    def is_leaf(self):
        return self.__is_leaf
    
    @is_leaf.setter
    def is_leaf(self, value):
        self.__is_leaf = value
        
        if value == True:
            self._finalize_leaf()
    
    def _finalize_leaf(self):
        data = np.concatenate([self.split.data.left.y, self.split.data.right.y], axis = 0).astype(int)
        self.leaf_value = np.bincount(data).argmax()

class ClassificationTree(object):
    def __init__(self, max_depth, min_samples_split):
        self.max_depth=max_depth
        self.min_samples_split=min_samples_split
        self.root_node = None
        
    def fit(self, x, y):
        node_stack = []
        
        # create root node and push it to the recusion stack
        # TODO no-split check
        self.root_node = Node(split = best_split(x, y, min_samples_split=self.min_samples_split),
                    parent = None, 
                    is_leaf = False, 
                    depth = 0)
        
        node_stack.append(self.root_node)
        
        while node_stack:
            node = node_stack.pop()
            
            # if node became a leaf as a result of a no-split (see Node.create_child)
            if node.is_leaf:
                continue
            
            # stop if we have reached maximum depth
            if node.depth >= self.max_depth:
                node.is_leaf = True
                continue
            
            # Create children. Child should be a leaf if min_node_size constraint is unfitfulled
            # TODO - BUG do not use is_leaf but call finalize, need to terminate CURRENT node
            if (len(node.split.data.left.x) >= self.min_samples_split):
                node.left = node.create_child(is_leaf=False, 
                                              left=True, min_samples_split=self.min_samples_split)
                node_stack.append(node.left)
            else:
                node.is_leaf = True
            
            if (len(node.split.data.right.x) >= self.min_samples_split):
                node.right = node.create_child(is_leaf=False, left=False, min_samples_split=self.min_samples_split)
                node_stack.append(node.right)
            else:
                node.is_leaf = True
            
            
            # TODO make BestSplit a class so we can delete redundant split data
            # del node.split
            
    def predict_vector(self, x):
        if not self.root_node:
            raise Exception('You should call fit(x, y) first')
            
        node_stack = [self.root_node]
        
        while node_stack:
            node = node_stack.pop()
            
            if node.is_leaf:
                return node.leaf_value
            
            if x[node.split.split_index] < node.split.split_value:
                node_stack.append(node.left)
            else:
                node_stack.append(node.right)
                
    def predict(self, x):
        results = []
        
        for row in x:
            results.append(self.predict_vector(row))
            
        return results
        
    def print_tree(self):
        node_stack = [self.root_node]
        
        while node_stack:
            node = node_stack.pop()
            
            if node.is_leaf:
                print("%s terminal - class %d" % ('-' * node.depth, node.leaf_value))
            else:
                print("%s feature[%d] > %f" % ('-' * node.depth, node.split.split_index, node.split.split_value))
                node_stack.append(node.left)
                node_stack.append(node.right)
                
            
tree = ClassificationTree(max_depth=20, min_samples_split=5)
tree.fit(x, y)
tree.print_tree()

0
1
2
2
3
3
4
5
5
6
7
8
8
7
8
9
10
11
12
12
13
13
11
10
9
8
6
4
1
 feature[1] > -5.456948
- feature[9] > 2.509265
-- terminal - class 3
-- feature[1] > 3.416294
--- terminal - class 1
--- feature[8] > -1.989855
---- feature[3] > 3.251977
----- terminal - class 0
----- feature[2] > -2.303831
------ feature[4] > 1.118338
------- feature[2] > -0.061067
-------- terminal - class 4
-------- terminal - class 2
------- feature[1] > -2.965528
-------- feature[8] > -0.521344
--------- feature[9] > -1.091357
---------- feature[2] > -1.055419
----------- feature[3] > 0.203131
------------ terminal - class 3
------------ feature[0] > -0.440908
------------- terminal - class 4
------------- terminal - class 3
----------- terminal - class 0
---------- terminal - class 1
--------- terminal - class 4
-------- terminal - class 0
------ terminal - class 1
---- terminal - class 0
- terminal - class 2


In [99]:
from sklearn.metrics import accuracy_score, classification_report
from sklearn.cross_validation import train_test_split

tree = ClassificationTree(max_depth=2000, min_samples_split=0)
tree.fit(x, y)
print(accuracy_score(y, tree.predict(x)))
print(classification_report(y, tree.predict(x)))

0
1
2
2
3
4
5
6
6
7
7
8
8
9
9
10
10
11
12
12
11
12
13
13
14
14
15
15
16
16
17
17
18
18
19
19
20
20
21
22
23
24
25
26
27
27
28
28
29
30
30
31
32
32
33
33
34
34
35
35
36
36
37
38
39
40
41
42
43
43
44
44
45
46
47
48
48
49
49
50
50
51
51
52
53
54
55
56
57
58
59
60
61
62
63
64
64
65
66
67
67
68
68
69
69
70
70
71
71
72
73
73
74
74
75
75
76
77
78
78
79
79
80
80
81
81
82
82
83
83
84
85
85
86
86
87
87
84
77
76
72
66
67
67
65
66
66
67
67
63
62
61
62
62
60
59
58
57
56
55
54
53
52
47
46
45
42
41
40
39
40
40
38
37
31
29
26
25
24
23
22
23
23
21
12
13
13
5
6
6
4
5
5
6
6
3
1
2
2
1.0
             precision    recall  f1-score   support

          0       1.00      1.00      1.00        21
          1       1.00      1.00      1.00        21
          2       1.00      1.00      1.00        20
          3       1.00      1.00      1.00        20
          4       1.00      1.00      1.00        18

avg / total       1.00      1.00      1.00       100



In [60]:
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier()
clf.fit(x, y)
print(accuracy_score(y, clf.predict(x)))
print(classification_report(y, clf.predict(x)))

1.0
             precision    recall  f1-score   support

          0       1.00      1.00      1.00        21
          1       1.00      1.00      1.00        21
          2       1.00      1.00      1.00        20
          3       1.00      1.00      1.00        20
          4       1.00      1.00      1.00        18

avg / total       1.00      1.00      1.00       100

