### Explore JAX Bug

In [2]:
from dataclasses import dataclass
from sklearn.datasets import load_diabetes
import jax.numpy as jnp

In [23]:
@dataclass
class SplitData:
    """
    A dataclass for representing the result of a data split at a Decision Tree node, including details necessary for propagating splits across the tree.

    Attributes:
        left_partition_mask (np.ndarray): A boolean mask indicating samples allocated to the left child node based on the split.
        right_partition_mask (np.ndarray): A boolean mask indicating samples allocated to the right child node based on the split.
        left_count (int): Number of samples in the left partition, facilitating efficient calculation of statistics or further splits.
        right_count (int): Number of samples in the right partition, facilitating efficient calculation of statistics or further splits.
        feat_index (int): Index of the feature used for splitting, enabling traceability of the decision path within the tree.
        cost (float): The cost or impurity reduction achieved by this split, used to assess the quality of the split.
        thresh (float): Threshold value for the feature at this split, determining the boundary between left and right partitions.
        
    This structure is crucial for the iterative construction and analysis of Decision Trees, where each node's decision to split data affects subsequent nodes' decisions.
    """
    left_partition_mask: jnp.ndarray
    right_partition_mask: jnp.ndarray
    left_count: int
    right_count: int 
    feat_index: int 
    cost: float 
    thresh: float

    def __str__(self):
        cost_rounded = float(self.cost)
        cost_rounded = round(cost_rounded, 3)

        thresh_rounded = float(self.thresh)
        thresh_rounded = round(thresh_rounded, 3)

        return ', '.join([
            f'cost={cost_rounded}', 
            f'feature_index={self.feat_index}', 
            f'left_count={self.left_count}',
            f'right_count={self.right_count}', 
            f'thresh={thresh_rounded}',
        ])
    

class TreeNode:

    def __init__(self, 
                    X : jnp.ndarray, 
                    y : jnp.ndarray,
                    max_depth=5, 
                    min_leaves=1,
                    depth_index=0,
                    width_index=0,
                    node_type='init_node',
                    feat_index=0,
                    partition=None,
                    thresh=None
                 ):
        
        self.depth_index = depth_index
        self.width_index = width_index

        self.max_depth = max_depth 
        self.min_leaves = min_leaves
        self.node_type = node_type

        if partition is None:
            self.partition = jnp.ones((y.shape[0],))
            self.num_samples = y.shape[0]
            self.thresh = jnp.max(X[:,0])
            self.feat_index = 0
        else:
            self.partition = partition
            self.num_samples = jnp.sum(partition)
            self.thresh = thresh
            self.feat_index = feat_index

        self.out_per_node = jnp.sum(y * self.partition) / self.num_samples
        split_data = self.split_by_min_feat_thresh(X, y, partition)
        
        next_depth = self.depth_index + 1
        is_not_max_depth = self.max_depth >= next_depth        
        left_is_not_max_leaves = min_leaves <= split_data.left_count
        right_is_not_max_leaves = min_leaves <= split_data.right_count
        
        grow_left = (is_not_max_depth and left_is_not_max_leaves) and split_data.cost != 0.0
        grow_right = (is_not_max_depth and right_is_not_max_leaves) and split_data.cost != 0.0

        self.left_node, self.right_node = None, None
        
        print("current_depth = ", self.depth_index)
        print('split_data:\n', split_data)

        if grow_left:
            self.left_node = TreeNode(
                                X, y, 
                                max_depth=max_depth, 
                                min_leaves=min_leaves,
                                depth_index=next_depth,
                                width_index=self.width_index,
                                node_type='left_node',
                                feat_index=split_data.feat_index,
                                partition=split_data.left_partition_mask,
                                thresh=split_data.thresh,
                            )

        if grow_right:
            self.right_node = TreeNode(
                                X,y,
                                max_depth=max_depth,
                                min_leaves=min_leaves,
                                depth_index=next_depth,
                                width_index=self.width_index + 1,
                                node_type='right_node',
                                feat_index=split_data.feat_index,
                                partition=split_data.right_partition_mask,
                                thresh=split_data.thresh
                            )
    
    def split_by_min_thresh(self, X, y, feat_index, partition_mask=None):
        thresh = jnp.unique(X[:,feat_index]).reshape(1,-1)
        selected_feat = X[:,feat_index].reshape(-1,1)
        is_left_sampler = (selected_feat <= thresh)
        is_right_sampler = ~is_left_sampler

        if partition_mask is not None:
            partition_mask = jnp.hstack(is_left_sampler.shape[1] * [ partition_mask.reshape(-1,1) ])
            is_left_sampler = is_left_sampler & partition_mask
            is_right_sampler = is_right_sampler & partition_mask

        stacked_targets = jnp.hstack([y.reshape(-1,1)] * is_left_sampler.shape[1])
        left_sampled, right_sampled = is_left_sampler * stacked_targets, is_right_sampler * stacked_targets

        n_samples_left, n_samples_right = jnp.sum(is_left_sampler, axis=0), jnp.sum(is_right_sampler, axis=0)

        # with np.errstate(divide='ignore', invalid='ignore'):
        left_samples_mean = jnp.where(n_samples_left == 0, 0, jnp.sum(left_sampled, axis=0) / n_samples_left)
        right_samples_mean = jnp.where(n_samples_right == 0, 0, jnp.sum(right_sampled, axis=0) / n_samples_right)
        
        var_left = is_left_sampler * (left_sampled - left_samples_mean.reshape(1,-1)) ** 2
        var_right = is_right_sampler * (right_sampled - right_samples_mean.reshape(1,-1)) ** 2

        # with np.errstate(divide='ignore', invalid='ignore'):
        var_left = jnp.where(n_samples_left == 0, 0, jnp.sum(var_left, axis=0) / n_samples_left)
        var_right = jnp.where(n_samples_right == 0, 0, jnp.sum(var_right, axis=0) / n_samples_right)

        total_samples = left_sampled.shape[0]
        assert total_samples == right_sampled.shape[0], 'wrong total samples'

        ratio_left, ratio_right = n_samples_left / total_samples, n_samples_right / total_samples
        thresh_cost = ratio_left * var_left + ratio_right * var_right
        min_thresh_index = jnp.argmin(thresh_cost)

        return SplitData(
            left_partition_mask=is_left_sampler[:, min_thresh_index],
            right_partition_mask=is_right_sampler[:, min_thresh_index],
            left_count=jnp.sum(is_left_sampler[:, min_thresh_index]),
            right_count=jnp.sum(is_right_sampler[:, min_thresh_index]),
            feat_index=feat_index,
            cost=jnp.min(thresh_cost),
            thresh=thresh.reshape(-1)[min_thresh_index] 
        )

    def split_by_min_feat_thresh(self, X, y, partition):
        feat_count = X.shape[1]
        min_split_data = self.split_by_min_thresh(X, y, 0)
        # print('min_split_data:\n', min_split_data)
        for feat_index in range(1, feat_count):
            split_data = self.split_by_min_thresh(X, y, feat_index, partition)
            # print('min_split_data:\n', min_split_data)
            if min_split_data.cost > split_data.cost:
                min_split_data = split_data

        return min_split_data
    
    def prediction_per_node(self, X, y_preds, prev_mask=None):
        if self.node_type == 'right_node':
            prediction_mask = X[:, self.feat_index] > self.thresh
        else:
            prediction_mask = X[:,self.feat_index] <= self.thresh
        
        if prev_mask is not None:
            prediction_mask = prediction_mask & prev_mask

        prediction_indices = jnp.nonzero(prediction_mask)[0]
        y_preds[prediction_indices] = self.out_per_node
        
        if self.left_node is not None:
            self.left_node.prediction_per_node(X, y_preds, prediction_mask)
        
        if self.right_node is not None:
            self.right_node.prediction_per_node(X, y_preds, prediction_mask)




