## Decision Tree Classifier Implementation from Scratch

In [1]:
import numpy as np
import pandas as pd
import operator

from sklearn.metrics import classification_report
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [2]:
ops = {
    '<=' : operator.le,
    '>' : operator.gt
}

In [3]:
iris_sklearn = load_iris(as_frame=True)
print(iris_sklearn.DESCR)

.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

:Number of Instances: 150 (50 in each of three classes)
:Number of Attributes: 4 numeric, predictive attributes and the class
:Attribute Information:
    - sepal length in cm
    - sepal width in cm
    - petal length in cm
    - petal width in cm
    - class:
            - Iris-Setosa
            - Iris-Versicolour
            - Iris-Virginica

:Summary Statistics:

                Min  Max   Mean    SD   Class Correlation
sepal length:   4.3  7.9   5.84   0.83    0.7826
sepal width:    2.0  4.4   3.05   0.43   -0.4194
petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)

:Missing Attribute Values: None
:Class Distribution: 33.3% for each of 3 classes.
:Creator: R.A. Fisher
:Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
:Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fis

In [4]:
iris: pd.DataFrame= iris_sklearn.frame
iris.head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
0,5.1,3.5,1.4,0.2,0
1,4.9,3.0,1.4,0.2,0
2,4.7,3.2,1.3,0.2,0
3,4.6,3.1,1.5,0.2,0
4,5.0,3.6,1.4,0.2,0


In [5]:
iris.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   sepal length (cm)  150 non-null    float64
 1   sepal width (cm)   150 non-null    float64
 2   petal length (cm)  150 non-null    float64
 3   petal width (cm)   150 non-null    float64
 4   target             150 non-null    int64  
dtypes: float64(4), int64(1)
memory usage: 6.0 KB


In [6]:
iris.describe()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
count,150.0,150.0,150.0,150.0,150.0
mean,5.843333,3.057333,3.758,1.199333,1.0
std,0.828066,0.435866,1.765298,0.762238,0.819232
min,4.3,2.0,1.0,0.1,0.0
25%,5.1,2.8,1.6,0.3,0.0
50%,5.8,3.0,4.35,1.3,1.0
75%,6.4,3.3,5.1,1.8,2.0
max,7.9,4.4,6.9,2.5,2.0


In [7]:
X: pd.DataFrame = iris.drop("target", axis=1)
y: pd.Series = iris["target"]

In [8]:
X: np.ndarray = X.to_numpy()
y: np.ndarray = y.to_numpy()

In [9]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)

In [10]:
class Gini:    
    def impurity(self, values: np.ndarray) -> float:
        gini = 1 - np.sum((values / np.sum(values))**2)
        return gini
    
    def loss(self, left_samples: int, left_gini: float, right_samples: int, right_gini: float) -> float:
        loss = (left_samples*left_gini + right_samples*right_gini) / (left_samples + right_samples)
        return loss

In [11]:
class Node:
    def __init__(self, gini: float = None, samples: int = None, values: np.ndarray = None, cls: int = None, feature: int = None, threshold: float = None):
        self.feature = feature
        self.threshold = threshold
        self.gini = gini
        self.samples = samples
        self.values = values
        self.cls = cls
               
        self.left: Node = None
        self.right: Node = None
        
    def __repr__(self) -> str:
        if self.feature is None:
            return f'Leaf: samples={self.samples}, value = {self.values}, class = {self.cls}'
        return f'X[{self.feature}] <= {self.threshold:.3f}, gini={self.gini:.3f}, samples={self.samples}, value = {self.values}, class = {self.cls}'

