# Decision tree from scratch using numpy

### Information Gain $IG = I(parent) - wt.avg \big( I(child) \big)$   
Info $I$ can either be represented by Entropy or GINI where:  
 - Entropy $E(x) = - \sum{ p(x) - \log_2 \big( p(x) \big) }$  
 - GINI $G(x) = 1 - \sum{ p(x)^2 }$  
Where $p(x) = \dfrac{\#x}n$  
  
Stopping Criteria: max depth, min samples, min impurity decrease

In [1]:
import numpy as np
from collections import Counter

In [2]:
class Node:
  def __init__(
      self, feature=None, threshold=None, left=None, right=None, *, value=None
  ):
    self.feature = feature
    self.threshold = threshold
    self.left = left
    self.right = right
    self.value = value

  def isLeafNode(self):
    return self.value is not None

In [3]:

class DecisionTree:
  def __init__(self, minSampleSplit=2, maxDepth=100, nFeatures=None):
    self.minSampleSplit = minSampleSplit
    self.maxDepth = maxDepth
    self.nFeatures = nFeatures
    self.root = None

  def fit(self, X, y):
    self.nFeatures = (
        X.shape[1] if not self.nFeatures else min(
            X.shape[1], self.nFeatures)
    )
    self.root = self._growTree(X, y)

  def _growTree(self, X, y, depth=0):
    nSamples, nFeatures = X.shape
    nLabels = len(np.unique(y))

    if depth >= self.maxDepth or nLabels == 1 or nSamples < self.minSampleSplit:
      value = self._mostCommonLabel(y)
      return Node(value=value)

    featureIdxs = np.random.choice(
        nFeatures, self.nFeatures, replace=False)
    bestThreshold, bestIdx = self._bestSplit(X, y, featureIdxs)

    leftIdxs, rightIdxs = self._split(X[:, bestIdx], bestThreshold)
    left = self._growTree(X[leftIdxs, :], y[leftIdxs], depth + 1)
    right = self._growTree(X[rightIdxs, :], y[rightIdxs], depth + 1)

    return Node(bestIdx, bestThreshold, left, right)

  def _mostCommonLabel(self, y):
    counter = Counter(y)
    return counter.most_common(1)[0][0]

  def _bestSplit(self, X, y, featureIdxs):
    bestGain = -1
    splitIdx, splitThreshold = None, None

    for idx in featureIdxs:
      Xcol = X[:, idx]
      thresholds = np.unique(Xcol)

      for thr in thresholds:
        gain = self._informationGain(y, Xcol, thr)

        if gain > bestGain:
          bestGain = gain
          splitIdx = idx
          splitThreshold = thr

    return splitThreshold, splitIdx

  def _informationGain(self, y, Xcol, threshold):
    # parentInfo = self.entropy(y)
    parentInfo = self._gini(y)

    leftIdxs, rightIdxs = self._split(Xcol, threshold)

    if len(leftIdxs) == 0 or len(rightIdxs) == 0:
      return 0

    n = len(y)
    nL, nR = len(leftIdxs), len(rightIdxs)
    # iL, iR = self.entropy(y[leftIdx]), self.entropy(y[rightIdx])
    iL, iR = self._gini(y[leftIdxs]), self._gini(y[rightIdxs])

    childInfo = (nL / n) * iL + (nR / n) * iR

    informationGain = parentInfo - childInfo
    return informationGain

  def _entropy(self, y):
    ps = np.bincount(y) / len(y)
    entropy = -np.sum([p * np.log(p) for p in ps if p > 0])
    return entropy

  def _gini(self, y):
    ps = np.bincount(y) / len(y)
    impurity = 1 - np.sum(ps**2)
    return impurity

  def _split(self, Xcol, splitThreshold):
    leftIdxs = np.argwhere(Xcol <= splitThreshold).flatten()
    rightIdxs = np.argwhere(Xcol > splitThreshold).flatten()
    return leftIdxs, rightIdxs

  def predict(self, X):
    return np.array([self._traverseTree(x, self.root) for x in X])

  def _traverseTree(self, x, node: Node):
    if node.isLeafNode():
      return node.value

    if x[node.feature] <= node.threshold:
      return self._traverseTree(x, node.left)
    return self._traverseTree(x, node.right)

In [4]:
from sklearn import datasets
from sklearn.model_selection import train_test_split

breastCancer = datasets.load_breast_cancer()
X, y = breastCancer.data, breastCancer.target

XTrain, XTest, YTrain, YTest = train_test_split(
    X, y, test_size=0.2, random_state=0)

classifier = DecisionTree()
classifier.fit(XTrain, YTrain)
predictions = classifier.predict(XTest)


def accuracy(yTest, yPred):
  return np.sum(yTest == yPred) / len(yTest)


acc = accuracy(YTest, predictions)

In [5]:
acc

np.float64(0.9122807017543859)