In [1]:
import numpy as np
import matplotlib.pyplot as plt

## Decision Tree class

In [2]:
class TreeNode():
    def __init__(self):
        self.feature = None
        self.threshold = None
        self.value = None
        self.left = None
        self.right = None
    

In [3]:
class DecisionTree:
    def __init__(self):
        self.root = None

    
    def fit(self, X, y):
        self.root = self._build_tree(X, y)

    
    def predict(self, X):
        res = []
        for i in range(X.shape[0]):
            row = X[i, ...]
            node = self.root
            while node.value == None:
                if row[node.feature] < node.threshold:
                    node = node.left
                else:
                    node = node.right
            res.append(node.value) 
        return np.asarray(res)

    
    @staticmethod
    def _entropy(y):
        # best case (entropy = 0): all values belong to one class.
        # worst case (binary, entropy = 1): 50% of values belong to one class, 50% - to another class
        
        probabilities = np.array(np.unique(y, return_counts=True)).T[:, 1]/len(y)
        return np.sum(-probabilities * np.log2(probabilities + 1e-9)) # add 1e-9 to avoid log2(0)

    
    @staticmethod
    def _information_gain(y, y_left, y_right):
        y_entropy = __class__._entropy(y)

        y_left_entropy = __class__._entropy(y_left)
        y_right_entropy = __class__._entropy(y_right)

        y_left_weight = len(y_left) / len(y)
        y_rigth_weight = len(y_right) / len(y)

        weighted_entropy = y_left_weight * y_left_entropy + y_rigth_weight * y_right_entropy
        return y_entropy - weighted_entropy

    
    def _best_split(self, X, y):
        best_split = {'max_gain': -1}

        if len(np.unique(y)) < 2:
            return best_split
            
        num_features = X.shape[1]
        for i in range(num_features):
            feature_values = X[:, i]
            
            thresholds = np.unique(feature_values)
            for threshold in thresholds:
                mask = feature_values < threshold
                X_left, X_right = X[mask, ...], X[~mask, ...]
                y_left, y_right = y[mask], y[~mask]
                
                if not len(y_left) or not len(y_right):
                    continue
                    
                information_gain = self._information_gain(y, y_left, y_right)
                if information_gain > best_split['max_gain']:
                    best_split['max_gain'] = information_gain
                    best_split['feature'] = i
                    best_split['threshold'] = threshold
                    best_split['X_left'] = X_left
                    best_split['y_left'] = y_left
                    best_split['X_right'] = X_right
                    best_split['y_right'] = y_right
        return best_split

    
    def _build_tree(self, X, y):
        best_split = self._best_split(X, y)
        
        node = TreeNode()
        if best_split['max_gain'] != -1:
            node.feature = best_split['feature']
            node.threshold = best_split['threshold']
            node.left = self._build_tree(best_split['X_left'], best_split['y_left'])
            node.right = self._build_tree(best_split['X_right'], best_split['y_right'])
        else:
            node.value = np.bincount(y).argmax()
            
        return node
 

## Test

In [4]:
from sklearn.datasets import make_classification

X, y = make_classification(n_samples = 1_000, 
                           n_features = 5,
                           n_informative = 3,
                           n_redundant = 0,
                           n_classes = 3,
                           n_clusters_per_class = 1,
                           random_state = 42)

In [5]:
test_split = int(X.shape[0] * 0.2)

X_train, X_test = X[:-test_split, ...], X[-test_split:, ...]
y_train, y_test = y[:-test_split, ...], y[-test_split:, ...]

print(X_train.shape, X_test.shape)

(800, 5) (200, 5)


In [6]:
dt_clf = DecisionTree()
dt_clf.fit(X_train, y_train)

In [7]:
pred = dt_clf.predict(X_test)

In [8]:
# accuracy
sum(y_test == pred) / len(y_test)

0.84