Find the Best Gini-Based Split for a Binary Decision Tree
Medium
Machine Learning

Implement a function that scans every feature and threshold in a small data set, then returns the split that minimises the weighted Gini impurity. Your implementation should support binary class labels (0 or 1) and handle ties gracefully.

You will write one function:

find_best_split(X: np.ndarray, y: np.ndarray) -> tuple[int, float]
X is an 
n
×
d
n×d NumPy array of numeric features.
y is a length-
n
n NumPy array of 0/1 labels.
The function returns (best_feature_index, best_threshold) for the split with the lowest weighted Gini impurity.
If several splits share the same impurity, return the first that you encounter while scanning features and thresholds.
Example:
Input:
import numpy as np
X = np.array([[2.5],[3.5],[1.0],[4.0]])
y = np.array([0,1,0,1])
print(find_best_split(X, y))
Output:
(0, 2.5)
Reasoning:
Splitting on feature 0 at threshold 2.5 yields two perfectly pure leaves, producing the minimum possible weighted Gini impurity.

In [None]:
import numpy as np
from typing import Tuple

def _gini(y: np.ndarray) -> float:
    if y.size == 0: return 0.0
    p = y.mean()  # proportion of class 1
    return 2.0 * p * (1.0 - p)  # 1 - (p^2 + (1-p)^2)

def find_best_split(X: np.ndarray, y: np.ndarray) -> Tuple[int, float]:
    """Return (feature_index, threshold) minimizing weighted Gini impurity.
    Split rule: left = x_j <= threshold, right = x_j > threshold.
    Ties are broken by first encountered (feature, threshold) during scan.
    """
    n, d = X.shape
    best_imp = float("inf")
    best_feat = -1
    best_thr = np.nan

    for j in range(d):
        vals = X[:, j]
        # candidate thresholds: observed values (sorted unique)
        thrs = np.unique(vals)
        for t in thrs:
            left = y[vals <= t]
            right = y[vals > t]
            if left.size == 0 or right.size == 0:
                continue  # skip degenerate split
            imp = (left.size/n) * _gini(left) + (right.size/n) * _gini(right)
            if imp < best_imp:
                best_imp = imp
                best_feat = j
                best_thr = float(t)

    # fallback if no valid split found (e.g., all identical values)
    if best_feat == -1:
        best_feat = 0
        best_thr = float(X[0, 0])

    return best_feat, best_thr

# Example
if __name__ == "__main__":
    X = np.array([[2.5],[3.5],[1.0],[4.0]])
    y = np.array([0,1,0,1])
    print(find_best_split(X, y))  # (0, 2.5)