# CART from Scratch for Iris dataset
**[For theory, click here.](./1_scratch.ipynb)**

## Load data

In [1]:
from sklearn.datasets import load_iris

In [2]:
data = load_iris()

In [3]:
data.keys()

dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])

In [4]:
X = data.data

In [5]:
X[:5,:]

array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])

In [6]:
y = data.target

In [7]:
y[:5]

array([0, 0, 0, 0, 0])

In [8]:
from sklearn.model_selection import train_test_split

In [9]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)

In [10]:
import numpy as np

In [11]:
dataset = np.column_stack((X_train, y_train))

## Implement Tree

In [12]:
class CartNode:
    def __init__(self, index, value, groups):
        self.index = index
        self.value = value
        self.groups = groups
        self.left = None
        self.right = None

In [13]:
def gini_index(groups, classes):
    n_instances = sum([len(group) for group in groups])
    gini = 0
    for group in groups:
        size = len(group)
        if size == 0:
            continue
        score = 0
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val)/size
            score += p*p
        gini += (1-score)*size/n_instances
    return gini

In [14]:
def test_split(index, value, dataset):
    left, right = [], []
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

In [15]:
def get_split(dataset, log=False):
    classes = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = -1, 0, 1, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, classes)
            if log:
                print('X%d < %.3f Gini=%.3f' % ((index), row[index], gini))
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return CartNode(b_index, b_value, b_groups)

In [16]:
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

In [17]:
def split(node, max_depth, min_size, depth):
    left, right = node.groups
    # check for a no split
    if not left or not right:
        node.left = node.right = to_terminal(left + right)
        return
    # check for max depth
    if depth >= max_depth:
        node.left, node.right = to_terminal(left), to_terminal(right)
        return
    # process left child
    if len(left) <= min_size:
        node.left = to_terminal(left)
    else:
        node.left = get_split(left)
        split(node.left, max_depth, min_size, depth+1)
    # process right child
    if len(right) <= min_size:
        node.right = to_terminal(right)
    else:
        node.right = get_split(right)
        split(node.right, max_depth, min_size, depth+1)

In [18]:
def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return root

In [19]:
def print_tree(node, depth=0):
    if isinstance(node, CartNode):
        print('%s[X%d < %.3f]' % ((depth*' ', (node.index), node.value)))
        print_tree(node.left, depth+1)
        print_tree(node.right, depth+1)
    else:
        print('%s[%s]' % ((depth*' ', node)))

In [20]:
tree = build_tree(dataset, 3, 1)
print_tree(tree)

[X2 < 3.000]
 [X0 < 4.300]
  [0.0]
  [0.0]
 [X3 < 1.700]
  [X2 < 5.000]
   [1.0]
   [2.0]
  [X2 < 4.900]
   [2.0]
   [2.0]


## Prediction

In [21]:
def predict(node, row):
    if row[node.index] < node.value:
        if isinstance(node.left, CartNode):
            return predict(node.left, row)
        else:
            return node.left
    else:
        if isinstance(node.right, CartNode):
            return predict(node.right, row)
        else:
            return node.right

In [22]:
pred = [predict(tree, row) for row in X_test]

In [23]:
from sklearn.metrics import classification_report, confusion_matrix

In [24]:
print(confusion_matrix(y_test, pred))
print(classification_report(y_test, pred))

[[13  0  0]
 [ 0 19  1]
 [ 0  0 12]]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        13
           1       1.00      0.95      0.97        20
           2       0.92      1.00      0.96        12

    accuracy                           0.98        45
   macro avg       0.97      0.98      0.98        45
weighted avg       0.98      0.98      0.98        45

