In [3]:
import numpy as np
import pandas as pd


In [122]:
class DecisionTreeNode:
    def __init__(self):
        
        self.cat_features = None

        self.left = None
        self.right = None
        self.isleaf = False
        
        self.split_val = None
        self.split_feature = None
        self.classification = None
        
        self.depth = 0

    def is_pure(self, y):
        # check whether the branch is pure() having same class )
        
        # compare all class label with the class of first row. 
#         print(y)
        for i in y:
            if i != y[0]:
                return False
        return True

    def find_split_val(self, X, y, cat_features):
        # Find best split value
        class_values = list(set(y))
        b_feature, b_value, b_score, b_groups = None, None, None, None
        for feature in range(len(X[0])):
            for row in X:                
                groups = self.split(feature, row[feature], cat_features, X, y)
                gini = self.gini_index([groups[0]['y'], groups[1]['y']], class_values)
#                 print('X%d < %.3f Gini=%.3f' % ((feature + 1), row[feature], gini))
                if b_score is None or gini < b_score:
                    b_index, b_value, b_score, b_groups = feature, row[feature], gini, groups
        return {'index': b_index, 'value': b_value, 'groups': b_groups, 'gini': b_score}

    def split(self, feature, val, cat_features, X, y):
        # split data according to split criteria
        left_X, left_y, right_X, right_y = list(), list(), list(), list()
        if feature in cat_features:
            # categorical feature
            # data with feature value equal to val goes to left ( feature == val )
            # the rest goes to right
            for idx, row in enumerate(X):
                if row[feature] == val:
                    left_X.append(row)
                    left_y.append(y[idx])
                else:
                    right_X.append(row) 
                    right_y.append(y[idx])
        #             idx = X[feature] == val
        #             left_X = X.loc[idx]
        #             left_y = y.loc[idx]
        #             idx = ~idx
        #             right_X = X.loc[idx] 
        #             right_y = y.loc[idx] 
        else:
            # numerical feature
            # data with feature value smaller than val goes to left ( feature < val )
            # the rest goes to right
            for idx, row in enumerate(X):
                if row[feature] < val:
                    left_X.append(row)
                    left_y.append(y[idx])
                else:
                    right_X.append(row) 
                    right_y.append(y[idx])


        return [{'X':left_X, 'y':left_y}, {'X': right_X, 'y': right_y}]

    def gini_index(self, groups, classes):
        # calculte Gini index of a split
        # sum up gini index i(s,t) of both children trees, ex. i_left + i_right

        n_instances = 0

        for gr in groups:
            n_instances += len(gr)

        # sum weighted Gini index for each group
        gini = 0.0
        for gr in groups:
            size = len(gr)
            # avoid divide by zero
            if size == 0:
                continue
            score = 0.0

            # score the group based on the score for each class
            for class_val in classes:
                p = 0.0
                for v in gr:
                    if v == class_val:
                        p += 1
                p = p/ size
                score += p * p
            # weight the group score by its relative size
            gini += (1.0 - score) * (size / n_instances) # *0.5
        return gini

    def grow(self, X, y, cat_features, max_depth, depth):
        
        if self.is_pure(y):
            self.terminate(y, depth)
            return
        else:
            
            best_split = self.find_split_val(X, y, cat_features)
            self.split_val = best_split['value']
            self.split_feature = best_split['index']
            [left, right] = best_split['groups']

#             print("{}X{} < {} gini={}".format(self.depth*' ',self.split_feature+1, self.split_val, best_split['gini']))
#             print("left - Class 0:{},class 1:{}".format(np.sum(left.iloc[:,-1]==0),np.sum(left.iloc[:,-1]==1)))
#             print("right - Class 0:{},class 1:{}".format(np.sum(right.iloc[:,-1]==0),np.sum(right.iloc[:,-1]==1)))
            
            self.left = DecisionTreeNode()
            self.right = DecisionTreeNode()
            self.left.grow(left['X'], left['y'], cat_features, max_depth, depth+1)
            self.right.grow(right['X'], right['y'], cat_features, max_depth, depth+1)

        self.depth = depth
        return
        
    def terminate(self, y, depth):
        # define leaf node
        # most frequent class in the data as class label of this node
        self.classification = max(set(y), key=y.count)
        self.isleaf = True
        self.depth = depth
    
    def train(self, X, y, cat_features, max_depth):
        # grow a tree
        X = X.values.tolist()
        y = y.values.tolist()
        self.grow(X, y, cat_features, max_depth, 1)

    def print_tree(self, cat_features):
        if not self.isleaf:
            if self.split_feature in cat_features:
                print("{}X{} = {} ".format(self.depth*' ',self.split_feature, self.split_val))
            else:
                print("{}X{} < {} ".format(self.depth*' ',self.split_feature, self.split_val))
            self.left.print_tree(cat_features)
            self.right.print_tree(cat_features)
        else:
            print("{}[{}]".format(self.depth*' ', self.classification))
        
    def predict_iterate(self, row, cat_features):
        
        if self.isleaf:
            # is leaf node
            return self.classification
        else:
            # not leaf node
            if self.split_feature in cat_features:
                # predict categorical feature 
                if row[self.split_feature] == self.split_val:
                    return self.left.predict_iterate(row, cat_features)
                else:
                    return self.right.predict_iterate(row,cat_features)
            else:
                # predict numerical feature 
                if row[self.split_feature] < self.split_val:
                    return self.left.predict_iterate(row,cat_features)
                else:
                    return self.right.predict_iterate(row,cat_features)
    
    def predict(self, X, cat_features):
        num_rows = X.shape[0]
        prediction = []
