In [1]:
import numpy as np

In [10]:
class Node:
    def __init_(self, predicted_class, depth=None):
        self.predicted_class = predicted_class  # The majority class
        self.feature_index = None               # Feature index of the split.
        self.threshold = None                   # The threshold value of the split.
        self.depth = None                       # Current depth of the node.
        self.n_sample = None                    # No. of sample in the associated data
        self.sample_idx = None                  # Indices of samples that belongs to the node.
        self.entropy = None                     # Entropy of the sample which belongs to the node.

        # If the node has a child node.
        self.left = None    # The left child node.
        self.right = None   # The right child node.

In [4]:
class DecisionTreeClassifier:
    def __init__(self, min_sample_split=2, max_depth=5):
        """
            min_sample_spilt: Min no. of samples required to produce a split.
            max_depth: Maximum allowed depth.
        """
        self.min_sample_spilt = min_sample_split
        self.max_depth = max_depth
        self.tree = None    # The root node.
        self.x_train = None # The training predictor variables.
        self.y_train = None # The training response variables.
        self.n_sample = None# No. of training samples.
        self.n_feat = None  # Feature dimension.

    def fit(self, x_train:np.array, y_train:np.array):
        """
            x_train: The predictor values for training.
            y_train: The response values corresponding to the training samples.        
        """
        self.x_train = x_train
        self.y_train = y_train
        self.n_sample, self.n_feat = self.x_train.shape
        self.tree = self._create_root_node()
        self.build_tree(self.tree)

    def _create_root_node(self) -> Node:
        # The class with highest frequency of training samples.
        max_class = self.find_most_freq_class(self.y_train)
        root_node = Node(max_class)
        root_node.n_sample = self.n_sample
        root_node.depth = 0
        root_node.entropy = self.find_entropy(self.y_train)
        root_node.sample_idx = np.array(np.range(self.n_sample))
        return root_node
    
    def find_entropy(self, y: np.array) -> float:
        """
            Calculate entropy of the system.

            input
                y: response variable

            output:
                Entropy of the system
        """
        _, counts = np.unique(y, return_counts=True)
        n_sample = len(y)
        count_frac = counts / n_sample
        count_frac_log = np.log(count_frac)
        out = -1 * (count_frac * count_frac_log).sum()
        return out

    def find_most_freq_class(self, y: np.array) -> int:
        """
            Class with most frequent sample
        """
        elements, counts = np.unique(y, return_counts=True)
        highest_count = counts.argmax()
        out = elements[highest_count]
        return out


    def get_spilt(self, X, y):
        pass

    def _get_split_info(self, feature_idx, thr, sample_idx):
        left_idx = self.x_train[self.x_train[sample_idx, feature_idx] < thr]
        right_idx = self.x_train[self.x_train[sample_idx, feature_idx] > thr]

        y_left = self.x_train[left_idx], self.y_train[left_idx]
        y_right = self.x_train[right_idx], self.y_train[right_idx]

        left_entropy = self.find_entropy(y_left)
        right_entropy = self.find_entropy(y_right)

        out = (len(y_left)/len(sample_idx)) * left_entropy + (len(y_right)/len(sample_idx)) * right_entropy
        return out

    def get_best_split(self, node:Node) -> tuple[Node, Node]:
        """
            For a given data find the best split based on information gain.

            return
            LHS subtree [Node]
            RHS subtree [Node]
        """
        best_gain = -np.infty
        for feature_idx in range(self.n_feat):
            feature_vals = self.x_train[node.sample_idx, feature_idx]
            potential_thrs = np.unique(feature_vals)

            for thr in potential_thrs:
                split_info = self._get_split_info(feature_idx, thr, node.sample_idx)
                information_gain = node.entropy - split_info

                





    def build_tree(self, node: Node):
        """
            Recursively build the decision tree by splitting the training data based on certain criteria.
        """
        if node.n_sample > self.min_sample_spilt and node.depth <= self.max_depth:
            left_tree, right_tree = self.get_best_split(node)