In [148]:
import numpy as np
from collections import Counter

data = np.loadtxt("diabetes.csv", delimiter=",")
data = data[:150]
X = data[:, :-1]
y = data[:, -1]

In [149]:
class Node:
    def __init__(self, ids=[], children=[], entropy=0, depth=0):
        self.ids = ids
        self.children = children
        self.entropy = entropy
        self.depth = depth
        
        self.split_attribute = 0
        self.threshold = 0
        self.label = 0
        
    def set_properties(self, split_attribute, threshold):
        self.split_attribute = split_attribute
        self.threshold = threshold
        
    def set_label(self, label):
        self.label = label

In [150]:
class Tree:
    def __init__(self, max_depth=10, min_samples_split=2, min_gain=1e-4):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_gain = min_gain
        
        self.root = None
        self.Ntrain = 0
        self.attributes = []
        self.labels = []
        self.X = None
        self.y = None
        
    def fit(self, X, y):
        self.X = X
        self.y = y
        self.attributes = list(range(X[0].shape[0]))
        self.labels = np.unique(y)
        self.Ntrain = len(X)
        
        ids = list(range(self.Ntrain))
        self.root = Node(ids=ids, entropy=self._entropy(ids), depth=0)
        
        queue = [self.root]
        while queue:
            node = queue.pop()
            if node.depth < self.max_depth or node.entropy < self.min_gain:
                node.children = self._split(node)
                if not node.children:
                    self._set_label(node)
                queue += node.children
            else:
                self._set_label(node)
        
    def _entropy(self, ids):
        if len(ids) == 0: return 0
        uni, cnt = np.unique(self.y[ids], return_counts="true")
        ratio = cnt/sum(cnt)
        return -np.sum(ratio*np.log(ratio))
    
    def _set_label(self, node):
        uni, cnt = np.unique(self.y[node.ids], return_counts="true")
        most_common = Counter(dict(zip(uni, cnt))).most_common()[0][0]
        node.set_label(most_common)
        
    def _split(self, node):
        ids = node.ids
        best_gain = 0
        best_splits = []
        best_attribute = 0
        best_threshold = 0
        node_X = self.X[ids]
        
        for att in self.attributes:
            best_HxS = 1000
            splits = []
            threshold = 0
            att_X = node_X[:, att]
            
            for val in np.unique(att_X):
                splits = []
                threshold = val
                splits.append([i for i, x in enumerate(att_X) if x < threshold])
                splits.append([i for i, x in enumerate(att_X) if x >= threshold])
                if min(map(len, splits)) < self.min_samples_split: continue
                HxS = 0
                for split in splits:
                    HxS += len(split)*self._entropy(split)/len(ids)
                if HxS < best_HxS: 
                    best_HxS = HxS
                    
            gain = node.entropy - best_HxS
            if gain < self.min_gain: continue
            if gain > best_gain:
                best_gain = gain
                best_splits = splits
                best_attribute = att
                best_threshold = threshold
        
        print(best_splits, best_gain, best_attribute, best_threshold)
        node.set_properties(best_attribute, best_threshold)
        return [Node(ids=split,
                     entropy=self._entropy(split), depth=node.depth+1) for split in best_splits]
                    
        
    def predict(self, X):
        n = len(X)
        predict = []
        
        for i in range(n):
            node = self.root
            while node.children:
                if X[i][node.split_attribute] <= node.threshold:
                    node = node.children[0]
                else:
                    node = node.children[1]
            predict.append(node.label)
        
        return predict
        
    
tree = Tree(max_depth=5)
tree.fit(X[:10], y[:10])
print(tree.predict(X[:5]))
print(y[:5])


[[0, 1, 2, 3, 4, 5, 6, 7, 9], [8]] 0.4228104552401626 1 197.0
[] 0 0 0
[[0, 1, 3, 4, 5, 6, 7, 8], [2]] 0.408960230187219 1 183.0
[] 0 0 0
[[0, 1, 2, 3, 4, 5, 7], [6]] 0.38039566584857787 0 10.0
[] 0 0 0
[[0, 1, 2, 3, 4, 5], [6]] 0.36157373634686696 0 10.0
[] 0 0 0
[[0, 1, 3, 4, 5], [2]] 0.6931471805599453 1 183.0
[] 0 0 0
[0.0, 0.0, 0.0, 0.0, 0.0]
[1. 0. 1. 0. 1.]
