In [1]:
import numpy as np
import cvxpy as cp
from collections import Counter
from numba import njit
from scipy import optimize
from nldg.utils import max_mse, min_xplvar, gen_data_v2, gen_data_v3
import matplotlib.pyplot as plt
import seaborn as sns
from numpy.random import default_rng
from scipy.optimize import root

In [54]:
dtr = gen_data_v3(n=1000, setting=2)
#dtr = gen_data_v2(n=1000)
Xtr = np.array(dtr.drop(columns=['E', 'Y']))
Ytr = np.array(dtr['Y'])
Ytr = Ytr - np.mean(Ytr)
dtr['Y'] = Ytr
Etr = np.array(dtr['E'])
Xtr_sorted = np.sort(Xtr, axis=0)
all_idx = default_rng(0).choice(np.arange(0, Xtr.shape[0]), size=Ytr.shape[0], replace=True)

Y_sample = Ytr[all_idx].flatten()
X_sample = Xtr[all_idx].flatten()
E_sample = Etr[all_idx].flatten()
unique_envs = np.unique(E_sample)
c_init = cp.Variable(1)
t = cp.Variable(nonneg=True)
constraints = []
for env in unique_envs:
    Y_e = Y_sample[E_sample == env]
    constraints.append(cp.mean(cp.square(Y_e - c_init)) <= t)
objective = cp.Minimize(t)
problem = cp.Problem(objective, constraints)
problem.solve()
best_preds = np.array([c_init.value] * Y_sample.shape[0])
alpha = 0.1 * objective.value

In [55]:
indices = all_idx

# 1. Sort indices by the splitting feature
order = np.argsort(Xtr[indices].flatten())
sorted_indices = indices[order]
sorted_X = Xtr[sorted_indices]
sorted_Y = Ytr[sorted_indices]
sorted_E = Etr[sorted_indices]

# 2. Precompute per-env prefix sums & prefix squared-sums
N = len(sorted_indices)
n_envs = len(unique_envs)
env_idx = np.arange(n_envs)[:, None]
mask_mat = (sorted_E[None, :] == env_idx)

prefix_count_mat = np.concatenate(
    [np.zeros((n_envs,1),int), np.cumsum(mask_mat, axis=1)], axis=1
)

Y_mat = sorted_Y[None, :] * mask_mat
prefix_sum_mat = np.concatenate(
    [np.zeros((n_envs,1)), np.cumsum(Y_mat, axis=1)], axis=1
)
prefix_sq_mat = np.concatenate(
    [np.zeros((n_envs,1)), np.cumsum((Y_mat * sorted_Y[None,:]), axis=1)], axis=1
)

suffix_count_mat = np.zeros((n_envs, N + 1))
suffix_sum_mat = np.zeros((n_envs, N + 1))
suffix_sq_mat = np.zeros((n_envs, N + 1))

for e in range(n_envs):
    suffix_count_mat[e, N] = 0
    suffix_sum_mat[e, N] = 0.0
    suffix_sq_mat[e, N] = 0.0

for j in range(N - 1, -1, -1):
    for e in range(n_envs):
        m = 1 if sorted_E[j] == e else 0
        suffix_count_mat[e, j] = suffix_count_mat[e, j + 1] + m
        suffix_sum_mat[e, j] = suffix_sum_mat[e, j + 1] + m * sorted_Y[j]
        suffix_sq_mat[e, j] = suffix_sq_mat[e, j + 1] + m * (sorted_Y[j] ** 2)

total_count = np.zeros(n_envs)
for e in range(n_envs):
    total_count[e] = np.sum(Etr[all_idx] == e)

# 4. Precompute "remaining" loss per env (here remaining = none if all_idx used)
diff = Counter(all_idx) - Counter(indices)
remaining = np.array(list(diff.elements()), dtype=np.intc)
rem_loss_vec = np.zeros(n_envs)
if remaining.size > 0:
    mask_rem = (Etr[remaining][None, :] == env_idx)
    Y_rem    = Ytr[remaining][None, :] * mask_rem
    P_rem    = best_preds[remaining][None, :] * mask_rem
    errs_sq  = (Y_rem - P_rem) ** 2
    rem_loss_vec = errs_sq.sum(axis=1)

# 5. Build the CVXPY problem once using Parameters
cL = cp.Variable(1)
cR = cp.Variable(1)
t  = cp.Variable()
#mean_diff_p = cp.Parameter(nonneg=True)
max_diff = cp.Parameter(nonneg=True)

