In [1]:
import numpy as np
from dataclasses import dataclass

@dataclass
class SplitData:
    """
    A dataclass for representing the result of a data split at a Decision Tree node, including details necessary for propagating splits across the tree.

    Attributes:
        left_partition_mask (np.ndarray): A boolean mask indicating samples allocated to the left child node based on the split.
        right_partition_mask (np.ndarray): A boolean mask indicating samples allocated to the right child node based on the split.
        left_count (int): Number of samples in the left partition, facilitating efficient calculation of statistics or further splits.
        right_count (int): Number of samples in the right partition, facilitating efficient calculation of statistics or further splits.
        feat_index (int): Index of the feature used for splitting, enabling traceability of the decision path within the tree.
        cost (float): The cost or impurity reduction achieved by this split, used to assess the quality of the split.
        thresh (float): Threshold value for the feature at this split, determining the boundary between left and right partitions.
        
    This structure is crucial for the iterative construction and analysis of Decision Trees, where each node's decision to split data affects subsequent nodes' decisions.
    """
    left_partition_mask: np.ndarray
    right_partition_mask: np.ndarray
    left_count: int
    right_count: int 
    feat_index: int 
    cost: float 
    thresh: float 

In [4]:
class TreeNode:
    def __init__(self, 
                    X : np.array, 
                    y : np.array,
                    max_depth=5, 
                    min_leaves=1,
                    depth=0,
                    partition=None
                 ):
        
        self.depth = depth
        self.max_depth = max_depth 
        self.min_leaves = min_leaves

        split_data = self.split_by_min_feat_thresh(X, y, partition)
        
        is_not_max_depth = self.max_depth <= next_depth        
        left_is_not_max_leaves = min_leaves < split_data.left_count
        right_is_not_max_leaves = min_leaves < split_data.right_count
        
        grow_left = is_not_max_depth and left_is_not_max_leaves
        grow_right = is_not_max_depth and right_is_not_max_leaves

        next_depth = self.depth + 1

        print(f'[{self.depth}/{self.max_depth}] cost={split_data.cost}, feature_index={split_data.feat_index}, left_count={split_data.left_count}, right_count={split_data.right_count}')
        if grow_left:
            self.left_node = TreeNode(
                                X, y, 
                                max_depth=max_depth, 
                                min_leaves=min_leaves,
                                depth=next_depth,
                                partition=split_data.left_partition_mask
                            )

        if grow_right:
            self.right_node = TreeNode(
                                X,y,
                                max_depth=max_depth,
                                min_leaves=min_leaves,
                                depth=next_depth,
                                partition=split_data.right_partition_mask
                            )
    
    def split_by_min_thresh(self, X, y, feat_index, partition_mask=None):
        thresh = np.unique(X[:,feat_index]).reshape(1,-1)
        selected_feat = X[:,feat_index].reshape(-1,1)
        is_left_sampler = (selected_feat <= thresh)
        is_right_sampler = ~is_left_sampler

        if partition_mask is not None:
            partition_mask = np.hstack(is_left_sampler.shape[1] * [ partition_mask.reshape(-1,1) ])
            is_left_sampler = is_left_sampler & partition_mask
            is_right_sampler = is_right_sampler & partition_mask

        stacked_targets = np.hstack([y.reshape(-1,1)] * is_left_sampler.shape[1])
        left_sampled, right_sampled = is_left_sampler * stacked_targets, is_right_sampler * stacked_targets

        n_samples_left, n_samples_right = np.sum(is_left_sampler, axis=0), np.sum(is_right_sampler, axis=0)

        with np.errstate(divide='ignore', invalid='ignore'):
            left_samples_mean = np.where(n_samples_left == 0, 0, np.sum(left_sampled, axis=0) / n_samples_left)
            right_samples_mean = np.where(n_samples_right == 0, 0, np.sum(right_sampled, axis=0) / n_samples_right)
        
        var_left = is_left_sampler * (left_sampled - left_samples_mean.reshape(1,-1)) ** 2
        var_right = is_right_sampler * (right_sampled - right_samples_mean.reshape(1,-1)) ** 2

        with np.errstate(divide='ignore', invalid='ignore'):
            var_left = np.where(n_samples_left == 0, 0, np.sum(var_left, axis=0) / n_samples_left)
            var_right = np.where(n_samples_right == 0, 0, np.sum(var_right, axis=0) / n_samples_right)

        total_samples = left_sampled.shape[0]
        assert total_samples == right_sampled.shape[0], 'wrong total samples'

        ratio_left, ratio_right = n_samples_left / total_samples, n_samples_right / total_samples
        thresh_cost = ratio_left * var_left + ratio_right * var_right
        min_thresh_index = np.argmin(thresh_cost)

        return SplitData(
            left_partition_mask=is_left_sampler[:, min_thresh_index],
            right_partition_mask=is_right_sampler[:, min_thresh_index],
            left_count=np.sum(is_left_sampler[:, min_thresh_index]),
            right_count=np.sum(is_right_sampler[:, min_thresh_index]),
            feat_index=feat_index,
            cost=thresh_cost,
            thresh=thresh.reshape(-1)[min_thresh_index] 
        )

    def split_by_min_feat_thresh(self, X, y, partition):
        feat_count = X.shape[1]
        min_split_data = self.split_by_min_thresh(X, y, 0)

        for feat_index in range(1, feat_count):
            split_data = self.split_by_min_thresh(X, y, feat_index, partition)
            if min_split_data.cost > split_data.cost:
                min_split_data = split_data

        return min_split_data


In [5]:
from sklearn.datasets import load_diabetes
import pandas as pd

data = load_diabetes()
X,y = data['data'], data['target']

node = TreeNode(X,y)

ValueError: operands could not be broadcast together with shapes (58,) (2,) 