# Classification Trees

Gini Impurity:

$$L(\mathcal{R}_m) = \sum_{k = 1}^K \hat{p}_{mk}(1-\hat{p}_{mk})$$

Cross entropy:

$$L(\mathcal{R}_m) = -\sum_{k = 1}^K \hat{p}_{mk}\log_2(\hat{p}_{mk}) $$

Minimize weighted loss of children 

$$
\frac{|\mathcal{R}_{c1}|L(\mathcal{R}_{c1}) + |\mathcal{R}_{c2}|L(\mathcal{R}_{c2})}{|\mathcal{R}_{c1}| + |\mathcal{R}_{c2}|}
$$

Note: one cheat--using `combinations` from `itertools`.

In [75]:
import numpy as np 
from itertools import combinations
import matplotlib.pyplot as plt
import seaborn as sns


In [76]:
penguins = sns.load_dataset('penguins')
penguins.dropna(inplace = True)
X = np.array(penguins.drop(columns = 'species'))
y = np.array(penguins['species'])

In [77]:
def gini_impurity(y):
    size = len(y)
    classes, counts = np.unique(y, return_counts = True)
    pmk = counts/size
    return np.sum(pmk*(1-pmk))
     
def cross_entropy(y):
    size = len(y)
    classes, counts = np.unique(y, return_counts = True)
    pmk = counts/size
    return -np.sum(pmk*np.log2(pmk))

def split_loss(child1, child2, loss = cross_entropy):
    return (len(child1)*loss(child1) + len(child2)*loss(child2))/(len(child1) + len(child2))

In [79]:
def all_possible_splits(x):
    indices = []
    for i in range(1, 2**len(x)-1):
        list_ = [bool(int(j)) for j in bin(i)[2:]]
        falses = [False for j in range(len(x) - len(list_))]
        indices.append(falses + list_)
    return [np.array(x)[j] for j in indices]

In [80]:
def all_rows_equal(X):
    return (X == X[0]).all()

def all_possible_splits(x):
    L_values = []
    for i in range(1, len(x)):
        for combo in combinations(x, i):
            L_values.append(list(combo))
    return L_values

all_possible_splits([1,3,4])

[[1], [3], [4], [1, 3], [1, 4], [3, 4]]

In [81]:
class Node:
    
    def __init__(self, Xsub, ysub, ID, depth = 0, parent_ID = None, leaf = True):
        self.ID = ID
        self.Xsub = Xsub
        self.ysub = ysub
        self.size = len(ysub)
        self.depth = depth
        self.parent_ID = parent_ID
        self.leaf = leaf
        

class Splitter:
    
    def __init__(self):
        self.loss = np.inf
        
    def replace_split(self, loss, parent_ID, d, dtype = 'quant', t = None, L_values = None):
        self.loss = loss
        self.parent_ID = parent_ID
        self.d = d
        self.dtype = dtype
        self.t = t
        self.L_values = L_values        

