In [8]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes

def split_by_min_thresh(X, y, feat_index, partition_mask=None):
    thresh = np.unique(X[:,feat_index]).reshape(1,-1)
    selected_feat = X[:,feat_index].reshape(-1,1)
    is_left_sampler = (selected_feat <= thresh)
    is_right_sampler = ~is_left_sampler 

    if partition_mask:
        is_left_sampler = is_left_sampler & partition_mask

    stacked_targets = np.hstack([y.reshape(-1,1)] * is_left_sampler.shape[1])
    left_sampled, right_sampled = is_left_sampler * stacked_targets, is_right_sampler * stacked_targets

    n_samples_left, n_samples_right = np.sum(is_left_sampler, axis=0), np.sum(is_right_sampler, axis=0)

    with np.errstate(divide='ignore', invalid='ignore'):
        left_samples_mean = np.where(n_samples_left == 0, 0, np.sum(left_sampled, axis=0) / n_samples_left)
        right_samples_mean = np.where(n_samples_right == 0, 0, np.sum(right_sampled, axis=0) / n_samples_right)
    
    var_left = is_left_sampler * (left_sampled - left_samples_mean.reshape(1,-1)) ** 2
    var_right = (~is_left_sampler) * (right_sampled - right_samples_mean.reshape(1,-1)) ** 2

    with np.errstate(divide='ignore', invalid='ignore'):
        var_left = np.where(n_samples_left == 0, 0, np.sum(var_left, axis=0) / n_samples_left)
        var_right = np.where(n_samples_right == 0, 0, np.sum(var_right, axis=0) / n_samples_right)

    total_samples = left_sampled.shape[0]
    assert total_samples == right_sampled.shape[0], 'wrong total samples'

    ratio_left, ratio_right = n_samples_left / total_samples, n_samples_right / total_samples
    thresh_cost = ratio_left * var_left + ratio_right * var_right
    min_thresh_index = np.argmin(thresh_cost)

    min_cost = np.min(thresh_cost)
    # left_indices = np.nonzero(is_left_sampler[:, min_thresh_index])[0]
    # right_indices = np.nonzero(~is_left_sampler[:, min_thresh_index])[0]
    left_partition_mask = is_left_sampler[:, min_thresh_index]
    right_partition_mask = is_right_sampler[:, min_thresh_index]
    left_count = np.sum(is_left_sampler[:, min_thresh_index])
    right_count = np.sum(is_right_sampler[:, min_thresh_index])    
    min_thresh = thresh.reshape(-1)[min_thresh_index]

    return left_partition_mask, right_partition_mask, left_count, right_count, min_cost, min_thresh

def split_by_min_feat_thresh(X, y):
    feat_count = X.shape[1]
    data_count = X.shape[0]

    min_left_count, min_right_count, min_cost, min_thresh = split_by_min_thresh(X, y, 0)
    min_feat_index = 0

    for feat_index in range(1, feat_count):
        left_count, right_count, cost, thresh = split_by_min_thresh(X, y, feat_index)
        print(f'[{feat_index}/{feat_count}] cost = {cost}, thresh = {thresh}')
        assert left_count + right_count == data_count, 'left_count + right_count != data_count'
        if min_cost > cost:
            min_left_count = left_count
            min_right_count = right_count
            min_cost = cost
            min_thresh = thresh
            min_feat_index = feat_index

    return min_left_count, min_right_count, min_feat_index, min_cost, min_thresh
    

In [9]:
data = load_diabetes()
X,y = data['data'], data['target']

split_by_min_feat_thresh(X,y)
# feat_count = X.shape[1]
# total_data = X.shape[0]
# min_left_indices, min_right_indices, min_cost, min_thresh = split_by_min_thresh(X, y, 0)
# min_feat_index = 0

# for feat_index in range(1, feat_count):
#     left_indices, right_indices, cost, thresh = split_by_min_thresh(X, y, feat_index)
    
#     print(f'[{feat_index}/{feat_count}] cost = {cost}, thresh = {thresh}')
    
#     assert left_indices.shape[0] + right_indices.shape[0] == total_data, 'fail'    
#     if min_cost > cost:
#         min_left_indices = left_indices
#         min_right_indices = right_indices
#         min_cost = cost
#         min_thresh = thresh
#         min_feat_index = feat_index



[1/10] cost = 5918.888899586022, thresh = -0.044641636506989144
[2/10] cost = 4279.164764194539, thresh = 0.008883414898524095
[3/10] cost = 4919.231731635831, thresh = 0.0218723855140367
[4/10] cost = 5572.695496316517, thresh = 0.005310804470794357
[5/10] cost = 5658.358681622334, thresh = 0.017161881819363848
[6/10] cost = 5046.367625854413, thresh = -0.01762938102341632
[7/10] cost = 4866.073277556504, thresh = -0.014400620678474476
[8/10] cost = 4201.076466066314, thresh = -0.00422151393810765
[9/10] cost = 5157.83877572983, thresh = 0.03205915781820968


(218, 224, 8, 4201.076466066314, -0.00422151393810765)