# Parameters for left & right sums
S0_L = cp.Parameter(n_envs, nonneg=True)
S1_L = cp.Parameter(n_envs)
n_L  = cp.Parameter(n_envs, nonneg=True)
S0_R = cp.Parameter(n_envs, nonneg=True)
S1_R = cp.Parameter(n_envs)
n_R  = cp.Parameter(n_envs, nonneg=True)
R    = cp.Parameter(n_envs, nonneg=True)

def quad_term(c, S0, S1, n):
    # S0 - 2 c * S1 + n * c^2
    return S0 - 2 * c * S1 + n * c**2

# elementwise env losses <= t
loss_vec = (quad_term(cL, S0_L, S1_L, n_L)
          + quad_term(cR, S0_R, S1_R, n_R)
          + R) / total_count
constraints = [loss_vec <= t]
#objective   = cp.Minimize(t + alpha * mean_diff_p ** (-1))
objective   = cp.Minimize(t + alpha * max_diff)
problem     = cp.Problem(objective, constraints)

# 6. Loop over splits, update parameters, solve
best_score = np.inf
best_split = None
best_vals  = None

#@njit
#def compute_mean_diff(prefix_count, prefix_sum, suffix_count, suffix_sum, i):
#    K = prefix_count.shape[0]
#    s = 0.0
#    c = 0
#    for e in range(K):
#        nL = prefix_count[e, i]
#        nR = suffix_count[e, i]
#        if nL > 0 and nR > 0:
#            mL = prefix_sum[e, i] / nL
#            mR = suffix_sum[e, i] / nR
#            s += abs(mR - mL)
#            c += 1
#    return s / c if c > 0 else 0.0

@njit
def compute_max_prop_balance(prefix_count, suffix_count, total_count, i):
    K = prefix_count.shape[0]
    max_left_prop = 0.0
    max_right_prop = 0.0

    for e in range(K):
        total = total_count[e]
        if total > 0:
            left_prop = prefix_count[e, i] / total
            right_prop = suffix_count[e, i] / total

            if left_prop > max_left_prop:
                max_left_prop = left_prop

            if right_prop > max_right_prop:
                max_right_prop = right_prop

    return abs(max_left_prop - max_right_prop)

#for i in range(1, N):
for i in [371]:
    # mean-diff vectorized
    #mean_diff_p.value = compute_mean_diff(
    #    prefix_count_mat, prefix_sum_mat,
    #    suffix_count_mat, suffix_sum_mat, i
    #)
    max_diff.value = compute_max_prop_balance(
        prefix_count_mat, suffix_count_mat, total_count, i
    )

    # update S0, S1, n for left & right
    S0_L.value = prefix_sq_mat[:, i]
    S1_L.value = prefix_sum_mat[:, i]
    n_L.value  = prefix_count_mat[:, i]

    S0_R.value = suffix_sq_mat[:, i]
    S1_R.value = suffix_sum_mat[:, i]
    n_R.value  = suffix_count_mat[:, i]

    R.value    = rem_loss_vec

    problem.solve(warm_start=True)

    if problem.value < best_score:
        best_score = problem.value
        best_split = i
        best_vals  = (cL.value, cR.value, t.value)

In [56]:
cL.value

array([0.67837939])

In [57]:
cR.value

array([1.06966161])

In [58]:
best_split

371

In [59]:
n1 = np.sum(E_sample == 0)
n2 = np.sum(E_sample == 1)
n3 = np.sum(E_sample == 2)
n1_L = n_L.value[0]
n2_L = n_L.value[1]
n3_L = n_L.value[2]
n1_R = n_R.value[0]
n2_R = n_R.value[1]
n3_R = n_R.value[2]
K1 = rem_loss_vec[0]
K2 = rem_loss_vec[1]
K3 = rem_loss_vec[2]
sum_sq_1 = S0_L.value[0] + S0_R.value[0]
sum_sq_2 = S0_L.value[1] + S0_R.value[1]
sum_sq_3 = S0_L.value[2] + S0_R.value[2]
mu1_L = S1_L.value[0] / n1_L
mu2_L = S1_L.value[1] / n2_L
mu3_L = S1_L.value[2] / n3_L
mu1_R = S1_R.value[0] / n1_R
mu2_R = S1_R.value[1] / n2_R
mu3_R = S1_R.value[2] / n3_R

In [60]:
t.value

array(23.86100775)