In [24]:
class DecisionTreeRegressor:
    def __init__(
            self,
            max_depth=5, 
            min_leaves=1, 
        ):
        self.max_depth=max_depth
        self.min_leaves=min_leaves

    def fit(self, X, y):
        self.node = TreeNode(X, y, 
                             max_depth=self.max_depth, 
                             min_leaves=self.min_leaves, 
                             depth_index=0
                            )
        
    def predict(self, X):
        y_preds = jnp.zeros((X.shape[0],))
        self.node.prediction_per_node(X, y_preds)
        return y_preds

In [25]:

data = load_diabetes()
X,y = data['data'], data['target']
X,y = jnp.array(X), jnp.array(y)

max_depth = 5 
min_leaves = 1

model = DecisionTreeRegressor(max_depth=max_depth, min_leaves=min_leaves)
model.fit(X,y)

current_depth =  0
split_data:
 cost=4201.077, feature_index=8, left_count=218, right_count=224, thresh=-0.004
current_depth =  1
split_data:
 cost=1262.777, feature_index=2, left_count=171, right_count=47, thresh=0.006
current_depth =  2
split_data:
 cost=766.899, feature_index=6, left_count=87, right_count=84, thresh=0.019
current_depth =  3
split_data:
 cost=480.771, feature_index=4, left_count=85, right_count=2, thresh=0.049
current_depth =  4
split_data:
 cost=451.92, feature_index=1, left_count=39, right_count=46, thresh=-0.045
current_depth =  5
split_data:
 cost=225.88, feature_index=3, left_count=21, right_count=18, thresh=-0.037
current_depth =  5
split_data:
 cost=163.552, feature_index=9, left_count=18, right_count=28, thresh=-0.03
current_depth =  4
split_data:
 cost=0.0, feature_index=1, left_count=1, right_count=1, thresh=-0.045
current_depth =  3
split_data:
 cost=191.536, feature_index=1, left_count=64, right_count=20, thresh=-0.045
current_depth =  4
split_data:
 cost