In [12]:
class DecisionTree:
    def __init__(self, max_depth: int = None, max_features: int = None):
        self.root: Node = Node()
        self.classes: np.ndarray = None
        self.gini = Gini()
        self.max_depth = max_depth
        self.max_features = max_features
        
    def create_best_split(self, parent_node: Node, X: np.ndarray, y: np.ndarray):
        best_gini = parent_node.gini
        n_features = X.shape[1]     
        features = np.random.choice(n_features, self.max_features, replace=False)
        n_classes = len(self.classes)
        
        for n in features:
            sorted_ids = np.argsort(X[:, n])
            thresholds = X[:, n][sorted_ids]
            labels = y[sorted_ids]
            
            left_values = np.zeros(n_classes)
            right_values = np.bincount(labels, minlength=n_classes)
            left_samples = 0
            right_samples = len(labels)
            
            for m in range(0, len(thresholds)-1):
                cls = labels[m]
                left_values[cls] += 1
                right_values[cls] -= 1
                left_samples += 1
                right_samples -= 1
            
                if thresholds[m+1] == thresholds[m]:
                    continue
                
                left_gini = self.gini.impurity(left_values)
                right_gini = self.gini.impurity(right_values)
                new_gini = self.gini.loss(left_samples, left_gini, right_samples, right_gini)
                
                if new_gini < best_gini:
                    best_gini = new_gini
                    parent_node.feature = n
                    parent_node.threshold = (thresholds[m] + thresholds[m+1]) / 2
                    parent_node.left = Node(left_gini, left_samples, left_values.copy(), self.classes[np.argmax(left_values)])
                    parent_node.right = Node(right_gini, right_samples, right_values.copy(), self.classes[np.argmax(right_values)])
                    
    def create_nodes(self, parent_node: Node, X: np.ndarray, y: np.ndarray, max_depth: int):
        if max_depth == 0 or len(np.unique(y)) == 1 or parent_node.samples <= 1:
            return
        
        self.create_best_split(parent_node, X, y)

        if parent_node.feature is None or parent_node.threshold is None:
            return
              
        X_left = X[X[:, parent_node.feature] <= parent_node.threshold]
        y_left = y[X[:, parent_node.feature] <= parent_node.threshold]
        
        X_right = X[X[:, parent_node.feature] > parent_node.threshold]
        y_right = y[X[:, parent_node.feature] > parent_node.threshold]
        
        self.create_nodes(parent_node.left, X_left, y_left, None if max_depth is None else max_depth - 1)
        self.create_nodes(parent_node.right, X_right, y_right, None if max_depth is None else max_depth - 1)
                  
    def build_tree(self, X: np.ndarray, y: np.ndarray):
        self.classes, values = np.unique(y, return_counts=True)
        samples = len(y)
        gini = self.gini.impurity(values)
        self.root = Node(gini, samples, values, self.classes[np.argmax(values)])
        self.create_nodes(self.root, X, y, self.max_depth)
    
    def check_class(self, x: np.ndarray) -> int:
        current_node = self.root

        while current_node.feature is not None:
            if x[current_node.feature] <= current_node.threshold:
                current_node = current_node.left
            else:
                current_node = current_node.right
   
        return current_node.cls
    
    def pre_order_traversal(self, node: Node, depth: int):
        if node is None:
            return
        indent = "  " * depth
        print(indent, node)
        self.pre_order_traversal(node.left, depth + 1)
        self.pre_order_traversal(node.right, depth + 1)    
        
    def print_tree(self):
        self.pre_order_traversal(self.root, 0) 

In [13]:
class DecisionTreeClassifier(DecisionTree):
    def __init__(self, max_depth: int = None, max_features = None):
        super().__init__(max_features=max_features)
        self.max_depth = max_depth
    
    def _resolve_max_features(self, X: np.ndarray) -> int:
        n_features = X.shape[1]

        if self.max_features is None:
            return n_features
        elif isinstance(self.max_features, str):
            if self.max_features == "sqrt":
                return max(1, int(np.sqrt(n_features)))
            elif self.max_features == "log2":
                return max(1, int(np.log2(n_features)))
            else:
                raise ValueError(f"Unknown max_features string: {self.max_features}")
        elif isinstance(self.max_features, int):
            return max(1, min(self.max_features, n_features))
        elif isinstance(self.max_features, float):
            if not (0.0 < self.max_features <= 1.0):
                raise ValueError("If max_features is float, it must be in [0,1].")
            return max(1, int(self.max_features * n_features))
        else:
            raise TypeError("max_features must be None, int, float, or str {'sqrt', 'log2'}.")
           
    def fit(self, X: np.ndarray, y: np.ndarray):
        self.max_features = self._resolve_max_features(X)
        self.build_tree(X, y)
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        y_pred = np.array([self.check_class(x) for x in X])
        return y_pred

In [14]:
model = DecisionTreeClassifier(max_depth=3)

In [15]:
model.fit(X_train, y_train)

In [16]:
model.print_tree()

 X[2] <= 2.600, gini=0.665, samples=112, value = [34 39 39], class = 1
   Leaf: samples=34, value = [34.  0.  0.], class = 0
   X[3] <= 1.750, gini=0.500, samples=78, value = [ 0 39 39], class = 1
     X[2] <= 5.350, gini=0.136, samples=41, value = [ 0. 38.  3.], class = 1
       Leaf: samples=39, value = [ 0. 38.  1.], class = 1
       Leaf: samples=2, value = [0 0 2], class = 2
     X[2] <= 4.850, gini=0.053, samples=37, value = [ 0  1 36], class = 2
       Leaf: samples=2, value = [0. 1. 1.], class = 1
       Leaf: samples=35, value = [ 0  0 35], class = 2


In [17]:
y_pred = model.predict(X_test)

In [18]:
target_names = ['Setosa', 'Versicolor', 'Virginica']
print(classification_report(y_test, y_pred, target_names=target_names))

              precision    recall  f1-score   support

      Setosa       1.00      1.00      1.00        16
  Versicolor       0.79      1.00      0.88        11
   Virginica       1.00      0.73      0.84        11

    accuracy                           0.92        38
   macro avg       0.93      0.91      0.91        38
weighted avg       0.94      0.92      0.92        38

