In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from pydantic import BaseModel
from typing import Optional

class SplitData(BaseModel):
     left_indices: np.array
     right_indices: np.array
     feat_index: np.int32
     cost: np.float32
     thresh: np.float32
     

class TreeNode:
    def __init__(self, 
                 X : np.array, 
                 y : np.array,
                 split_data
                 depth, max_depth, 
                 max_leaves):
        self.depth = depth
        
        self.max_depth = max_depth 
        self.max_leaves = self.max_leaves

        if self.max_depth > self.depth + 1:
            self.left_node = TreeNode()

    
    def split_by_min_thresh(self, X, y, feat_index):
        thresh = np.unique(X[:,feat_index]).reshape(1,-1)
        selected_feat = X[:,feat_index].reshape(-1,1)
        is_left_sampler = (selected_feat <= thresh)
        stacked_targets = np.hstack([y.reshape(-1,1)] * is_left_sampler.shape[1])
        left_sampled, right_sampled = is_left_sampler * stacked_targets, (~is_left_sampler) * stacked_targets

        n_samples_left, n_samples_right = np.sum(is_left_sampler, axis=0), np.sum(~is_left_sampler, axis=0)

        with np.errstate(divide='ignore', invalid='ignore'):
            left_samples_mean = np.where(n_samples_left == 0, 0, np.sum(left_sampled, axis=0) / n_samples_left)
            right_samples_mean = np.where(n_samples_right == 0, 0, np.sum(right_sampled, axis=0) / n_samples_right)
        
        var_left = is_left_sampler * (left_sampled - left_samples_mean.reshape(1,-1)) ** 2
        var_right = (~is_left_sampler) * (right_sampled - right_samples_mean.reshape(1,-1)) ** 2

        with np.errstate(divide='ignore', invalid='ignore'):
            var_left = np.where(n_samples_left == 0, 0, np.sum(var_left, axis=0) / n_samples_left)
            var_right = np.where(n_samples_right == 0, 0, np.sum(var_right, axis=0) / n_samples_right)

        total_samples = left_sampled.shape[0]
        assert total_samples == right_sampled.shape[0], 'wrong total samples'

        ratio_left, ratio_right = n_samples_left / total_samples, n_samples_right / total_samples
        thresh_cost = ratio_left * var_left + ratio_right * var_right
        min_thresh_index = np.argmin(thresh_cost)

        return SplitData(
            left_indices=np.nonzero(is_left_sampler[:, min_thresh_index])[0],
            right_indices=np.nonzero(~is_left_sampler[:, min_thresh_index])[0],
            feat_index=feat_index,
            cost=np.min(thresh_cost),
            thresh = thresh.reshape(-1)[min_thresh_index],
        ) 

    def split_by_min_feat_thresh(self, X, y):
        feat_count = X.shape[1]
        min_split_data = self.split_by_min_thresh(X, y, 0)

        for feat_index in range(1, feat_count):
            split_data = self.split_by_min_thresh(X, y, feat_index)
            if min_split_data.cost > split_data.cost:
                min_split_data = split_data

        return min_split_data


PydanticSchemaGenerationError: Unable to generate pydantic-core schema for <built-in function array>. Set `arbitrary_types_allowed=True` in the model_config to ignore this error or implement `__get_pydantic_core_schema__` on your type to fully support it.

If you got this error by calling handler(<some type>) within `__get_pydantic_core_schema__` then you likely need to call `handler.generate_schema(<some type>)` since we do not call `__get_pydantic_core_schema__` on `<some type>` otherwise to avoid infinite recursion.

For further information visit https://errors.pydantic.dev/2.6/u/schema-for-unknown-type