In [59]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes

def split_by_min_thresh(X, y, feat_index):
    thresh = np.unique(X[:,feat_index]).reshape(1,-1)
    selected_feat = X[:,feat_index].reshape(-1,1)
    is_left_sampler = (selected_feat <= thresh)
    stacked_targets = np.hstack([y.reshape(-1,1)] * is_left_sampler.shape[1])
    left_sampled, right_sampled = is_left_sampler * stacked_targets, (~is_left_sampler) * stacked_targets

    n_samples_left, n_samples_right = np.sum(is_left_sampler, axis=0), np.sum(~is_left_sampler, axis=0)

    with np.errstate(divide='ignore', invalid='ignore'):
        left_samples_mean = np.where(n_samples_left == 0, 0, np.sum(left_sampled, axis=0) / n_samples_left)
        right_samples_mean = np.where(n_samples_right == 0, 0, np.sum(right_sampled, axis=0) / n_samples_right)
    
    var_left = (left_sampled - left_samples_mean.reshape(1,-1)) ** 2
    var_right = (right_sampled - right_samples_mean.reshape(1,-1)) ** 2

    # This is important. We will be computing the wrong variance without this
    var_left = np.where(left_sampled == 0, 0, var_left)
    var_right = np.where(right_sampled == 0, 0, var_right)

    with np.errstate(divide='ignore', invalid='ignore'):
        var_left = np.where(n_samples_left == 0, 0, np.sum(var_left, axis=0) / n_samples_left)
        var_right = np.where(n_samples_right == 0, 0, np.sum(var_right, axis=0) / n_samples_right)

    total_samples = left_sampled.shape[0]
    assert total_samples == right_sampled.shape[0], 'wrong total samples'

    ratio_left, ratio_right = n_samples_left / total_samples, n_samples_right / total_samples
    thresh_cost = ratio_left * var_left + ratio_right * var_right
    min_thresh_index = np.argmin(thresh_cost)

    min_cost = np.min(thresh_cost)
    left_indices = np.nonzero(is_left_sampler[:, min_thresh_index])[0]
    right_indices = np.nonzero(~is_left_sampler[:, min_thresh_index])[0]
    min_thresh = thresh.reshape(-1)[min_thresh_index]

    return left_indices, right_indices, min_cost, min_thresh 

In [65]:
data = load_diabetes()
X,y = data['data'], data['target']

feat_count = X.shape[1]
total_data = X.shape[0]
min_left_indices, min_right_indices, min_cost, min_thresh = split_by_min_thresh(X, y, 0)
min_feat_index = 0

for feat_index in range(1, feat_count):
    left_indices, right_indices, cost, thresh = split_by_min_thresh(X, y, feat_index)
    assert left_indices.shape[0] + right_indices.shape[0] == total_data, 'fail'    
    if min_cost > cost:
        min_left_indices = left_indices
        min_right_indices = right_indices
        min_cost = cost
        min_thresh = thresh
        min_feat_index = feat_index

