In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
data = pd.read_csv('airfoil_noise_data.csv')
X = data.iloc[:, :-1].values
Y = data.iloc[:, -1].values.reshape(-1,1)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, random_state=41)

In [40]:
class Node():
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, var_red=None, value=None):
        ''' constructor ''' 
        
        # for decision node
        self.feature_index = feature_index
        self.threshold = threshold
        self.left = left
        self.right = right
        self.var_red = var_red
        
        # for leaf node
        self.value = value
        
        
        

In [55]:
class DT_REGRESSOR():
    def __init__(self,min_sample_split,max_depth):
        
        self.min_sample_split = min_sample_split
        self.max_depth = max_depth
        
    def create_tree(self,X,y,depth=0):
        
        #keep growing tree until conditions are met
        #once conditions are met then return leaf value
       
        if depth <= self.max_depth and len(X) > self.min_sample_split:
            #first split
            best_split_ = self.best_split(X,y)
            if best_split_["var"]>0:
                left_tree = self.create_tree(best_split_['left_x'],best_split_['left_y'],depth+1)
                right_tree = self.create_tree(best_split_['left_x'],best_split_['left_y'],depth+1)

                return Node(best_split_['feature_index'],best_split_['threshold'],\
                            left_tree,right_tree,best_split_['var'])
        
        return Node(value = np.mean(y))
            
    def best_split(self,X,y):
        
        # maximize variance reduction
        base_var = np.NINF
        out = {}
        
        for col in range(np.shape(X)[1]):
            thresholds = sorted(list(set(X[:,col])))
            for t in thresholds:
                
                left_split_idx,right_split_idx = np.where(X[:,col] <= t)[0],np.where(X[:,col] > t)[0]
                
                if len(left_split_idx)>0 and len(right_split_idx)>0:
                    curr_var = self.var_(y,y[left_split_idx],y[right_split_idx])

                    if curr_var > base_var:
                        base_var = curr_var
                        out['left_x'] = X[left_split_idx,:]
                        out['left_y'] = y[left_split_idx]

                        out['right_x'] = X[right_split_idx,:]
                        out['right_y'] = y[right_split_idx]

                        out['threshold'] = t
                        out['feature_index'] = col
                        out['var'] = curr_var
        return out
                
    def var_(self,head_arr,left_arr,right_arr):
        
        return np.var(head_arr) - ((len(left_arr)/len(head_arr))*np.var(left_arr) + (len(right_arr)/len(head_arr))*np.var(right_arr))
        
   
        
    def print_tree(self, tree=None, indent=" "):
        ''' function to print the tree '''
        
        if not tree:
            tree = self.root

        if tree.value is not None:
            print(tree.value)

        else:
            print("X_"+str(tree.feature_index), "<=", tree.threshold, "?", tree.var_red)
            print("%sleft:" % (indent), end="")
            self.print_tree(tree.left, indent + indent)
            print("%sright:" % (indent), end="")
            self.print_tree(tree.right, indent + indent)
            
    def fit(self,X,y):
        
        self.root = self.create_tree(X,y)


In [56]:
test = DT_REGRESSOR(3,3)
test.fit(X_train,Y_train)
test.print_tree()

X_0 <= 3150.0 ? 7.132048702017748
 left:X_4 <= 0.033779199999999995 ? 3.5903305690676675
  left:X_3 <= 55.5 ? 1.1789899981318328
    left:X_4 <= 0.00251435 ? 1.614396721819876
        left:128.9919833333333
        right:128.9919833333333
    right:X_4 <= 0.00251435 ? 1.614396721819876
        left:128.9919833333333
        right:128.9919833333333
  right:X_3 <= 55.5 ? 1.1789899981318328
    left:X_4 <= 0.00251435 ? 1.614396721819876
        left:128.9919833333333
        right:128.9919833333333
    right:X_4 <= 0.00251435 ? 1.614396721819876
        left:128.9919833333333
        right:128.9919833333333
 right:X_4 <= 0.033779199999999995 ? 3.5903305690676675
  left:X_3 <= 55.5 ? 1.1789899981318328
    left:X_4 <= 0.00251435 ? 1.614396721819876
        left:128.9919833333333
        right:128.9919833333333
    right:X_4 <= 0.00251435 ? 1.614396721819876
        left:128.9919833333333
        right:128.9919833333333
  right:X_3 <= 55.5 ? 1.1789899981318328
    left:X_4 <= 0.00251435 ? 1