In [82]:
class DecisionTreeClassifier:
    
    #############################
    ######## 1. TRAINING ########
    #############################
    
    ######### FIT ##########
    def fit(self, X, y, loss_func = cross_entropy, max_depth = 100, min_size = 2):
        
        ## Add data
        self.X = X
        self.y = y
        self.N, self.D = self.X.shape
        dtypes = [np.array(list(self.X[:,d])).dtype for d in range(self.D)]
        self.dtypes = ['quant' if (dtype == float or dtype == int) else 'cat' for dtype in dtypes]

        ## Add model parameters
        self.loss_func = loss_func
        self.max_depth = max_depth
        self.min_size = min_size
        
        ## Initialize nodes
        self.nodes_dict = {}
        self.current_ID = 0
        initial_node = Node(Xsub = X, ysub = y, ID = self.current_ID, parent_ID = None)
        self.nodes_dict[self.current_ID] = initial_node
        self.current_ID += 1
        
        # Build
        self.build()
        
        # Calculate leaf modes
        self.get_leaf_modes()
     
    
    ###### FIND SPLIT ######
    def find_split(self, eligible_parents):
        
        ## Instantiate splitter
        splitter = Splitter()
        
        ## For each eligible parent node...
        for parent_ID, parent in eligible_parents.items():
            ysub = parent.ysub
            
            ## For each predictor...
            for d in range(self.D):
                Xsub_d = parent.Xsub[:,d]
                dtype = self.dtypes[d]
                if len(np.unique(Xsub_d)) == 1:
                    continue
                    
                ## For each value...
                if dtype == 'quant':
                    for t in np.unique(Xsub_d)[:-1]:
                        ysub_L = ysub[Xsub_d <= t]
                        ysub_R = ysub[Xsub_d > t]
                        loss = split_loss(ysub_L, ysub_R, loss = self.loss_func)
                        if loss < splitter.loss:
                            splitter.replace_split(loss, parent_ID, d, 'quant', t = t)
                else:
                    for L_values in all_possible_splits(np.unique(Xsub_d)):
                        ysub_L = ysub[np.isin(Xsub_d, L_values)]
                        ysub_R = ysub[~np.isin(Xsub_d, L_values)]
                        loss = split_loss(ysub_L, ysub_R, loss = self.loss_func)
                        if loss < splitter.loss: 
                            splitter.replace_split(loss, parent_ID, d, 'cat', L_values = L_values)
        ## Save splitter
        self.splitter = splitter
    
    ###### MAKE SPLIT ######
    def make_split(self):
        ## Update parent nodes
        parent_node = self.nodes_dict[self.splitter.parent_ID]
        parent_node.leaf = False
        parent_node.child_L = self.current_ID
        parent_node.child_R = self.current_ID + 1
        parent_node.d = self.splitter.d
        parent_node.dtype = self.splitter.dtype
        parent_node.t = self.splitter.t
        parent_node.L_values = self.splitter.L_values
        
        ## Get X and y data for children
        if parent_node.dtype == 'quant':
            L_condition = parent_node.Xsub[:,parent_node.d] <= parent_node.t
     
        else:
            L_condition = np.isin(parent_node.Xsub[:,parent_node.d], parent_node.L_values)
        Xchild_L = parent_node.Xsub[L_condition]
        ychild_L = parent_node.ysub[L_condition]
        Xchild_R = parent_node.Xsub[~L_condition]
        ychild_R = parent_node.ysub[~L_condition]


        
        ## Create child nodes
        child_node_L = Node(Xchild_L, ychild_L, depth = parent_node.depth + 1,
                            ID = self.current_ID, parent_ID = parent_node.ID)
        child_node_R = Node(Xchild_R, ychild_R, depth = parent_node.depth + 1,
                            ID = self.current_ID+1, parent_ID = parent_node.ID)
        self.nodes_dict[self.current_ID] = child_node_L
        self.nodes_dict[self.current_ID + 1] = child_node_R
        self.current_ID += 2
    
    ###### BUILD TREE ######
    def build(self):
        
        eligible_parents = self.nodes_dict
        while True:
                                    
            ## Find split among eligible parent nodes
            self.find_split(eligible_parents)
            
            ## Make split
            self.make_split()
            
            ## Find eligible nodes for next iteration
            eligible_parents = {ID:node for (ID, node) in self.nodes_dict.items() if 
                                (node.leaf == True) &
                                (node.depth < self.max_depth) &
                                (node.size >= self.min_size) & 
                                (~all_rows_equal(node.Xsub))}
            
            ## Quit if no more eligible parents
            if len(eligible_parents) == 0:
                break
                
                
    ###### LEAF MEANS ######
    def get_leaf_modes(self):
        self.leaf_modes = {}
        for node_ID, node in self.nodes_dict.items():
            if node.leaf:
                values, counts = np.unique(node.ysub, return_counts=True)
                self.leaf_modes[node_ID] = values[np.argmax(counts)]
            
    #############################
    ####### 2. PREDICTING #######
    #############################
    
    ####### PREDICT ########
    def predict(self, X_test):
        yhat = []
        for x in X_test:
            node = self.nodes_dict[0] 
            while not node.leaf:
                if node.dtype == 'quant':
                    if x[node.d] <= node.t:
                        node = self.nodes_dict[node.child_L]
                    else:
                        node = self.nodes_dict[node.child_R]
                else:
                    if x[node.d] in node.L_values:
                        node = self.nodes_dict[node.child_L]
                    else:
                        node = self.nodes_dict[node.child_R]
            yhat.append(self.leaf_modes[node.ID])
        return np.array(yhat)
            


In [83]:
test_frac = 0.25
test_size = int(len(y)*test_frac)
test_idxs = np.random.choice(np.arange(len(y)), test_size, replace = False)
X_train = X[~test_idxs]
y_train = y[~test_idxs]
X_test = X[test_idxs]
y_test = y[test_idxs]


In [84]:
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train, max_depth = 10, min_size = 10)
y_test_hat = tree.predict(X_test)

In [85]:
np.mean(y_test_hat == y_test)

0.963855421686747