# Decision Tree Classifier (CART) Implementation

In [52]:
# Here is a Toy Dataset
dataset =[[2.771244718,1.784783929,0],
          [1.728571309,1.169761413,0],
          [3.678319846,2.81281357,0],
          [3.961043357,2.61995032,0],
          [2.999208922,2.209014212,0],
          [7.497545867,3.162953546,1],
          [9.00220326,3.339047188,1],
          [7.444542326,0.476683375,1],
          [10.12493903,3.234550982,1],
          [6.642287351,3.319983761,1]]

In [53]:
def test_split(data, index, value):
    left, right = [], []
    
    for row in data:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
            
    return [left, right]

In [90]:
def gini(groups):
    # load left and right groups
    left, right = groups[0], groups[1] 
    
    # Sample size of each groups for probability calculation
    num_left_samples = float(len(left))
    num_right_samples = float(len(right))
    num_total_samples = num_left_samples + num_right_samples
    
    # Each class samples in each groups
    num_left_class_0 = [row[-1] for row in left].count(0) # Class 0 samples in left
    num_left_class_1 = [row[-1] for row in left].count(1) # Class 1 samples in left
    num_right_class_0 = [row[-1] for row in right].count(0) # Class 0 samples in right
    num_right_class_1 = [row[-1] for row in right].count(1) # Class 1 samples in right
    
    # Probability scores
    left_class_0_prob, left_class_1_prob, right_class_0_prob, right_class_1_prob = 0.0, 0.0, 0.0, 0.0
    left_total_score, right_total_score = 0.0, 0.0
    
    # check if the left samples are empty
    if not num_left_samples:
        pass
    else:
        left_class_0_prob = num_left_class_0 / num_left_samples
        left_class_1_prob = num_left_class_1 / num_left_samples
        left_total_score = left_class_0_prob**2 + left_class_1_prob**2 # Take the total square probabilities
    
    # Check if the right samples are empty
    if not num_right_samples:
        pass
    else:
        right_class_0_prob = num_right_class_0 / num_right_samples
        right_class_1_prob = num_right_class_1 / num_right_samples
        right_total_score = right_class_0_prob**2 + right_class_1_prob**2 # Take the total square probabilities
    
    # Calculate Gini score for each groups
    left_gini_score = (1 - left_total_score)*num_left_samples / num_total_samples
    right_gini_score = (1 - right_total_score)*num_right_samples / num_total_samples
    
    return left_gini_score + right_gini_score

In [93]:
def best_split(train, lowest_gini=100.0, gini_score=0.0):
    
    # Looping through all the values in each column except class column
    for col in range(len(train[0])-1):
        for row in train:
            groups = test_split(train, col, row[col]) # split into groups based on each value
            gini_score = gini(groups) # Calc Gini score
            
            if gini_score < lowest_gini:
                lowest_gini = gini_score # Find the lowest Gini
                best_index, best_value, best_group = col, row[col], groups
                
    return {"index": best_index, "val": best_value, "sub-tree": best_group}

In [56]:
def build_tree():
    pass

In [57]:
def fit():
    pass

In [58]:
def predict():
    pass

In [59]:
p = []
for i in dataset:
    p.append(i[-1])

In [60]:
p

[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

In [61]:
p.count(0)

5

In [62]:
x = 9
if x == 9:
    pass
else:
    print("got it")

In [86]:
gr = test_split(dataset, 0, 6.642287351)

In [87]:
gini(gr)

1.0 0.0 0.0 1.0


0.0

In [88]:
gr[0]

[[2.771244718, 1.784783929, 0],
 [1.728571309, 1.169761413, 0],
 [3.678319846, 2.81281357, 0],
 [3.961043357, 2.61995032, 0],
 [2.999208922, 2.209014212, 0]]

In [89]:
gr[1]

[[7.497545867, 3.162953546, 1],
 [9.00220326, 3.339047188, 1],
 [7.444542326, 0.476683375, 1],
 [10.12493903, 3.234550982, 1],
 [6.642287351, 3.319983761, 1]]

In [94]:
best_split(dataset)

{'index': 0,
 'val': 6.642287351,
 'sub-tree': [[[2.771244718, 1.784783929, 0],
   [1.728571309, 1.169761413, 0],
   [3.678319846, 2.81281357, 0],
   [3.961043357, 2.61995032, 0],
   [2.999208922, 2.209014212, 0]],
  [[7.497545867, 3.162953546, 1],
   [9.00220326, 3.339047188, 1],
   [7.444542326, 0.476683375, 1],
   [10.12493903, 3.234550982, 1],
   [6.642287351, 3.319983761, 1]]]}