In [61]:
1 / n1 * (sum_sq_1 + n1_L * cL.value ** 2 - 2 * n1_L * cL.value * mu1_L + n1_R * cR.value ** 2 - 2 * n1_R * cR.value * mu1_R + K1)

array([23.86100775])

In [62]:
1 / n2 * (sum_sq_2 + n2_L * cL.value ** 2 - 2 * n2_L * cL.value * mu2_L + n2_R * cR.value ** 2 - 2 * n2_R * cR.value * mu2_R + K2)

array([23.86100775])

In [63]:
1 / n3 * (sum_sq_3 + n3_L * cL.value ** 2 - 2 * n3_L * cL.value * mu3_L + n3_R * cR.value ** 2 - 2 * n3_R * cR.value * mu3_R + K3)

array([14.39344718])

In [64]:
def triple_binding(lambdas):
    li, lj = lambdas
    lk = 1 - li - lj

    # compute cL, cR
    numL = li*n1_L*mu1_L + lj*n2_L*mu2_L + lk*n3_L*mu3_L
    denL = li*n1_L        + lj*n2_L        + lk*n3_L
    cL   = numL/denL

    numR = li*n1_R*mu1_R + lj*n2_R*mu2_R + lk*n3_R*mu3_R
    denR = li*n1_R        + lj*n2_R        + lk*n3_R
    cR   = numR/denR

    # evaluate the 3 losses at (cL,cR)
    f_i = 1 / n1 * (sum_sq_1 + n1_L * cL ** 2 - 2 * n1_L * cL * mu1_L + n1_R * cR ** 2 - 2 * n1_R * cR * mu1_R + K1)
    f_j = 1 / n2 * (sum_sq_2 + n2_L * cL ** 2 - 2 * n2_L * cL * mu2_L + n2_R * cR ** 2 - 2 * n2_R * cR * mu2_R + K2)
    f_k = 1 / n3 * (sum_sq_3 + n3_L * cL ** 2 - 2 * n3_L * cL * mu3_L + n3_R * cR ** 2 - 2 * n3_R * cR * mu3_R + K3)

    # our two binding equations:
    return [f_i - f_j,
            f_j - f_k]

# initial guess, e.g. equally split:
init = np.array([1/3, 1/3])

sol = root(triple_binding, init, method='hybr')
li, lj = sol.x
lk     = 1 - li - lj

In [65]:
cL = (li*n1_L*mu1_L + lj*n2_L*mu2_L + lk*n3_L*mu3_L)/(li*n1_L + lj*n2_L + lk*n3_L)
cR = (li*n1_R*mu1_R + lj*n2_R*mu2_R + lk*n3_R*mu3_R)/(li*n1_R + lj*n2_R + lk*n3_R)

In [66]:
1 / n1 * (sum_sq_1 + n1_L * cL ** 2 - 2 * n1_L * cL * mu1_L + n1_R * cR ** 2 - 2 * n1_R * cR * mu1_R + K1)

np.float64(27.456302043683404)

In [67]:
1 / n2 * (sum_sq_2 + n2_L * cL ** 2 - 2 * n2_L * cL * mu2_L + n2_R * cR ** 2 - 2 * n2_R * cR * mu2_R + K2)

np.float64(27.456302043791595)

In [68]:
1 / n3 * (sum_sq_3 + n3_L * cL ** 2 - 2 * n3_L * cL * mu3_L + n3_R * cR ** 2 - 2 * n3_R * cR * mu3_R + K3)

np.float64(27.456302043919763)

In [76]:
indices = all_idx

# 1. Sort indices by the splitting feature
order = np.argsort(Xtr[indices].flatten())
sorted_indices = indices[order]
sorted_X = Xtr[sorted_indices]
sorted_Y = Ytr[sorted_indices]
sorted_E = Etr[sorted_indices]

# 2. Precompute per-env prefix sums & prefix squared-sums
N = len(sorted_indices)
n_envs = len(unique_envs)
env_idx = np.arange(n_envs)[:, None]
mask_mat = (sorted_E[None, :] == env_idx)

prefix_count_mat = np.concatenate(
    [np.zeros((n_envs,1),int), np.cumsum(mask_mat, axis=1)], axis=1
)

Y_mat = sorted_Y[None, :] * mask_mat
prefix_sum_mat = np.concatenate(
    [np.zeros((n_envs,1)), np.cumsum(Y_mat, axis=1)], axis=1
)
prefix_sq_mat = np.concatenate(
    [np.zeros((n_envs,1)), np.cumsum((Y_mat * sorted_Y[None,:]), axis=1)], axis=1
)

