### Refactoring Notebook
- This is to verify if we are properly abstracting the functions necessary for Decision Tree Classifier like class probabilities, gini index, and entropy

In [1]:
from dataclasses import dataclass
from sklearn.datasets import load_iris
import numpy as np
import matplotlib.pyplot as plt

In [5]:
def one_hot_encode(y):
    num_categories = np.unique(y).shape[0]
    return np.eye(num_categories)[y], num_categories

In [6]:
def unique_thresholds(X, feat_index):
    return np.unique(X[:,feat_index]).reshape(1,-1)

def class_probabilities(y_one_hot_3d, sampler_mask, num_categories):
    sampler_mask_3d = np.stack(num_categories * [sampler_mask], axis=2)
    sampled = sampler_mask_3d * y_one_hot_3d
    num_sampled = np.sum(sampler_mask, axis=0)
    indicator_sum = np.sum(sampled, axis=0)
    class_probs = np.where(num_sampled == 0, 0, indicator_sum / num_sampled)
    return class_probs 

def gini_index(class_probs):
    return 1 - np.sum(class_probs ** 2, axis=1)

def entropy(class_probs):
    inv_prob = np.where(class_probs > 0.0, 1 / class_probs, 0.0)
    log_prob = np.where(inv_prob > 0.0, np.log2(inv_prob), 0.0)

    return np.sum(class_probs * log_prob, axis=1)