# A generic tree builder

In [52]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.datasets import load_iris

import warnings

In [3]:
%matplotlib inline


In [103]:
func_dict = {'sum':np.sum,
             'mean':np.mean,
             'median':np.median,
             'min':np.min,
             'max':np.max
            }

spacer_dict = {'spacer1': {'init':'|~~','next':'|  '},
               'spacer2': {'init':'-','next':'-'},
               'spacer3': {'init':'--','next':' -'},
               'spacer4': {'init':'|--','next':'|  '},
               'spacer5': {'init':'|--','next':'+  '}
              }

class TreeNode(object):

    def __init__(self, parent, depth):

        self.parent = parent
        self.depth = depth
        self.children = []
        self.value = []
        self.is_end = False
    
    def expand(self, x, evaluation_func):
        try:
            splits = evaluation_func(x)
        except Exception as e:
            print('Evaluation function is ill defined! Setting splits as None!')
            print('Caugth Exception:',e)
            splits = None
            
        if splits is None:
            self.is_end = True
            
        else:
            for split in splits:
                node = TreeNode(self,depth=self.depth+1)
                node.value = split
                self.children.append(node)            
    
    def expand_recursive(self, x, evaluation_func):
        self.expand(x, evaluation_func)
        if not self.is_leaf:
            for child in self.children:
                child.expand(x, evaluation_func)
    
    def collect(self, collector_func):
        values = 0
        for val in self.value:
            try:
                values += collector_func(val)
            except Exception as e:
                print('Collector function is ill defined! Setting value as 0!')
                values += 0
        return values
    
    def collect_recursive(self, collector_func):
        values = 0
        if not self.is_end:
            values = self.collect_recursive(collector_func)
        values += self.collect(collector_func)
        return values        
    
    def update(self, value, value_update_func):
        try:
            updates = value_update_func(value)
        except Exception as e:
            print('Value Update function is ill defined! Setting updates as None!')
            updates = None
        
        if updates is not None:
            self.value = updates

    def update_recursive(self, value, value_update_func):
        if self.parent:
            self.parent.update_recursive(value, value_update_func)
        self.update(value, value_update_fun)
    
    def predict_node(self, x, prediction_func):
        if self.is_leaf:
            return prediction_func(x)
        predictions = []
        for child in self.children:
            predictions.append(child.predict_node(x, prediction_func))
        
        
    def __repr__(self, spacer_type='spacer1'):
        if self.depth == 0:
            spacer = spacer_dict[spacer_type]['init']
        else:
            spacer = spacer_dict[spacer_type]['next'] * (self.depth) + spacer_dict[spacer_type]['init']

        s = spacer + 'Node at depth {} has {} children and Value: \n{}{}\n'.format((self.depth + 1), self.children, spacer,
                                                                                   self.value)
        if not self.is_leaf:
            for i in self.children:
                s += i.__repr__()
        else:
            s = spacer_dict[spacer_type]['next'] + '*'*(self.depth+1) + spacer_dict[spacer_type]['next']
            return s
        return s

    def is_leaf(self):
        return self.is_end

    def is_root(self):
        return self.parent is None
    
                
class Tree(object):
    
    def __init__(self, split_func=None, agg_func=None, na_func=None):
    
        if split_func is None:
            def random_50_split(x):
                if x.shape[0] < 5:
                    return None
                spl = np.random.choice(x.shape[0], int(np.ceil(x.shape[0]*0.5)))
                return x[spl,:],x[~spl,:]
            self.split_func = random_50_split
        else:
            self.split_func = func

        if agg_func is None:
            self.agg_func = lambda x: np.sum(x)
        else:
            self.agg_func = func
        
        self.tree = None
    
    def fit_data(self, data,  value_update_func = None):        
        if self.tree is None:
            tree = TreeNode(None,0)
        elif self.tree is not None:
            warnings.warn('Some data already fitted in the tree! Refitting new data!')
            tree = TreeNode(None,0)
        tree.expand_recursive(x = data, evaluation_func=self.split_func)
        self.tree = tree
        
    def predict(self, data):
        pass
    
    def print_stats(self):
        pass
    
#     def __repr__(self):
#         return self.tree
    
        

In [104]:
data = load_iris()
df  = pd.concat([pd.DataFrame(data.data),pd.DataFrame(data.target)],axis = 1)
data.feature_names.append('target')
df.columns = data.feature_names
print(df.head())
df.describe()


   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \
0                5.1               3.5                1.4               0.2   
1                4.9               3.0                1.4               0.2   
2                4.7               3.2                1.3               0.2   
3                4.6               3.1                1.5               0.2   
4                5.0               3.6                1.4               0.2   

   target  
0       0  
1       0  
2       0  
3       0  
4       0  


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
count,150.0,150.0,150.0,150.0,150.0
mean,5.843333,3.054,3.758667,1.198667,1.0
std,0.828066,0.433594,1.76442,0.763161,0.819232
min,4.3,2.0,1.0,0.1,0.0
25%,5.1,2.8,1.6,0.3,0.0
50%,5.8,3.0,4.35,1.3,1.0
75%,6.4,3.3,5.1,1.8,2.0
max,7.9,4.4,6.9,2.5,2.0


In [105]:
test_tree = Tree()

In [106]:
test_tree.fit_data(df.iloc[:,0:3].values)

In [113]:
test_tree.tree

|  *|  

In [108]:
test_tree

<__main__.Tree at 0x10455c4e0>

In [62]:
test_tree.fit_data(df.iloc[:,0:3].values)

