In [1]:
from scipy import stats
import numpy as np

## Decision Tree

In [6]:
class decision_tree():
    def __init__(self):
        self.tree = None
        
        
        
    def gini(self, target):
        vals, counts = np.unique(target, return_counts=True)
        gini_parts = []
        for c in counts:
            gini_parts.append((-(c / np.sum(counts))**2))
        return sum(gini_parts)

    def find_impurities(self, data, feature, target):

        #calculate parent gini
        start_gini = self.gini(target)
        
        #get unique values and how many
        values, counts = np.unique(data[:, feature], return_counts=True)
        #print('target',len(target))
        #print('data',len(data))
        
        #weight the child leaves by how many values in the leaf
        weighted_gini_impurity_list=[]
        for ind in range(len(values)):
            weight = counts[ind]/np.sum(counts)
            gini = self.gini(target[np.where(data[:,feature]==values[ind])])
            weighted_gini = weight* gini
            weighted_gini_impurity_list.append(weighted_gini)
        
        gain = start_gini - np.sum(weighted_gini_impurity_list)
        
        
        return gain
        
    def build_tree(self, new_data, orig_data, features,new_target,orig_target,parent_class=None):
        
        #need to deal with outlier cases eg no samples in a node
        # also need to stop building trees when there is only one class 
        
        #if only one class then return it
        if len(np.unique(new_target))<=1:
            return np.unique(new_target)[0]
        
        #if no data and the parent class is none then break because no data
        elif len(new_data)==0 and parent_class==None:
            print('please input data')
        
        #if no data in leaf and we have a value for parent class we will just use that value
        elif len(new_data)==0:
            return parent_class
        
        else:
            parent_class = np.argmax(np.unique(new_target,return_counts=True)[1])
            
            #get all of the feature impurities
            impurities = [self.find_impurities(new_data, f,new_target) for f in features]
            best_gain = features[np.argmax(impurities)]
            
            
            #build tree
            tree ={best_gain:{}}
            
            #remove feature we just used
            features = [f for f in features if f != best_gain]
            
            for value in np.unique(new_data[:,best_gain]):
                
                child_data = new_data[np.where(new_data[:,best_gain]==value)]
                child_target = new_target[np.where(new_data[:,best_gain]==value)]
                
                
                child_tree = self.build_tree(child_data,orig_data,features,child_target,orig_target, parent_class)
                
                tree[best_gain][value]=child_tree
                
            self.tree=tree
            return tree
    
    
    def predict(self,sample,fall_back):
        for feature in range(len(sample[0])):
            if feature in list(self.tree.keys()):
                try:
                
                    result =self.tree[feature][sample[0][feature]]
                
                
                except:
                    return fall_back
        
                result = self.tree[feature][sample[0][feature]]    
            
            if type(result)== dict:
                return re(sample,result)
            
            else:
                return result

In [8]:
example_data= np.array([[1,1,0,1],#1
                        [1,0,1,1],#0
                        [0,1,1,1],#1
                        [0,1,1,0],
                       [0,1,1,0],
                       [1,1,1,1],
                       [0,1,0,0],
                       [0,0,1,0]])
                    

target_data = np.array([1,0,1,0,0,1,1,1,0,0,0,0])

d= decision_tree()
tree=d.build_tree(example_data,example_data,[0,1,2,3],target_data,target_data)
tree

{2: {0: 1, 1: {3: {0: {1: {0: 1, 1: 0}}, 1: {1: {0: 0, 1: 1}}}}}}