suffix_count_mat = np.zeros((n_envs, N + 1))
suffix_sum_mat = np.zeros((n_envs, N + 1))
suffix_sq_mat = np.zeros((n_envs, N + 1))

for e in range(n_envs):
    suffix_count_mat[e, N] = 0
    suffix_sum_mat[e, N] = 0.0
    suffix_sq_mat[e, N] = 0.0

for j in range(N - 1, -1, -1):
    for e in range(n_envs):
        m = 1 if sorted_E[j] == e else 0
        suffix_count_mat[e, j] = suffix_count_mat[e, j + 1] + m
        suffix_sum_mat[e, j] = suffix_sum_mat[e, j + 1] + m * sorted_Y[j]
        suffix_sq_mat[e, j] = suffix_sq_mat[e, j + 1] + m * (sorted_Y[j] ** 2)

total_count = np.zeros(n_envs)
for e in range(n_envs):
    total_count[e] = np.sum(Etr[all_idx] == e)

# 4. Precompute "remaining" loss per env (here remaining = none if all_idx used)
diff = Counter(all_idx) - Counter(indices)
remaining = np.array(list(diff.elements()), dtype=np.intc)
rem_loss_vec = np.zeros(n_envs)
if remaining.size > 0:
    mask_rem = (Etr[remaining][None, :] == env_idx)
    Y_rem    = Ytr[remaining][None, :] * mask_rem
    P_rem    = best_preds[remaining][None, :] * mask_rem
    errs_sq  = (Y_rem - P_rem) ** 2
    rem_loss_vec = errs_sq.sum(axis=1)

# 6. Loop over splits, update parameters, solve
best_score = np.inf
best_split = None
best_vals  = None

@njit
def compute_max_prop_balance(prefix_count, suffix_count, total_count, i):
    K = prefix_count.shape[0]
    max_left_prop = 0.0
    max_right_prop = 0.0

    for e in range(K):
        total = total_count[e]
        if total > 0:
            left_prop = prefix_count[e, i] / total
            right_prop = suffix_count[e, i] / total

            if left_prop > max_left_prop:
                max_left_prop = left_prop

            if right_prop > max_right_prop:
                max_right_prop = right_prop

    return abs(max_left_prop - max_right_prop)

def best_lam_pairs(lam):
    denL = (lam * n_ei_L + (1 - lam) * n_ej_L)
    if denL > 0:
        cL = (lam * n_ei_L * mu_ei_L + (1 - lam) * n_ej_L * mu_ej_L) / denL
    else:
        cL = 0
    denR = (lam * n_ei_R + (1 - lam) * n_ej_R)
    if denR:
        cR = (lam * n_ei_R * mu_ei_R + (1 - lam) * n_ej_R * mu_ej_R) / denR
    else:
        cR = 0

    fi = 1/n_ei * (sum_sq_ei + n_ei_L * cL ** 2 - 2 * n_ei_L * cL * mu_ei_L +
                   n_ei_R * cR ** 2 - 2 * n_ei_R * cR * mu_ei_R + K_ei)

    fj = 1/n_ej * (sum_sq_ej + n_ej_L * cL ** 2 - 2 * n_ej_L * cL * mu_ej_L +
                   n_ej_R * cR ** 2 - 2 * n_ej_R * cR * mu_ej_R + K_ej)

    return fi - fj

def best_lam_triplets(lam):
    li, lj = lam
    lk = 1 - li - lj

    denL = li * n_ei_L + lj * n_ej_L + lk * n_ek_L
    if denL > 0:
        cL = (li * n_ei_L * mu_ei_L + lj * n_ej_L * mu_ej_L + lk * n_ek_L * mu_ek_L) / denL
    else:
        cL = 0
    denR = li * n_ei_R + lj * n_ej_R + lk * n_ek_R
    if denR > 0:
        cR = (li * n_ei_R * mu_ei_R + lj * n_ej_R * mu_ej_R + lk * n_ek_R * mu_ek_R) / denR
    else:
        cR = 0

    fi = 1/n_ei * (sum_sq_ei + n_ei_L * cL ** 2 - 2 * n_ei_L * cL * mu_ei_L +
                   n_ei_R * cR ** 2 - 2 * n_ei_R * cR * mu_ei_R + K_ei)
    fj = 1/n_ej * (sum_sq_ej + n_ej_L * cL ** 2 - 2 * n_ej_L * cL * mu_ej_L +
                   n_ej_R * cR ** 2 - 2 * n_ej_R * cR * mu_ej_R + K_ej)
    fk = 1/n_ek * (sum_sq_ek + n_ek_L * cL ** 2 - 2 * n_ek_L * cL * mu_ek_L +
                   n_ek_R * cR ** 2 - 2 * n_ek_R * cR * mu_ek_R + K_ek)

    return [fi - fj, fj - fk]

