In [10]:
import numpy as np
import pandas as pd

def create_dataset():
    dataset = pd.DataFrame({
        'Age': [25, 45, 35, 50, 23, 43, 21, 35, 55, 22],
        'Salary': [50000, 90000, 60000, 80000, 55000, 95000, 45000, 70000, 105000, 48000],
        'Buy': [0, 1, 0, 1, 0, 1, 0, 0, 1, 0]
    })
    return dataset

dataset = create_dataset()
print(dataset)


   Age  Salary  Buy
0   25   50000    0
1   45   90000    1
2   35   60000    0
3   50   80000    1
4   23   55000    0
5   43   95000    1
6   21   45000    0
7   35   70000    0
8   55  105000    1
9   22   48000    0


In [8]:
class DecisionTree:
    def __init__(self, min_samples_split=2, max_depth=2):
        self.min_samples_split = min_samples_split
        self.max_depth = max_depth
        self.tree = None

    def fit(self, X, y):
        dataset = np.c_[X, y] 
        self.tree = self._build_tree(dataset)

    def predict(self, X):
        predictions = [self._make_prediction(x, self.tree) for x in X]
        return np.array(predictions)

    def _gini_index(self, groups, classes):
        n_instances = sum([len(group) for group in groups])
        gini = 0.0
        for group in groups:
            size = len(group)
            if size == 0:
                continue
            score = 0.0
            for class_val in classes:
                p = [row[-1] for row in group].count(class_val) / size
                score += p * p
            gini += (1.0 - score) * (size / n_instances)
        return gini

    def _test_split(self, index, value, dataset):
        left, right = list(), list()
        for row in dataset:
            if row[index] < value:
                left.append(row)
            else:
                right.append(row)
        return left, right

    def _get_split(self, dataset):
        class_values = list(set(row[-1] for row in dataset))
        b_index, b_value, b_score, b_groups = 999, 999, 999, None
        for index in range(len(dataset[0]) - 1):
            for row in dataset:
                groups = self._test_split(index, row[index], dataset)
                gini = self._gini_index(groups, class_values)
                if gini < b_score:
                    b_index, b_value, b_score, b_groups = index, row[index], gini, groups
        return {'index': b_index, 'value': b_value, 'groups': b_groups}

    def _to_terminal(self, group):
        outcomes = [row[-1] for row in group]
        return max(set(outcomes), key=outcomes.count)

    def _split(self, node, depth):
        left, right = node['groups']
        del(node['groups'])
        if not left or not right:
            node['left'] = node['right'] = self._to_terminal(left + right)
            return
        if depth >= self.max_depth:
            node['left'], node['right'] = self._to_terminal(left), self._to_terminal(right)
            return
        if len(left) <= self.min_samples_split:
            node['left'] = self._to_terminal(left)
        else:
            node['left'] = self._get_split(left)
            self._split(node['left'], depth+1)
        if len(right) <= self.min_samples_split:
            node['right'] = self._to_terminal(right)
        else:
            node['right'] = self._get_split(right)
            self._split(node['right'], depth+1)

    def _build_tree(self, train):
        root = self._get_split(train)
        self._split(root, 1)
        return root

    def _make_prediction(self, row, node):
        if row[node['index']] < node['value']:
            if isinstance(node['left'], dict):
                return self._make_prediction(row, node['left'])
            else:
                return node['left']
        else:
            if isinstance(node['right'], dict):
                return self._make_prediction(row, node['right'])
            else:
                return node['right']


In [9]:
if __name__ == "__main__":
    dataset = create_dataset()

    X = dataset[['Age', 'Salary']].values
    y = dataset['Buy'].values

    model = DecisionTree(max_depth=3)
    model.fit(X, y)

    predictions = model.predict(X)
    print("Predictions:", predictions)
    print("Actual:", y)


Predictions: [0 1 0 1 0 1 0 0 1 0]
Actual: [0 1 0 1 0 1 0 0 1 0]