#         prediction = np.zeros((num_rows,1))
        for idx, row in X.iterrows():
            prediction.append(self.predict_iterate(row, cat_features))
            
        return np.array(prediction)

    def prune(self, tree):
        # prune here
        pruned_tree = tree
        return pruned_tree


In [123]:
data = pd.DataFrame([[23.771244718,1.784783929, 'a',0],
[66.728571309,55.169761413, 'a',0],
[3.678319846,3.81281357, 'a',0],
[3.961043357,0.61995032, 'a',0],
[2.999208922,2.209014212, 'b',0],
[7.497545867,3.162953546, 'b',1],
[9.00220326,3.339047188, 'b',1],
[7.444542326,0.476683375, 'b',1],
[10.12493903,3.234550982, 'b',1],
[6.642287351,3.319983761, 'b',1]])

dt = DecisionTreeNode()
dt.train(X=data.iloc[:,:-1],y=data.iloc[:,-1], cat_features=[2], max_depth=5)


In [124]:
dt.print_tree([2])

 X2 = a 
  [0]
  X0 < 6.642287351 
   [0]
   [1]


In [125]:
pred = dt.predict(data,[2])
pred

array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

# Test with Chronic Kidney Disease

In [31]:
df = pd.read_csv("./dataset/chronic_kidney_disease_full.csv", na_values=['?','\t?'])
print(df.shape)
df = df.drop(columns = ['id'])
df.head()

(400, 26)


Unnamed: 0,'age','bp','sg','al','su','rbc','pc','pcc','ba','bgr',...,'pcv','wbcc','rbcc','htn','dm','cad','appet','pe','ane','class'
0,48.0,80.0,1.02,1.0,0.0,,normal,notpresent,notpresent,121.0,...,44.0,7800.0,5.2,yes,yes,no,good,no,no,ckd
1,7.0,50.0,1.02,4.0,0.0,,normal,notpresent,notpresent,,...,38.0,6000.0,,no,no,no,good,no,no,ckd
2,62.0,80.0,1.01,2.0,3.0,normal,normal,notpresent,notpresent,423.0,...,31.0,7500.0,,no,yes,no,poor,no,yes,ckd
3,48.0,70.0,1.005,4.0,0.0,normal,abnormal,present,notpresent,117.0,...,32.0,6700.0,3.9,yes,no,no,poor,yes,yes,ckd
4,51.0,80.0,1.01,2.0,0.0,normal,normal,notpresent,notpresent,106.0,...,35.0,7300.0,4.6,no,no,no,good,no,no,ckd


In [32]:
# Fill NA values with most frequent values in the column
df = df.fillna(df.mode().iloc[0])
df.isnull().sum().sum()

0

In [33]:
new_col = []
for col in df.columns:
    col = col.replace("'",'')
    new_col.append(col)
df.columns = new_col

str_cols = []
str_cols_index = []
num_cols = []
i = 0
for col in df.drop(columns = 'class').columns:
    if df[col].dtype != np.int64 and df[col].dtype != np.float64:
        str_cols.append(col)
        str_cols_index.append(i)
    else:
        num_cols.append(col)
    i += 1
    
print("categorical columns: {}\n".format(str_cols))
print("categorical column index: {}\n".format(str_cols_index))
print("numerical columns: {}\n".format(num_cols))

categorical columns: ['rbc', 'pc', 'pcc', 'ba', 'htn', 'dm', 'cad', 'appet', 'pe', 'ane']

categorical column index: [5, 6, 7, 8, 18, 19, 20, 21, 22, 23]

numerical columns: ['age', 'bp', 'sg', 'al', 'su', 'bgr', 'bu', 'sc', 'sod', 'pot', 'hemo', 'pcv', 'wbcc', 'rbcc']



In [62]:
from sklearn.model_selection import train_test_split
X = df.drop(columns=['class'])
y = df['class']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)


In [76]:
X_train.shape[0]

268

In [127]:
dt = DecisionTreeNode()
%prun dt.train(X=X_train,y=y_train, cat_features=str_cols, max_depth=5)

 

In [128]:
pred = dt.predict(X_test,str_cols)

print(np.sum(pred == y_test)/y_test.shape[0])
dt.print_tree(str_cols)

0.9242424242424242
 X15 < 42.0 
  X14 < 15.5 
   X12 < 145.0 
    X10 < 16.0 
     X13 < 4.4 
      [ckd]
      [notckd]
     X12 < 139.0 
      [ckd]
      X14 < 15.0 
       [ckd]
       X0 < 73.0 
        [notckd]
        [ckd]
    X11 < 1.6 
     [notckd]
     [ckd]
   [notckd]
  X2 < 1.02 
   [ckd]
   X18 < yes 
    X5 < normal 
     [ckd]
     [notckd]
    [ckd]
