In [1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from typing import Tuple


In [2]:
data = load_breast_cancer()
X = data.data
y = data.target


In [3]:
X.shape, y.shape


((569, 30), (569,))

In [4]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=42
)


In [5]:
class Node:
    def __init__(
        self,
        depth: int,
        split_attribute: int = None,
        split_threshold: float = None,
        entropy: float = None,
        samples: int = None,
        values: np.ndarray = None,
        label: int = None,
    ) -> None:
        self.depth = depth
        self.split_attribute = split_attribute
        self.split_threshold = split_threshold
        self.entropy = entropy
        self.samples = samples
        self.values = values
        self.label = label
        self.right = None
        self.left = None

    def __str__(self):
        return f"Depth: {self.depth}, Split attribute index: {self.split_attribute}, threshold: {self.split_threshold}, entropy: {self.entropy:.3f}, samples: {self.samples}, values: {self.values}, label: {self.label}"


In [6]:
class DecisionTree:
    def __init__(self, max_depth: int, min_entropy_diff: float) -> None:
        self.root = None
        self.max_depth = max_depth
        self.min_entropy_diff = min_entropy_diff

    def _entropy(self, s: np.ndarray) -> float:
        y = s[:, -1]
        positive_ratio = np.sum(y) / len(y)
        return -positive_ratio * np.log(positive_ratio + 1e-3) - (
            1 - positive_ratio
        ) * np.log(1 - positive_ratio + 1e-3)

    def _find_split(
        self, s: np.ndarray
    ) -> Tuple[int, float, float, np.ndarray, np.ndarray]:
        min_split_entropy = np.inf
        best_split_attr_idx = None
        best_threshold = None
        best_left_subset = None
        best_right_subset = None

        for attribute_idx in range(s.shape[1] - 1):
            for threshold in np.unique(s[:, attribute_idx]):

                left_subset = np.array(
                    [row for row in s if row[attribute_idx] < threshold]
                )
                right_subset = np.array(
                    [row for row in s if row[attribute_idx] >= threshold]
                )

                left_entropy = (
                    (len(left_subset) / len(s)) * self._entropy(left_subset)
                    if left_subset.shape[0] > 0
                    else 0
                )
                right_entropy = (
                    (len(right_subset) / len(s)) * self._entropy(right_subset)
                    if right_subset.shape[0] > 0
                    else 0
                )

                split_entropy = left_entropy + right_entropy
                if split_entropy < min_split_entropy:
                    min_split_entropy = split_entropy
                    best_split_attr_idx = attribute_idx
                    best_threshold = threshold
                    best_left_subset = left_subset
                    best_right_subset = right_subset

        return (
            best_split_attr_idx,
            best_threshold,
            min_split_entropy,
            best_left_subset,
            best_right_subset,
        )

    def _build_id3(self, dataset: np.ndarray, depth: int) -> Node:
        if dataset.shape[0] == 0:
            return None

        X, y = dataset[:, :-1], dataset[:, -1].astype("int64")

        # all examples classified correctly
        if np.all(y == 0.0):
            return Node(
                depth=depth,
                values=[np.bincount(y)[0], 0],
                entropy=0.0,
                label=np.argmax([np.bincount(y)[0], 0]),
                samples=dataset.shape[0],
            )
        if np.all(y == 1.0):
            return Node(
                depth=depth,
                values=np.bincount(y),
                entropy=0.0,
                label=np.argmax(np.bincount(y)),
                samples=dataset.shape[0],
            )

        # no attributes to split upon
        if X.shape[1] == 0:
            return Node(
                depth=depth,
                values=np.bincount(np.squeeze(y, axis=1)),
                entropy=self._entropy(dataset),
                label=np.argmax(np.bincount(np.squeeze(y, axis=1))),
                samples=X.shape[0],
            )

        (
            best_split_attr_idx,
            best_threshold,
            min_split_entropy,
            best_left_subset,
            best_right_subset,
        ) = self._find_split(dataset)

        # decide about splitting
        if (
            depth < self.max_depth
            and (self._entropy(dataset) - min_split_entropy) > self.min_entropy_diff
        ):
            root = Node(
                depth=depth,
                split_attribute=best_split_attr_idx,
                split_threshold=best_threshold,
                entropy=self._entropy(dataset),
                samples=dataset.shape[0],
                values=np.bincount(np.squeeze(y)),
                label=np.argmax(np.bincount(np.squeeze(y))),
            )
            root.left = self._build_id3(best_left_subset, depth=depth + 1)
            root.right = self._build_id3(best_right_subset, depth=depth + 1)

        else:
            root = Node(
                depth=depth,
                entropy=self._entropy(dataset),
                samples=dataset.shape[0],
                values=np.bincount(np.squeeze(y)),
                label=np.argmax(np.bincount(np.squeeze(y))),
            )

        return root

    def visualize(self) -> None:
        queue = list()
        queue.append(self.root)

        while queue:
            v = queue.pop(0)
            print(v)
            if v.left is not None:
                queue.append(v.left)
            if v.left is not None:
                queue.append(v.right)

    def fit(self, X: np.ndarray, y: np.ndarray) -> None:
        if len(y.shape) == 1:
            y = np.expand_dims(y, axis=1)

        dataset = np.concatenate([X, y], axis=1)
        self.root = self._build_id3(dataset, depth=0)

    def _predict_sample(self, sample: np.ndarray) -> int:
        current_node = self.root
        current_prediction = current_node.label
        while current_node.split_attribute is not None:
            if sample[current_node.split_attribute] < current_node.split_threshold:
                current_prediction = current_node.left.label
                current_node = current_node.left
            else:
                current_prediction = current_node.right.label
                current_node = current_node.right

        return current_prediction

    def predict(self, X: np.ndarray) -> np.ndarray:
        if self.root is not None:
            return np.array([self._predict_sample(sample) for sample in X])
        else:
            raise Exception("Decision Tree is not trained")


In [7]:
tree = DecisionTree(max_depth=3, min_entropy_diff=0.01)
tree.fit(X_train, y_train)


In [8]:
tree.visualize()


Depth: 0, Split attribute index: 7, threshold: 0.05182, entropy: 0.658, samples: 455, values: [169 286], label: 1
Depth: 1, Split attribute index: 20, threshold: 16.89, entropy: 0.216, samples: 282, values: [ 16 266], label: 1
Depth: 1, Split attribute index: 22, threshold: 114.6, entropy: 0.356, samples: 173, values: [153  20], label: 0
Depth: 2, Split attribute index: 10, threshold: 0.645, entropy: 0.092, samples: 263, values: [  5 258], label: 1
Depth: 2, Split attribute index: 1, threshold: 16.68, entropy: 0.679, samples: 19, values: [11  8], label: 0
Depth: 2, Split attribute index: 21, threshold: 25.84, entropy: 0.687, samples: 44, values: [24 20], label: 0
Depth: 2, Split attribute index: None, threshold: None, entropy: 0.000, samples: 129, values: [129, 0], label: 0
Depth: 3, Split attribute index: None, threshold: None, entropy: 0.061, samples: 260, values: [  3 257], label: 1
Depth: 3, Split attribute index: None, threshold: None, entropy: 0.635, samples: 3, values: [2 1], la

In [9]:
predictions = tree.predict(X_test)


In [10]:
print(classification_report(y_test, predictions))


              precision    recall  f1-score   support

           0       0.97      0.91      0.94        43
           1       0.95      0.99      0.97        71

    accuracy                           0.96       114
   macro avg       0.96      0.95      0.95       114
weighted avg       0.96      0.96      0.96       114

