In [45]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes

def split_by_min_thresh(X, y, feat_index):
    thresh = np.unique(X[:,feat_index]).reshape(1,-1)
    selected_feat = X[:,feat_index].reshape(-1,1)
    is_left_sampler = (selected_feat <= thresh)
    is_right_sampler = ~(is_left_sampler)
    stacked_targets = np.hstack([y.reshape(-1,1)] * is_left_sampler.shape[1])
    left_sampled, right_sampled = is_left_sampler * stacked_targets, is_right_sampler * stacked_targets

    n_samples_left, n_samples_right = np.sum(is_left_sampler, axis=0), np.sum(is_right_sampler, axis=0)
    left_samples_sum, right_samples_sum = np.sum(left_sampled, axis=0), np.sum(right_sampled, axis=0)

    with np.errstate(divide='ignore', invalid='ignore'):
        left_samples_mean = np.where(n_samples_left == 0, 0, left_samples_sum / n_samples_left)
        right_samples_mean = np.where(n_samples_right == 0, 0, right_samples_sum / n_samples_right)
    
    var_left = (left_sampled - left_samples_mean.reshape(1,-1)) ** 2
    var_right = (right_sampled - right_samples_mean.reshape(1,-1)) ** 2

    # This is important. We will be computing the wrong variance without this
    var_left = np.where(left_sampled == 0, 0, var_left)
    var_right = np.where(right_sampled == 0, 0, var_right)

    with np.errstate(divide='ignore', invalid='ignore'):
        var_left = np.where(n_samples_left == 0, 0, np.sum(var_left, axis=0) / n_samples_left)
        var_right = np.where(n_samples_right == 0, 0, np.sum(var_right, axis=0) / n_samples_right)

    total_samples = left_sampled.shape[0]
    assert total_samples == right_sampled.shape[0], 'wrong total samples'

    ratio_left, ratio_right = n_samples_left / total_samples, n_samples_right / total_samples
    thresh_cost = ratio_left * var_left + ratio_right * var_right
    min_thresh_index = np.argmin(thresh_cost)

    min_cost = np.min(thresh_cost)
    left_indices = np.nonzero(is_left_sampler[:, min_thresh_index])[0]
    right_indices = np.nonzero(is_right_sampler[:, min_thresh_index])[0]
    min_thresh = thresh.reshape(-1)[min_thresh_index]

    return left_indices, right_indices, min_cost, min_thresh 

In [46]:
data = load_diabetes()
X,y = data['data'], data['target']

feat_count = X.shape[1]

for feat_index in range(feat_count):
    left_indices, right_indices, min_cost, min_thresh = split_by_min_thresh(X, y, feat_index)
    print(f'[{feat_index}/{feat_count}] left_count = {left_indices.shape[0]}, right_count = {right_indices.shape[0]} min_cost = {min_cost}, min_thresh = {min_thresh}')

[0/10] left_count = 227, right_count = 215 min_cost = 5700.035157086954, min_thresh = 0.005383060374248237
[1/10] left_count = 235, right_count = 207 min_cost = 5918.888899586022, min_thresh = -0.044641636506989144
[2/10] left_count = 277, right_count = 165 min_cost = 4279.164764194539, min_thresh = 0.008883414898524095
[3/10] left_count = 307, right_count = 135 min_cost = 4919.231731635831, min_thresh = 0.0218723855140367
[4/10] left_count = 259, right_count = 183 min_cost = 5572.695496316517, min_thresh = 0.005310804470794357
[5/10] left_count = 294, right_count = 148 min_cost = 5658.358681622334, min_thresh = 0.017161881819363848
[6/10] left_count = 180, right_count = 262 min_cost = 5046.367625854413, min_thresh = -0.01762938102341632
[7/10] left_count = 173, right_count = 269 min_cost = 4866.073277556504, min_thresh = -0.014400620678474476
[8/10] left_count = 218, right_count = 224 min_cost = 4201.076466066314, min_thresh = -0.00422151393810765
[9/10] left_count = 348, right_count 