### Refactoring Notebook
- This is to verify if we are properly abstracting the functions necessary for Decision Tree Classifier like class probabilities, gini index, and entropy

In [1]:
from dataclasses import dataclass
from sklearn.datasets import load_iris
import numpy as np
import matplotlib.pyplot as plt

In [29]:
def get_one_hot_encode(y):
    num_categories = np.unique(y).shape[0]
    return np.eye(num_categories)[y], num_categories

def stack_target_wrt_thresh(y_one_hot, unique_thresholds):
    num_thresholds = unique_thresholds.shape[1]
    return np.stack(num_thresholds * [y_one_hot], axis=1)

def get_unique_thresholds(X, feature_index):
    return np.unique(X[:,feature_index]).reshape(1,-1)

def get_left_right_samplers(X, unique_thresholds, feature_index):
    selected_features = X[:, feature_index].reshape(-1,1)
    left_sampler = selected_features <= unique_thresholds
    right_sampler = ~left_sampler
    return left_sampler, right_sampler

def get_class_probabilities(y_one_hot_3d, sampler_mask, num_categories):
    sampler_mask_3d = np.stack(num_categories * [sampler_mask], axis=2)
    
    sampled = sampler_mask_3d * y_one_hot_3d
    num_sampled = np.sum(sampler_mask, axis=0)
    num_sampled = np.stack( 
        num_categories * [ np.sum(sampler_mask, axis=0) ], 
        axis=1
    )
    indicator_sum = np.sum(sampled, axis=0)

    class_probs = np.where(num_sampled == 0, 0, indicator_sum / num_sampled)
    return class_probs 

def get_gini_index(class_probs):
    return 1 - np.sum(class_probs ** 2, axis=1)

def get_entropy(class_probs):
    inv_prob = np.where(class_probs > 0.0, 1 / class_probs, 0.0)
    log_prob = np.where(inv_prob > 0.0, np.log2(inv_prob), 0.0)

    return np.sum(class_probs * log_prob, axis=1)



In [26]:
data = load_iris()
X,y = data['data'], data['target']

In [27]:
feature_index = 2
unique_thresholds = get_unique_thresholds(X, feature_index)
left_sampler, right_sampler = get_left_right_samplers(X, unique_thresholds, feature_index)
y_one_hot, num_categories = get_one_hot_encode(y)
y_one_hot_3d = stack_target_wrt_thresh(y_one_hot, unique_thresholds)

# print(left_sampler.shape)
class_probs_left = get_class_probabilities(y_one_hot_3d, left_sampler, num_categories)
# class_probs_right = get_class_probabilities(y_one_hot_3d, right_sampler, num_categories)

sampler_mask_3d.shape =  (150, 43, 3)
sampled.shape =  (150, 43, 3)
num_sampled.shape =  (43, 3)
indicator_sum.shape =  (43, 3)
