Decision Tree Constructor
--------------------------
This notebook works through the construction of the DecisionTree and node classes created to solve problem 1. (It will eventually be imported into its own python module)

In [16]:
import numpy as np

In [183]:
class DecisionTree:
    """Build and store a decision tree, based on supplied training data. 
       Use this tree to predict classifications."""

    def __init__(self,treedepth=5,params=None):
        self.depth = treedepth
        
    def entropy(self,C,D,c,d):
        """Calculate entropy based on classifications above and below the splitrule"""
        if C != 0:
            Cfactor = -(C/(C+D))*np.log2(C/(C+D))
        else:
            Cfactor = 0
        if D != 0:
            Dfactor = -(D/(C+D))*np.log2(D/(C+D))
        else:
            Dfactor = 0
        if c != 0:
            cfactor = -(c/(c+d))*np.log2(c/(c+d))
        else:
            cfactor = 0
        if d != 0:
            dfactor = -(d/(c+d))*np.log2(d/(c+d))
        else:
            dfactor = 0
        H_left = Cfactor + Dfactor
        H_right = cfactor + dfactor
        H = ((C+D)*H_left + (c+d)*H_right)/(C+D+c+d)
        
        return H
    
    
    def segment(self,data,labels):
        totals = np.bincount(labels)
        
        # Quick safety check
        if len(labels) != len(data):
            print('ERROR: There must be the same number of labels as datapoints.')
        
        # Calculate the initial entropy, used to find info gain
        C,D = 0,0                      # C = in class left of split; D = not in class left of split
        c,d = totals[1],totals[0]      # c = in class right of split; d = not in class right of split
        H_i = self.entropy(C,D,c,d) # the initial entropy, before any splitting
        maxinfogain = 0
        splitrule = ['','']   # the split rule, given as the best feature, followed by the threshold value (<=)
        
        for feature_i in range(len(data)):
            # Order the data for determining ideal splits
            lbldat = np.concatenate(([data[:,feature_i]],[labels]),axis=0)
            fv = np.sort(lbldat.T,axis=0)
            lastfeature = np.array(['',''])
            
            C,D = 0,0                      # Reset the counters
            c,d = totals[1],totals[0]
            
            for point_i in range(len(fv)-1):
                
                # Update C,D,c,d to minmize runtime of entrop calc (keep at O(1) time)
                if fv[point_i,1] == 1:
                    C += 1
                    c -= 1
                elif fv[point_i,1] == 0:
                    D += 1
                    d -= 1
                else:
                    print("ERROR: Classifications can only be 0 or 1.")
                
                # Skip splitting values that are not separable
                if fv[point_i,0] == fv[point_i+1,0]:
                    continue
                else:
                    H_f = self.entropy(C,D,c,d)
                    infogain = H_i-H_f
                    if infogain > maxinfogain:
                        maxinfogain = infogain
                        splitrule = [feature_i,fv[point_i,0]]
        return splitrule
            
#    def train(self,data,labels):
        
#    def predict(self,data):
        
#        return predictions
    
    
    class Node:
        """Store a decision tree node, coupled in series to construct tree;
        includes a left branch, right branch, and splitrule"""
    
        def __init__(self,leftsplit,rightsplit,
                     splitrule,leaflabel=None):
            self.left = leftsplit
            self.right = rightsplit
            if splitrule and not leaflabel:
                self.rule = splitrule
            elif not splitrule and leaflabel:
                self.label = leaflabel
            else:
                print("ERROR: You may not give both a split rule and a leaf label.")


            
        

In [184]:
classifier = DecisionTree()

In [185]:
classifier.segment(np.array([[2,2,2,5],[1,1,1,2],[2,1,5,3],[8,1,2,1]]),np.array([1,1,0,1]))


 [0, 1]
