In [4]:
import numpy as np
from dataclasses import dataclass

@dataclass
class SplitData:
     left_partition_mask: np.ndarray
     right_partition_mask: np.ndarray
     left_count: int
     right_count: int 
     feat_index: int 
     cost: float 
     thresh: float 

class TreeNode:
    def __init__(self, 
                    X : np.array, 
                    y : np.array,
                    max_depth=5, 
                    min_leaves=1,
                    depth=0,
                    indices=None 
                 ):
        
        if indices is None:
            X_partition = X
            y_partition = y
        else:
            X_partition = X[indices,:]
            y_partition = y[indices]

        self.depth = depth
        self.max_depth = max_depth 
        self.min_leaves = self.min_leaves
        # X and y must be consumed as a partition
        # TODO: I need to guarantee that partioning indices remains consistent
        split_data = self.split_by_min_feat_thresh(X_partition, y_partition)

        next_depth = self.depth + 1
        is_not_max_depth = self.max_depth <= next_depth
        
        left_is_not_max_leaves = min_leaves < split_data.left_indices.shape[0]
        right_is_not_max_leaves = min_leaves < split_data.left_indices.shape[0]
        
        grow_left = is_not_max_depth and left_is_not_max_leaves
        grow_right = is_not_max_depth and right_is_not_max_leaves

        if grow_left:
        # TODO: I need to guarantee that partioning indices remains consistent
            self.left_node = TreeNode(
                                X, y, 
                                max_depth=max_depth, 
                                min_leaves=min_leaves,
                                depth=next_depth,
                                indices=split_data.left_indices
                            )

        if grow_right:
        # TODO: I need to guarantee that partioning indices remains consistent
            self.right_node = TreeNode()
    
    def split_by_min_thresh(self, X, y, feat_index, partition_mask=None):
        thresh = np.unique(X[:,feat_index]).reshape(1,-1)
        selected_feat = X[:,feat_index].reshape(-1,1)
        is_left_sampler = (selected_feat <= thresh)
        is_right_sampler = ~is_left_sampler

        if partition_mask is not None:
            partition_mask = np.hstack(is_left_sampler.shape[1] * [ partition_mask.reshape(-1,1) ])
            is_left_sampler = is_left_sampler & partition_mask
            is_right_sampler = is_right_sampler & partition_mask

        stacked_targets = np.hstack([y.reshape(-1,1)] * is_left_sampler.shape[1])
        left_sampled, right_sampled = is_left_sampler * stacked_targets, is_right_sampler * stacked_targets

        n_samples_left, n_samples_right = np.sum(is_left_sampler, axis=0), np.sum(is_right_sampler, axis=0)

        with np.errstate(divide='ignore', invalid='ignore'):
            left_samples_mean = np.where(n_samples_left == 0, 0, np.sum(left_sampled, axis=0) / n_samples_left)
            right_samples_mean = np.where(n_samples_right == 0, 0, np.sum(right_sampled, axis=0) / n_samples_right)
        
        var_left = is_left_sampler * (left_sampled - left_samples_mean.reshape(1,-1)) ** 2
        var_right = is_right_sampler * (right_sampled - right_samples_mean.reshape(1,-1)) ** 2

        with np.errstate(divide='ignore', invalid='ignore'):
            var_left = np.where(n_samples_left == 0, 0, np.sum(var_left, axis=0) / n_samples_left)
            var_right = np.where(n_samples_right == 0, 0, np.sum(var_right, axis=0) / n_samples_right)

        total_samples = left_sampled.shape[0]
        assert total_samples == right_sampled.shape[0], 'wrong total samples'

        ratio_left, ratio_right = n_samples_left / total_samples, n_samples_right / total_samples
        thresh_cost = ratio_left * var_left + ratio_right * var_right
        min_thresh_index = np.argmin(thresh_cost)

        return SplitData(
            left_partition_mask=is_left_sampler[:, min_thresh_index],
            right_partition_mask=is_right_sampler[:, min_thresh_index],
            left_count=np.sum(is_left_sampler[:, min_thresh_index]),
            right_count=np.sum(is_right_sampler[:, min_thresh_index]),
            feat_index=feat_index,
            cost=thresh_cost,
            thresh=thresh.reshape(-1)[min_thresh_index] 
        )

    def split_by_min_feat_thresh(self, X, y):
        feat_count = X.shape[1]
        min_split_data = self.split_by_min_thresh(X, y, 0)

        for feat_index in range(1, feat_count):
            split_data = self.split_by_min_thresh(X, y, feat_index)
            if min_split_data.cost > split_data.cost:
                min_split_data = split_data

        return min_split_data
