<a href="https://colab.research.google.com/github/nkrj01/Models-from-scratch/blob/main/decision_tree_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris

## **Helper Functions**

In [9]:
def checkPurity(data):
    """
    Args: np array
    returns: bool

    Checks whether y-data is pure or not
    """
    y = data[:, -1]
    unique = np.unique(y)
    if len(unique) == 1:
        return True
    else:
        return False


def classify(data):
    """
    Args: np array
    Return: int

    Classifies the data in a leaf node by calculating the max
    number of class in the leaf node data
    """
    y = data[:, -1]
    unique_classes, unique_classes_count = np.unique(y, return_counts=True)
    index = unique_classes_count.argmax()
    classification = unique_classes[index]
    return int(classification)


def getPotentialSplits(data):
    """
    Args: np array
    Returns: dictionary

    returns a dictionary where keys are column index
    and the values are the potential split values
    """

    potential_splits = {}
    m = data.shape[1] - 1 # number of features

    for col_index in range(m):
        potential_splits[col_index] = []
        values = data[:, col_index]
        unique_values = np.unique(values)

        for j in range(len(unique_values)):
            if j != 0:
                previous_value = unique_values[j - 1]
                current_value = unique_values[j]
                split = (previous_value + current_value) / 2
                potential_splits[col_index].append(split)
    return potential_splits


def splitData(data, split_column, split_value):
    """
    Args: np array, int, float
    Returns: np array, np array

    given a numpy array data, split column, and a split values,
    this function splits the function in two parts and returns
    both the np array

    """
    bool_index_below = data[:, split_column] <= split_value
    data_below = data[bool_index_below, :]
    bool_index_above = data[:, split_column] > split_value
    data_above = data[bool_index_above, :]
    return data_below, data_above


def getEntropy(data):
    """
    Args: np array
    Returns: float

    Using y-values of the data, returns the entropy of a given data
    """
    _, counts = np.unique(data[:, -1], return_counts=True)
    probability = counts / counts.sum()
    entropy = sum(probability * -np.log2(probability))
    return entropy


def getTotalEntropy(data_below, data_above):
    """
    Args: np array, np array
    Returns: float

    Using y-values of the data duo after the split, returns the total entropy
    """
    count_below = data_below.shape[0]
    count_above = data_above.shape[0]
    p_below = count_below / (count_above + count_below)
    p_above = count_above / (count_above + count_below)
    total_entropy = p_below * getEntropy(data_below) + p_above * getEntropy(data_above)
    return total_entropy


def getBestSplit(data, potential_splits):
    """
    Args: np array, dictionary
    Returns: int, float

    Given a data, this function returns the best column and the best split
    value of that column that will lead to minimum entropy of split data
    """
    total_entropy = 999
    for col_index in potential_splits:
        for value in potential_splits[col_index]:
            data_below, data_above = splitData(data, col_index, value)
            current_total_entropy = getTotalEntropy(data_below, data_above)

            if current_total_entropy <= total_entropy:
                total_entropy = current_total_entropy
                best_split_column = col_index
                best_split_value = value

    return best_split_column, best_split_value


## **Data Import**

In [3]:
iris = load_iris()
X, y = iris.data, iris.target.reshape(-1, 1)
data = np.hstack((X, y))

# **Decision Tree Algorithm**

In [12]:
def decision_tree(data):
  # exit condition
  if checkPurity(data):
    classfication = classify(data)
    return classfication

  potential_splits = getPotentialSplits(data)
  split_column, split_value = getBestSplit(data, potential_splits) # get the best split
  data_below, data_above = splitData(data, split_column, split_value) # split at the best split

  # recursion
  sub_tree = {f"feature #{split_column}, value<{split_value}":[decision_tree(data_below), decision_tree(data_above)]}

  return sub_tree


In [13]:
tree = decision_tree(data)
print(tree)

{'feature #3, value<0.8': [0, {'feature #3, value<1.75': [{'feature #2, value<4.95': [{'feature #3, value<1.65': [1, 2]}, {'feature #3, value<1.55': [2, {'feature #2, value<5.449999999999999': [1, 2]}]}]}, {'feature #2, value<4.85': [{'feature #1, value<3.1': [2, 1]}, 2]}]}]}