for i in range(1, N):
#for i in [371]:
    max_diff = compute_max_prop_balance(
        prefix_count_mat, suffix_count_mat, total_count, i
    )

    best_t_it = np.inf
    best_vals_it = None

    env_stats = []
    for e in range(n_envs):
        n_e = total_count[e]
        n_e_L = prefix_count_mat[e, i]
        n_e_R = suffix_count_mat[e, i]
        sum_e_L = prefix_sum_mat[e, i]
        sum_e_R = suffix_sum_mat[e, i]
        mu_e_L = sum_e_L / n_e_L if n_e_L > 0 else 0.0
        mu_e_R = sum_e_R / n_e_R if n_e_R > 0 else 0.0
        sum_sq_e = prefix_sq_mat[e, i] + suffix_sq_mat[e, i]
        K_e = rem_loss_vec[e]

        env_stats.append((n_e, n_e_L, n_e_R, mu_e_L, mu_e_R, sum_sq_e, K_e))

    # Singletons
    min_t = np.inf
    best_vals_singletons = None
    #print('...............')
    for ei in range(n_envs):
        n_ei, n_ei_L, n_ei_R, mu_ei_L, mu_ei_R, sum_sq_ei, K_ei = env_stats[ei]
        cL = mu_ei_L
        cR = mu_ei_R
        fi = 1/n_ei * (sum_sq_ei + n_ei_L * cL ** 2 - 2 * n_ei_L * cL * mu_ei_L + n_ei_R * cR ** 2 - 2 * n_ei_R * cR * mu_ei_R + K_ei)
        #print(ei, fi)
        valid = True
        for ek in range(n_envs):
            if ek != ei:
                n_ek, n_ek_L, n_ek_R, mu_ek_L, mu_ek_R, sum_sq_ek, K_ek = env_stats[ek]
                fk = 1/n_ek * (sum_sq_ek + n_ek_L * cL ** 2 - 2 * n_ek_L * cL * mu_ek_L +
                               n_ek_R * cR ** 2 - 2 * n_ek_R * cR * mu_ek_R + K_ek)
                #print(fk)
                if fk > fi:
                    valid = False
                    break

        if valid and fi < min_t:
            min_t = fi
            best_vals_singletons = [cL, cR]
    if min_t < best_t_it:
        best_t_it = min_t
        best_vals_it = best_vals_singletons
    #print('Best t singletons:', min_t)
    #print('Best values singletons:', best_vals_singletons)
    #print('...............')

    # Pairs
    if n_envs >= 2:
        min_t = np.inf
        best_vals_pairs = None
        for ei in range(n_envs):
            n_ei, n_ei_L, n_ei_R, mu_ei_L, mu_ei_R, sum_sq_ei, K_ei = env_stats[ei]
            for ej in range(ei+1, n_envs):
                n_ej, n_ej_L, n_ej_R, mu_ej_L, mu_ej_R, sum_sq_ej, K_ej = env_stats[ej]

                fa = best_lam_pairs(0)
                fb = best_lam_pairs(1)
                if fa * fb < 0:
                    l = optimize.root_scalar(best_lam_pairs, bracket=[0, 1], method='brentq').root

                    denL = l * n_ei_L + (1 - l) * n_ej_L
                    cL = (l * n_ei_L * mu_ei_L + (1 - l) * n_ej_L * mu_ej_L) / denL if denL > 0 else 0

                    denR = l * n_ei_R + (1 - l) * n_ej_R
                    cR = (l * n_ei_R * mu_ei_R + (1 - l) * n_ej_R * mu_ej_R) / denR if denR > 0 else 0

                    fi = 1/n_ei * (sum_sq_ei + n_ei_L * cL ** 2 - 2 * n_ei_L * cL * mu_ei_L +
                                   n_ei_R * cR ** 2 - 2 * n_ei_R * cR * mu_ei_R + K_ei)

                    #print(ei, ej, fi)

                    valid = True
                    for ek in range(n_envs):
                        if ek != ei and ek != ej:
                            n_ek, n_ek_L, n_ek_R, mu_ek_L, mu_ek_R, sum_sq_ek, K_ek = env_stats[ek]

                            fk = 1/n_ek * (sum_sq_ek + n_ek_L * cL ** 2 - 2 * n_ek_L * cL * mu_ek_L +
                                           n_ek_R * cR ** 2 - 2 * n_ek_R * cR * mu_ek_R + K_ek)
                            #print(fk)
                            if fk > fi:
                                valid = False
                                break
                    if valid and fi < min_t:
                        min_t = fi
                        best_vals_pairs = [cL, cR]
        if min_t < best_t_it:
            best_t_it = min_t
            best_vals_it = best_vals_pairs
        #print('Best t pairs:', min_t)
        #print('Best values pairs:', best_vals_pairs)
        #print('...............')

    # Triplets
    if n_envs >= 3:
        min_t = np.inf
        best_vals_triplets = None
        for ei in range(n_envs):
            n_ei, n_ei_L, n_ei_R, mu_ei_L, mu_ei_R, sum_sq_ei, K_ei = env_stats[ei]
            for ej in range(ei+1, n_envs):
                n_ej, n_ej_L, n_ej_R, mu_ej_L, mu_ej_R, sum_sq_ej, K_ej = env_stats[ej]
                for ek in range(ej+1, n_envs):
                    n_ek, n_ek_L, n_ek_R, mu_ek_L, mu_ek_R, sum_sq_ek, K_ek = env_stats[ek]

                    init = np.array([1/3, 1/3])

                    sol = root(best_lam_triplets, init, method='lm')
                    if sol.success:
                        li, lj = sol.x
                        lk = 1 - li - lj

                        denL = li * n_ei_L + lj * n_ej_L + lk * n_ek_L
                        cL = (li * n_ei_L * mu_ei_L + lj * n_ej_L * mu_ej_L + lk * n_ek_L * mu_ek_L) / denL if denL > 0 else 0

                        denR = li * n_ei_R + lj * n_ej_R + lk * n_ek_R
                        cR = (li * n_ei_R * mu_ei_R + lj * n_ej_R * mu_ej_R + lk * n_ek_R * mu_ek_R) / denR if denR > 0 else 0

                        fi = 1/n_ei * (sum_sq_ei + n_ei_L * cL ** 2 - 2 * n_ei_L * cL * mu_ei_L +
                                       n_ei_R * cR ** 2 - 2 * n_ei_R * cR * mu_ei_R + K_ei)
                        fj = 1/n_ej * (sum_sq_ej + n_ej_L * cL ** 2 - 2 * n_ej_L * cL * mu_ej_L +
                                       n_ej_R * cR ** 2 - 2 * n_ej_R * cR * mu_ej_R + K_ej)
                        fk = 1/n_ek * (sum_sq_ek + n_ek_L * cL ** 2 - 2 * n_ek_L * cL * mu_ek_L +
                                       n_ek_R * cR ** 2 - 2 * n_ek_R * cR * mu_ek_R + K_ek)

                        f_values = [fi, fj, fk]
                        max_f = fi
                        for h in range(3):
                            if f_values[h] > max_f:
                                max_f = f_values[h]

                        #print(ei, ej, ek, fi, fj, fk)

                        valid = True
                        for eh in range(n_envs):
                            if eh != ei and eh != ej and eh != ek:
                                n_eh, n_eh_L, n_eh_R, mu_eh_L, mu_eh_R, sum_sq_eh, K_eh = env_stats[eh]
                                fh = 1/n_eh * (sum_sq_eh + n_eh_L * cL ** 2 - 2 * n_eh_L * cL * mu_eh_L +
                                               n_eh_R * cR ** 2 - 2 * n_eh_R * cR * mu_eh_R + K_eh)
                                #print(fh)
                                if fh > max_f:
                                    valid = False
                                    break

                        if valid and max_f < min_t:
                            min_t = max_f
                            best_vals_triplets = [cL, cR]
        if min_t < best_t_it:
            best_t_it = min_t
            best_vals_it = best_vals_triplets
        #print('Best t triplets:', min_t)
        #print('Best values triplets:', best_vals_triplets)
        #print('...............')

    score = best_t_it + alpha * max_diff

    if score < best_score:
        best_score = score
        best_split = i
        best_vals  = best_vals_it

In [77]:
best_score

np.float64(23.865290147840547)

In [71]:
best_vals

[np.float64(0.6783655413461), np.float64(1.0696434267365649)]

In [72]:
best_split

371