In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor, plot_tree

from joblib import Parallel, delayed

In [2]:
from QuadraticConstraintModel import get_leaf_samples

from QuadraticConstraintModel import constrained_optimization_gurobi

from QuadraticConstraintModel import predict_from_COF

from QuadraticConstraintModel import  get_h_from_COF

In [3]:
def normalized_root_mean_square_error(y_true, y_pred):
    """
    Computes the Normalized Root Mean Square Error (NRMSE) between y_true and y_pred.
    If the range of y_true is zero, it normalizes by the number of samples * outputs.

    Parameters:
        y_true (np.ndarray): Ground truth values, shape (n_samples, n_outputs)
        y_pred (np.ndarray): Predicted values, shape (n_samples, n_outputs)

    Returns:
        float: NRMSE value
    """
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # Compute RMSE
    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
    
    # Compute range
    y_range = np.max(y_true) - np.min(y_true)
    
    if y_range != 0:
        # Normalize by range
        return rmse / y_range
    else:
        # Normalize by n_samples * n_outputs
        n_samples, n_outputs = y_true.shape
        return np.sqrt(np.sum((y_true - y_pred) ** 2) / (n_samples * n_outputs))


In [4]:
# Function to load DataSet
def load_dataset(file_path, num_attributes=2, num_classes=2):
    data = pd.read_csv(file_path)
    X = data.iloc[:, 0 :  num_attributes].values
    y = data.iloc[:,  num_attributes:  num_attributes + num_classes].values
    # y = data.iloc[:, 9:10].values
    return X, y

In [22]:
sys_name = "navigation_old"
n_samples = 500000
X, y = load_dataset(f"Dataset/{sys_name}/{sys_name}_{n_samples}/data_{sys_name}_{n_samples}.csv",num_attributes=4, num_classes=4)

In [23]:
X_train, X_test, y_train, y_test = train_test_split(X, y,  test_size=0.5)
print(f" Shape of X_Training = {X_train.shape} \n Shape of X_Testing = {X_test.shape}")
print(f" Shape of Y_Training = {y_train.shape} \n Shape of Y_Testing = {y_test.shape}")


 Shape of X_Training = (249999, 4) 
 Shape of X_Testing = (250000, 4)
 Shape of Y_Training = (249999, 4) 
 Shape of Y_Testing = (250000, 4)


In [14]:
def printTreeStats(tree):
    # Access the tree_ object
    n_nodes = tree.tree_.node_count
    n_leaves = tree.get_n_leaves()
    children_left = tree.tree_.children_left
    children_right = tree.tree_.children_right
    feature = tree.tree_.feature
    threshold = tree.tree_.threshold
    impurity = tree.tree_.impurity
    n_node_samples = tree.tree_.n_node_samples
    values = tree.tree_.value
    
    print(f"Total nodes: {n_nodes}\n")
    
    print("Number of leaves:", n_leaves)
    
    for node in range(n_nodes):
        if children_left[node] == children_right[node]:  # leaf
            print(f"🌿 Leaf {node}:")
        else:
            print(f"📌 Node {node}: feature {feature[node]} ≤ {threshold[node]:.4f}")
    
        print(f"   Samples: {n_node_samples[node]}")
        print(f"   Impurity (MSE): {impurity[node]:.4f}")
        print(f"   Value (mean target): {values[node].ravel()[0]:.4f}")
        print("-" * 40)

In [8]:
def prune_tree_hmin(tree, leaf_COFS_dict, h_min, alpha=2):
    """
    Prune COF tree based on h_min threshold.
    
    Args:
        tree: Trained COF tree object
        leaf_COFS_dict: dict storing leaf information (indices, M, m0, h)
        h_min: Minimum h value threshold
        alpha: pruning parameter (optional)
        
    Returns:
        pruned_leaf_dict: updated leaf_COFS_dict after pruning
    """
    pruned_leaf_dict = leaf_COFS_dict.copy()
    pruning_happened = True
    
    while pruning_happened:
        pruning_happened = False
        for node, leaf_info in list(pruned_leaf_dict.items()):
            h = leaf_info["CO_Model"]["h"]
            if h < h_min:
                # Prune node by merging with parent
                parent = tree.get_parent(node)
                if parent is None:
                    continue
                
                combined_indices = leaf_info["indices"] + pruned_leaf_dict[parent]["indices"]
                
                # Recompute M, m0, h using optimizer on combined indices
                M, m0, new_h = optimize_leaf(tree, combined_indices)
                
                # Update parent
                pruned_leaf_dict[parent] = {
                    "leaf_id": parent,
                    "CO_Model": {"h": new_h, "M": M, "m0": m0},
                    "indices": combined_indices,
                    "no_samples": len(combined_indices)
                }
                
                # Delete child node
                del pruned_leaf_dict[node]
                
                pruning_happened = True
                print(f"🌳 Pruning triggered at node {node} with h={h:.6f}, alpha={alpha}")
    
    print(f"➡️ Pruning completed. Remaining leaves: {len(pruned_leaf_dict)}")
    return pruned_leaf_dict


def prune_tree_cost(tree, leaf_COFS_dict, alpha=2):
    """
    Prune COF tree based on cost function: sum(h) + alpha * number_of_leaves
    
    Args:
        tree: Trained COF tree object
        leaf_COFS_dict: dict storing leaf info (indices, M, m0, h)
        alpha: regularization parameter
        
    Returns:
        pruned_leaf_dict: updated leaf_COFS_dict after cost-based pruning
    """
    pruned_leaf_dict = leaf_COFS_dict.copy()
    pruning_happened = True
    
    while pruning_happened:
        pruning_happened = False
        min_cost_reduction = 0
        leaf_to_prune = None
        merge_with = None
        
        # Evaluate cost reduction for each leaf
        for node, leaf_info in pruned_leaf_dict.items():
            parent = tree.get_parent(node)
            if parent is None:
                continue
            
            combined_indices = leaf_info["indices"] + pruned_leaf_dict[parent]["indices"]
            M, m0, new_h = optimize_leaf(tree, combined_indices)
            
            # Compute current cost and merged cost
            current_cost = sum([info["CO_Model"]["h"] for info in pruned_leaf_dict.values()]) + alpha * len(pruned_leaf_dict)
            merged_cost = sum([info["CO_Model"]["h"] for n, info in pruned_leaf_dict.items() if n not in [node, parent]] + [new_h]) + alpha * (len(pruned_leaf_dict)-1)
            
            cost_reduction = current_cost - merged_cost
            if cost_reduction > min_cost_reduction:
                min_cost_reduction = cost_reduction
                leaf_to_prune = node
                merge_with = parent
        
        if leaf_to_prune is not None:
            # Perform the merge
            combined_indices = pruned_leaf_dict[leaf_to_prune]["indices"] + pruned_leaf_dict[merge_with]["indices"]
            M, m0, new_h = optimize_leaf(tree, combined_indices)
            
            pruned_leaf_dict[merge_with] = {
                "leaf_id": merge_with,
                "CO_Model": {"h": new_h, "M": M, "m0": m0},
                "indices": combined_indices,
                "no_samples": len(combined_indices)
            }
            del pruned_leaf_dict[leaf_to_prune]
            pruning_happened = True
            print(f"🌳 Cost-based pruning: merged node {leaf_to_prune} into parent {merge_with}, cost reduction={min_cost_reduction:.6f}")
    
    print(f"➡️ Cost-based pruning completed. Remaining leaves: {len(pruned_leaf_dict)}")
    return pruned_leaf_dict


In [7]:
# ----------------------------------------
# Your least squares optimizer
# ----------------------------------------
def least_squares_solution(X_leaf, y_leaf):
    n_samples, n_features = X_leaf.shape

    # Step 1: augment X with a column of ones for intercept
    X_aug = np.hstack([np.ones((n_samples, 1)), X_leaf])  # (n × (p+1))

    # Step 2: closed-form least squares
    XtX = X_aug.T @ X_aug
    XtY = X_aug.T @ y_leaf
    Theta = np.linalg.pinv(XtX) @ XtY   # pseudo-inverse for safety

    # Step 3: extract intercepts and coefficients
    m0 = Theta[0, :] if Theta.ndim > 1 else Theta[0]       # (n_outputs,) or scalar
    M = Theta[1:, :].T if Theta.ndim > 1 else Theta[1:].reshape(1, -1)  # (n_outputs × n_features) or (1 × n_features)

    # Step 4: compute residual sum of squares
    Y_hat = X_aug @ Theta
    residuals = y_leaf - Y_hat
    h_val = np.sum(residuals**2)

    return M, m0, h_val

# ----------------------------------------
# Updated prune function using least squares
# ----------------------------------------
def train_and_prune_COF_tree(X_train, y_train, 
                             initial_tree_params=None,
                             alpha=1,
                             h_min=1,
                             pruning_mode="hmin"):
    
    tree = DecisionTreeRegressor(**(initial_tree_params or {}))
    tree.fit(X_train, y_train)
    
    # Leaf info
    leaf_COFS_dict = {}
    leaf_ids = tree.apply(X_train)
    unique_leaves = np.unique(leaf_ids)
    
    for leaf in unique_leaves:
        indices = np.where(leaf_ids == leaf)[0]
        X_leaf = X_train[indices]
        y_leaf = y_train[indices]
        M, m0, h = least_squares_solution(X_leaf, y_leaf)
        leaf_COFS_dict[leaf] = {"indices": indices, "M": M, "m0": m0, "h": h}
    
    # Start pruning
    pruning_done = True
    while pruning_done:
        pruning_done = False
        # Sort leaves by h
        leaves_sorted = sorted(leaf_COFS_dict.items(), key=lambda x: x[1]["h"])
        
        for leaf, info in leaves_sorted:
            h_leaf = info["h"]
            
            if pruning_mode == "hmin" and h_leaf <= h_min:
                # Prune leaf
                print(f"🌳 h_min pruning: Removing leaf {leaf} with h={h_leaf:.6f}")
                del leaf_COFS_dict[leaf]
                pruning_done = True
                break
                
            elif pruning_mode == "cost":
                current_cost = sum([v["h"] for v in leaf_COFS_dict.values()]) + alpha * len(leaf_COFS_dict)
                temp_dict = leaf_COFS_dict.copy()
                del temp_dict[leaf]
                new_cost = sum([v["h"] for v in temp_dict.values()]) + alpha * len(temp_dict)
                
                if new_cost < current_cost:
                    print(f"🌳 Cost pruning: Removing leaf {leaf}, cost {current_cost:.4f} -> {new_cost:.4f}")
                    leaf_COFS_dict = temp_dict
                    pruning_done = True
                    break
    
    return tree, leaf_COFS_dict


In [10]:
# H_min pruning
tree_h, leaves_h = train_and_prune_COF_tree(
    X_train, y_train, initial_tree_params={"min_samples_leaf":1000}, 
    h_min=0.5, pruning_mode="hmin"
)

print(f"Final leaf count (h_min): {len(leaves_h)}")
print("h values:", [v["h"] for v in leaves_h.values()])

# Cost-based pruning
tree_c, leaves_c = train_and_prune_COF_tree(
    X_train, y_train, initial_tree_params={"min_samples_leaf":10000}, 
    alpha=2, pruning_mode="cost"
)

print(f"Final leaf count (cost): {len(leaves_c)}")
print("h values:", [v["h"] for v in leaves_c.values()])

🌳 h_min pruning: Removing leaf 30 with h=0.000000
🌳 h_min pruning: Removing leaf 29 with h=0.000000
🌳 h_min pruning: Removing leaf 70 with h=0.000000
🌳 h_min pruning: Removing leaf 8 with h=0.000000
🌳 h_min pruning: Removing leaf 52 with h=0.000000
🌳 h_min pruning: Removing leaf 9 with h=0.000000
🌳 h_min pruning: Removing leaf 53 with h=0.000000
🌳 h_min pruning: Removing leaf 71 with h=0.000000
🌳 h_min pruning: Removing leaf 10 with h=0.000000
🌳 h_min pruning: Removing leaf 31 with h=0.000000
🌳 h_min pruning: Removing leaf 50 with h=0.000000
🌳 h_min pruning: Removing leaf 72 with h=0.000000
🌳 h_min pruning: Removing leaf 14 with h=0.000000
🌳 h_min pruning: Removing leaf 37 with h=0.000000
🌳 h_min pruning: Removing leaf 38 with h=0.000000
🌳 h_min pruning: Removing leaf 34 with h=0.000000
🌳 h_min pruning: Removing leaf 13 with h=0.000000
🌳 h_min pruning: Removing leaf 17 with h=0.000001
🌳 h_min pruning: Removing leaf 16 with h=0.000001
🌳 h_min pruning: Removing leaf 35 with h=0.000001
🌳 

In [15]:
printTreeStats(tree_h)

Total nodes: 347

Number of leaves: 174
📌 Node 0: feature 1 ≤ 1.4658
   Samples: 249999
   Impurity (MSE): 0.4648
   Value (mean target): 1.5064
----------------------------------------
📌 Node 1: feature 0 ≤ 1.1519
   Samples: 122305
   Impurity (MSE): 0.2945
   Value (mean target): 1.1703
----------------------------------------
📌 Node 2: feature 2 ≤ 0.0015
   Samples: 60571
   Impurity (MSE): 0.2009
   Value (mean target): 0.5968
----------------------------------------
📌 Node 3: feature 3 ≤ 0.0013
   Samples: 30521
   Impurity (MSE): 0.1513
   Value (mean target): 0.5532
----------------------------------------
📌 Node 4: feature 1 ≤ 0.6804
   Samples: 15303
   Impurity (MSE): 0.1013
   Value (mean target): 0.5543
----------------------------------------
📌 Node 5: feature 0 ≤ 0.5707
   Samples: 8166
   Impurity (MSE): 0.0710
   Value (mean target): 0.5442
----------------------------------------
📌 Node 6: feature 3 ≤ -0.4880
   Samples: 4011
   Impurity (MSE): 0.0500
   Value (mean t

In [16]:
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.tree import DecisionTreeRegressor
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted

class SelectivePruningDecisionTree(BaseEstimator, RegressorMixin):
    """
    A Decision Tree Regressor that allows for selective post-pruning or
    identification of leaves for expansion based on a custom optimization metric.

    This model first trains a standard DecisionTreeRegressor. Then, for each leaf
    in the tree, it solves a constrained optimization problem to find a linear
    model (M, m_0) and an associated error metric (h). The tree can then be
    analyzed to prune leaves with low 'h' values or expand those with high 'h'
    values.

    Parameters
    ----------
    decision_tree_args : dict, optional
        Arguments to be passed to the underlying `DecisionTreeRegressor`.
        This allows for control over pre-pruning parameters like `max_depth`,
        `min_samples_leaf`, etc. Default is None.

    Attributes
    ----------
    tree_ : sklearn.tree.DecisionTreeRegressor
        The initially trained decision tree.
    leaf_models_ : dict
        A dictionary where keys are the leaf node indices and values are
        dictionaries containing the trained linear model ('M', 'm0') and
        the error metric 'h' for that leaf.
    is_fitted_ : bool
        Flag indicating if the model has been fitted.
    """
    def __init__(self, **decision_tree_args):
        """
        Initializes the SelectivePruningDecisionTree.

        Args:
            **decision_tree_args: Arbitrary keyword arguments for the
                                  underlying scikit-learn DecisionTreeRegressor.
                                  e.g., max_depth=5, min_samples_leaf=10
        """
        self.decision_tree_args = decision_tree_args

    def fit(self, X, y):
        """
        Fit the model. This involves three main steps:
        1. Train a standard DecisionTreeRegressor.
        2. Identify the leaf nodes and the data samples belonging to each leaf.
        3. For each leaf, run the constrained optimization to get M, m0, and h.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            The training input samples.
        y : array-like of shape (n_samples, n_outputs)
            The target values.

        Returns
        -------
        self : object
            Returns self.
        """
        # 1. Input validation
        X, y = check_X_y(X, y, multi_output=True, y_numeric=True)
        if y.ndim == 1:
            y = y.reshape(-1, 1)

        # 2. Train the initial decision tree
        self.tree_ = DecisionTreeRegressor(**self.decision_tree_args)
        self.tree_.fit(X, y)

        # 3. Calculate leaf assignments for the training data
        leaf_ids = self.tree_.apply(X)
        unique_leaves = np.unique(leaf_ids)

        # 4. Calculate metrics (M, m0, h) for each leaf
        self.leaf_models_ = self._calculate_leaf_metrics(X, y, leaf_ids, unique_leaves)

        self.is_fitted_ = True
        return self

    def _calculate_leaf_metrics(self, X, y, leaf_ids, unique_leaves):
        """
        Iterates through each leaf, segments the data, and runs the
        optimization function.

        Parameters
        ----------
        X : np.ndarray
            The full training dataset.
        y : np.ndarray
            The full target dataset.
        leaf_ids : np.ndarray
            An array where each element is the leaf index for a sample in X.
        unique_leaves : np.ndarray
            An array of unique leaf indices.

        Returns
        -------
        dict
            A dictionary containing the fitted models for each leaf.
        """
        leaf_models = {}
        for leaf in unique_leaves:
            # Find all data points that fall into the current leaf
            mask = leaf_ids == leaf
            X_leaf = X[mask]
            y_leaf = y[mask]

            if X_leaf.shape[0] > 0: # Ensure the leaf is not empty
                M_val, m0_val, h_val = self.constrained_optimization_gurobi(X_leaf, y_leaf)
                leaf_models[leaf] = {
                    'M': M_val,
                    'm0': m0_val,
                    'h': h_val,
                    'n_samples': X_leaf.shape[0]
                }
        return leaf_models

    @staticmethod
    def constrained_optimization_gurobi(X_leaf, y_leaf):
        """
        Solves the constrained optimization problem for a single leaf.

        This function is kept static as it doesn't depend on the state of the
        class instance ('self'), promoting modularity.

        Parameters
        ----------
        X_leaf : np.ndarray
            Data samples belonging to a specific leaf.
        y_leaf : np.ndarray
            Target values for the samples in X_leaf.

        Returns
        -------
        tuple
            A tuple containing (M_val, m0_val, h_val). Returns (None, None, np.inf)
            on optimization failure.
        """
        n_samples, n_features = X_leaf.shape
        n_outputs = y_leaf.shape[1]

        model = gp.Model("constrained_optimization")
        model.setParam("OutputFlag", 0)

        M = model.addVars(n_outputs, n_features, lb=-GRB.INFINITY, name="M")
        m0 = model.addVars(n_outputs, lb=-GRB.INFINITY, name="m0")
        h = model.addVar(lb=0, name="h")

        residuals = gp.QuadExpr()
        for i in range(n_samples):
            for k in range(n_outputs):
                pred_ik = m0[k] + gp.quicksum(M[k, j] * X_leaf[i, j] for j in range(n_features))
                residual = pred_ik - y_leaf[i, k]
                residuals += residual * residual

        model.addConstr(residuals <= h, name="residual_constraint")
        model.setObjective(h, GRB.MINIMIZE)
        model.optimize()

        if model.status in [GRB.OPTIMAL, GRB.SUBOPTIMAL] and model.SolCount > 0:
            M_val = np.array([[M[k, j].X for j in range(n_features)] for k in range(n_outputs)])
            m0_val = np.array([m0[k].X for k in range(n_outputs)])
            h_val = h.X
        else:
            print(f"Optimization failed for a leaf with {n_samples} samples. Status: {model.status}")
            M_val, m0_val, h_val = None, None, np.inf

        return M_val, m0_val, h_val

    def get_leaves_to_expand(self, h_threshold):
        """
        Identifies leaves that have an 'h' value greater than a given threshold.
        These are candidates for further splitting (expansion).

        Parameters
        ----------
        h_threshold : float
            The threshold for the 'h' metric.

        Returns
        -------
        list
            A list of leaf node indices to be considered for expansion.
        """
        check_is_fitted(self)
        leaves_to_expand = []
        for leaf, model_info in self.leaf_models_.items():
            if model_info['h'] > h_threshold:
                leaves_to_expand.append(leaf)
        return leaves_to_expand

    def get_leaves_to_prune(self, h_threshold):
        """
        Identifies leaves that have an 'h' value less than or equal to a
        given threshold. These are candidates for pruning.

        Parameters
        ----------
        h_threshold : float
            The threshold for the 'h' metric.

        Returns
        -------
        list
            A list of leaf node indices to be considered for pruning.
        """
        check_is_fitted(self)
        leaves_to_prune = []
        for leaf, model_info in self.leaf_models_.items():
            if model_info['h'] <= h_threshold:
                leaves_to_prune.append(leaf)
        return leaves_to_prune

    def predict(self, X):
        """
        Predict target values for X.

        For each sample in X, it finds the corresponding leaf in the tree
        and uses the leaf's linear model (M, m0) to make a prediction.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            The input samples to predict.

        Returns
        -------
        y_pred : array-like of shape (n_samples, n_outputs)
            The predicted values.
        """
        check_is_fitted(self)
        X = check_array(X)

        leaf_ids = self.tree_.apply(X)
        # Assuming y has at least one output
        n_outputs = self.leaf_models_[list(self.leaf_models_.keys())[0]]['M'].shape[0]
        y_pred = np.zeros((X.shape[0], n_outputs))

        for i, leaf_id in enumerate(leaf_ids):
            model = self.leaf_models_.get(leaf_id)
            if model and model['M'] is not None:
                # Prediction: y = X @ M.T + m0
                y_pred[i, :] = X[i, :] @ model['M'].T + model['m0']
            else:
                # Fallback: use the tree's original prediction for this leaf
                # This handles cases where optimization failed for a leaf.
                y_pred[i, :] = self.tree_.predict(X[i, :].reshape(1, -1))

        return y_pred



In [19]:
# --- Example Usage ---
if __name__ == '__main__':
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error

    # 1. Generate synthetic data
    np.random.seed(42)
    X = np.random.rand(200, 3) * 10
    # Create a piecewise linear relationship
    y1 = 2 * X[:, 0] + 3 * X[:, 1] + 5
    y2 = -1.5 * X[:, 0] + 0.5 * X[:, 2] - 2
    # Add noise
    y1 += np.random.normal(0, 2, size=y1.shape)
    y2 += np.random.normal(0, 2, size=y2.shape)
    # Different relationship for a different segment of data
    mask = X[:, 0] > 5
    y1[mask] = 4 * X[mask, 0] - 1 * X[mask, 1] + 10 + np.random.normal(0, 2, size=y1[mask].shape)
    y2[mask] = 1 * X[mask, 0] - 2 * X[mask, 2] + 3 + np.random.normal(0, 2, size=y2[mask].shape)

    y = np.vstack([y1, y2]).T

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # 2. Instantiate and fit the model
    # We use max_depth=3 to create a few leaves to analyze
    spt = SelectivePruningDecisionTree(max_depth=3, min_samples_leaf=10, random_state=42)
    spt.fit(X_train, y_train)

    # 3. Analyze the results
    print("Analysis of Fitted Leaf Models:")
    print("-" * 50)
    for leaf, model_info in spt.leaf_models_.items():
        print(f"Leaf Node ID: {leaf}")
        print(f"  - Number of Samples: {model_info['n_samples']}")
        print(f"  - H-value (error): {model_info['h']:.4f}")
        # print(f"  - M matrix:\n{model_info['M']}") # Uncomment for more detail
        # print(f"  - m0 vector: {model_info['m0']}") # Uncomment for more detail
        print("-" * 20)

    # 4. Use the model to identify leaves for expansion/pruning
    # Let's set a threshold. Leaves with error > threshold should be split further.
    # Leaves with error <= threshold are good candidates for pruning.
    h_error_threshold = 1000.0

    leaves_to_expand = spt.get_leaves_to_expand(h_error_threshold)
    leaves_to_prune = spt.get_leaves_to_prune(h_error_threshold)

    print(f"\nBased on h_threshold = {h_error_threshold}:")
    print(f"Leaves identified for EXPANSION (high error): {leaves_to_expand}")
    print(f"Leaves identified for PRUNING (low error): {leaves_to_prune}")
    print("\nThis information can now be used to guide a second, more refined modeling step.")

    # 5. Make predictions and evaluate
    y_pred = spt.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    print(f"\nMean Squared Error on Test Set: {mse:.4f}")


Analysis of Fitted Leaf Models:
--------------------------------------------------
Leaf Node ID: 3
  - Number of Samples: 24
  - H-value (error): 100.2010
--------------------
Leaf Node ID: 4
  - Number of Samples: 10
  - H-value (error): 222.3151
--------------------
Leaf Node ID: 6
  - Number of Samples: 26
  - H-value (error): 723.6372
--------------------
Leaf Node ID: 7
  - Number of Samples: 45
  - H-value (error): 1648.6170
--------------------
Leaf Node ID: 10
  - Number of Samples: 14
  - H-value (error): 100.2937
--------------------
Leaf Node ID: 11
  - Number of Samples: 19
  - H-value (error): 98.4425
--------------------
Leaf Node ID: 13
  - Number of Samples: 12
  - H-value (error): 75.4971
--------------------
Leaf Node ID: 14
  - Number of Samples: 10
  - H-value (error): 15.1974
--------------------

Based on h_threshold = 1000.0:
Leaves identified for EXPANSION (high error): [np.int64(7)]
Leaves identified for PRUNING (low error): [np.int64(3), np.int64(4), np.int64(

In [20]:
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.tree import DecisionTreeRegressor
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted

class SelectivePruningDecisionTree(BaseEstimator, RegressorMixin):
    """
    A Decision Tree Regressor that allows for selective post-pruning or
    identification of leaves for expansion based on a custom optimization metric.

    This model first trains a standard DecisionTreeRegressor. Then, for each leaf
    in the tree, it solves a constrained optimization problem to find a linear
    model (M, m_0) and an associated error metric (h). The tree can then be
    analyzed to prune leaves with low 'h' values or expand those with high 'h'
    values.

    Parameters
    ----------
    decision_tree_args : dict, optional
        Arguments to be passed to the underlying `DecisionTreeRegressor`.
        This allows for control over pre-pruning parameters like `max_depth`,
        `min_samples_leaf`, etc. Default is None.

    Attributes
    ----------
    tree_ : sklearn.tree.DecisionTreeRegressor
        The initially trained decision tree.
    leaf_models_ : dict
        A dictionary where keys are the leaf node indices and values are
        dictionaries containing the trained linear model ('M', 'm0') and
        the error metric 'h' for that leaf.
    sub_trees_ : dict
        A dictionary to hold new sub-trees for leaves that have been expanded.
    pruned_leaves_ : set
        A set of leaf node indices that have been marked for pruning.
    is_fitted_ : bool
        Flag indicating if the model has been fitted.
    """
    def __init__(self, **decision_tree_args):
        """
        Initializes the SelectivePruningDecisionTree.

        Args:
            **decision_tree_args: Arbitrary keyword arguments for the
                                  underlying scikit-learn DecisionTreeRegressor.
                                  e.g., max_depth=5, min_samples_leaf=10
        """
        self.decision_tree_args = decision_tree_args
        self.sub_trees_ = {}
        self.pruned_leaves_ = set()

    def fit(self, X, y):
        """
        Fit the model. This involves three main steps:
        1. Train a standard DecisionTreeRegressor.
        2. Identify the leaf nodes and the data samples belonging to each leaf.
        3. For each leaf, run the constrained optimization to get M, m0, and h.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            The training input samples.
        y : array-like of shape (n_samples, n_outputs)
            The target values.

        Returns
        -------
        self : object
            Returns self.
        """
        # 1. Input validation
        X, y = check_X_y(X, y, multi_output=True, y_numeric=True)
        if y.ndim == 1:
            y = y.reshape(-1, 1)

        # 2. Train the initial decision tree
        self.tree_ = DecisionTreeRegressor(**self.decision_tree_args)
        self.tree_.fit(X, y)

        # 3. Calculate leaf assignments for the training data
        leaf_ids = self.tree_.apply(X)
        unique_leaves = np.unique(leaf_ids)

        # 4. Calculate metrics (M, m0, h) for each leaf
        self.leaf_models_ = self._calculate_leaf_metrics(X, y, leaf_ids, unique_leaves)

        self.is_fitted_ = True
        return self

    def _calculate_leaf_metrics(self, X, y, leaf_ids, unique_leaves):
        """
        Iterates through each leaf, segments the data, and runs the
        optimization function.

        Parameters
        ----------
        X : np.ndarray
            The full training dataset.
        y : np.ndarray
            The full target dataset.
        leaf_ids : np.ndarray
            An array where each element is the leaf index for a sample in X.
        unique_leaves : np.ndarray
            An array of unique leaf indices.

        Returns
        -------
        dict
            A dictionary containing the fitted models for each leaf.
        """
        leaf_models = {}
        for leaf in unique_leaves:
            # Find all data points that fall into the current leaf
            mask = leaf_ids == leaf
            X_leaf = X[mask]
            y_leaf = y[mask]

            if X_leaf.shape[0] > 0: # Ensure the leaf is not empty
                M_val, m0_val, h_val = self.constrained_optimization_gurobi(X_leaf, y_leaf)
                leaf_models[leaf] = {
                    'M': M_val,
                    'm0': m0_val,
                    'h': h_val,
                    'n_samples': X_leaf.shape[0]
                }
        return leaf_models

    @staticmethod
    def constrained_optimization_gurobi(X_leaf, y_leaf):
        """
        Solves the constrained optimization problem for a single leaf.

        This function is kept static as it doesn't depend on the state of the
        class instance ('self'), promoting modularity.

        Parameters
        ----------
        X_leaf : np.ndarray
            Data samples belonging to a specific leaf.
        y_leaf : np.ndarray
            Target values for the samples in X_leaf.

        Returns
        -------
        tuple
            A tuple containing (M_val, m0_val, h_val). Returns (None, None, np.inf)
            on optimization failure.
        """
        n_samples, n_features = X_leaf.shape
        n_outputs = y_leaf.shape[1]

        model = gp.Model("constrained_optimization")
        model.setParam("OutputFlag", 0)

        M = model.addVars(n_outputs, n_features, lb=-GRB.INFINITY, name="M")
        m0 = model.addVars(n_outputs, lb=-GRB.INFINITY, name="m0")
        h = model.addVar(lb=0, name="h")

        residuals = gp.QuadExpr()
        for i in range(n_samples):
            for k in range(n_outputs):
                pred_ik = m0[k] + gp.quicksum(M[k, j] * X_leaf[i, j] for j in range(n_features))
                residual = pred_ik - y_leaf[i, k]
                residuals += residual * residual

        model.addConstr(residuals <= h, name="residual_constraint")
        model.setObjective(h, GRB.MINIMIZE)
        model.optimize()

        if model.status in [GRB.OPTIMAL, GRB.SUBOPTIMAL] and model.SolCount > 0:
            M_val = np.array([[M[k, j].X for j in range(n_features)] for k in range(n_outputs)])
            m0_val = np.array([m0[k].X for k in range(n_outputs)])
            h_val = h.X
        else:
            print(f"Optimization failed for a leaf with {n_samples} samples. Status: {model.status}")
            M_val, m0_val, h_val = None, None, np.inf

        return M_val, m0_val, h_val

    def get_leaves_to_expand(self, h_threshold):
        """
        Identifies leaves that have an 'h' value greater than a given threshold.
        These are candidates for further splitting (expansion).

        Parameters
        ----------
        h_threshold : float
            The threshold for the 'h' metric.

        Returns
        -------
        list
            A list of leaf node indices to be considered for expansion.
        """
        check_is_fitted(self)
        leaves_to_expand = []
        for leaf, model_info in self.leaf_models_.items():
            if model_info['h'] > h_threshold:
                leaves_to_expand.append(leaf)
        return leaves_to_expand

    def get_leaves_to_prune(self, h_threshold):
        """
        Identifies leaves that have an 'h' value less than or equal to a
        given threshold. These are candidates for pruning.

        Parameters
        ----------
        h_threshold : float
            The threshold for the 'h' metric.

        Returns
        -------
        list
            A list of leaf node indices to be considered for pruning.
        """
        check_is_fitted(self)
        leaves_to_prune = []
        for leaf, model_info in self.leaf_models_.items():
            if model_info['h'] <= h_threshold:
                leaves_to_prune.append(leaf)
        return leaves_to_prune

    def refine(self, X, y, h_threshold, expansion_args=None):
        """
        Refines the tree by expanding high-error leaves and marking low-error
        leaves for pruning.

        Expansion: Trains a new sub-tree on the data points of a high-error leaf.
        Pruning: Marks a leaf to use the original tree's simpler prediction
                 (mean of values) instead of its own linear model.

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            The training input samples, needed for training sub-trees.
        y : array-like of shape (n_samples, n_outputs)
            The target values, needed for training sub-trees.
        h_threshold : float
            The threshold for the 'h' metric to decide on expansion/pruning.
        expansion_args : dict, optional
            Arguments for the DecisionTreeRegressor used for the sub-trees,
            e.g., {'max_depth': 2}. Default is None.
        """
        check_is_fitted(self)
        if expansion_args is None:
            expansion_args = {}

        leaves_to_expand = self.get_leaves_to_expand(h_threshold)
        self.pruned_leaves_ = set(self.get_leaves_to_prune(h_threshold))

        # We need the original data to train sub-trees
        leaf_ids = self.tree_.apply(X)

        print(f"\nRefining tree based on h_threshold = {h_threshold}:")
        print(f" -> {len(leaves_to_expand)} leaves to expand: {leaves_to_expand}")
        print(f" -> {len(self.pruned_leaves_)} leaves to prune: {list(self.pruned_leaves_)}")

        for leaf in leaves_to_expand:
            print(f"  - Expanding leaf {leaf}...")
            mask = (leaf_ids == leaf)
            X_leaf = X[mask]
            y_leaf = y[mask]

            if X_leaf.shape[0] > 1: # Need at least 2 samples to split
                # For simplicity, we use a standard DecisionTreeRegressor for the sub-tree
                sub_tree = DecisionTreeRegressor(**expansion_args)
                sub_tree.fit(X_leaf, y_leaf)
                self.sub_trees_[leaf] = sub_tree
        return self


    def predict(self, X):
        """
        Predict target values for X.

        The prediction logic is as follows:
        1. If a sample falls into a leaf that was expanded, use the corresponding
           sub-tree to make a prediction.
        2. If a sample falls into a leaf that was pruned, use the original
           tree's simple (mean) prediction for that leaf.
        3. Otherwise, use the leaf's specialized linear model (M, m0).

        Parameters
        ----------
        X : array-like of shape (n_samples, n_features)
            The input samples to predict.

        Returns
        -------
        y_pred : array-like of shape (n_samples, n_outputs)
            The predicted values.
        """
        check_is_fitted(self)
        X = check_array(X)

        leaf_ids = self.tree_.apply(X)

        # Robustly find n_outputs from the first available leaf model
        first_valid_leaf = next((leaf for leaf in self.leaf_models_.values() if leaf['M'] is not None), None)
        if first_valid_leaf is None:
            raise RuntimeError("No valid leaf models were fitted. Cannot determine number of outputs.")
        n_outputs = first_valid_leaf['M'].shape[0]
        y_pred = np.zeros((X.shape[0], n_outputs))

        for i, leaf_id in enumerate(leaf_ids):
            # Case 1: The leaf was expanded into a sub-tree
            if leaf_id in self.sub_trees_:
                y_pred[i, :] = self.sub_trees_[leaf_id].predict(X[i, :].reshape(1, -1))
                continue

            # Case 2: The leaf was marked for pruning (use simple mean)
            if leaf_id in self.pruned_leaves_:
                y_pred[i, :] = self.tree_.predict(X[i, :].reshape(1, -1))
                continue

            # Case 3: Standard prediction using the leaf's linear model
            model = self.leaf_models_.get(leaf_id)
            if model and model['M'] is not None:
                y_pred[i, :] = X[i, :] @ model['M'].T + model['m0']
            else:
                # Fallback for leaves where optimization failed initially
                y_pred[i, :] = self.tree_.predict(X[i, :].reshape(1, -1))

        return y_pred



In [27]:
# --- Example Usage ---
if __name__ == '__main__':
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error

    # # 1. Generate synthetic data
    # np.random.seed(42)
    # X = np.random.rand(200, 3) * 10
    # # Create a piecewise linear relationship
    # y1 = 2 * X[:, 0] + 3 * X[:, 1] + 5
    # y2 = -1.5 * X[:, 0] + 0.5 * X[:, 2] - 2
    # # Add noise
    # y1 += np.random.normal(0, 2, size=y1.shape)
    # y2 += np.random.normal(0, 2, size=y2.shape)
    # # Different relationship for a different segment of data
    # mask = X[:, 0] > 5
    # y1[mask] = 4 * X[mask, 0] - 1 * X[mask, 1] + 10 + np.random.normal(0, 2, size=y1[mask].shape)
    # y2[mask] = 1 * X[mask, 0] - 2 * X[mask, 2] + 3 + np.random.normal(0, 2, size=y2[mask].shape)

    # y = np.vstack([y1, y2]).T

    # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # 2. Instantiate and fit the model
    spt = SelectivePruningDecisionTree(max_depth=5, min_samples_leaf=10000, random_state=42)
    spt.fit(X_train, y_train)

    # 3. Analyze the results
    print("--- Initial Analysis of Fitted Leaf Models ---")
    print("-" * 50)
    for leaf, model_info in spt.leaf_models_.items():
        print(f"Leaf Node ID: {leaf}, Samples: {model_info['n_samples']}, H-value: {model_info['h']:.2f}")

    # 4. Evaluate the model BEFORE refinement
    y_pred_before = spt.predict(X_test)
    mse_before = mean_squared_error(y_test, y_pred_before)
    print(f"\nMean Squared Error on Test Set (BEFORE refinement): {mse_before:.4f}")

    # 5. REFINE the tree based on the h-value threshold
    h_error_threshold = 1.0
    # For sub-trees, let's give them a bit more depth
    expansion_params = {'max_depth': 2, 'min_samples_leaf': 1000, 'random_state': 42}
    spt.refine(X_train, y_train, h_threshold=h_error_threshold, expansion_args=expansion_params)

    # 6. Evaluate the model AFTER refinement
    y_pred_after = spt.predict(X_test)
    mse_after = mean_squared_error(y_test, y_pred_after)
    print(f"\nMean Squared Error on Test Set (AFTER refinement): {mse_after:.4f}")
    
    improvement = ((mse_before - mse_after) / mse_before) * 100
    print(f"Improvement: {improvement:.2f}%")



--- Initial Analysis of Fitted Leaf Models ---
--------------------------------------------------
Leaf Node ID: 4, Samples: 15197, H-value: 22.50
Leaf Node ID: 5, Samples: 14986, H-value: 22.30
Leaf Node ID: 7, Samples: 15869, H-value: 23.69
Leaf Node ID: 8, Samples: 15162, H-value: 22.63
Leaf Node ID: 11, Samples: 15729, H-value: 13.40
Leaf Node ID: 12, Samples: 15776, H-value: 13.14
Leaf Node ID: 14, Samples: 15390, H-value: 13.17
Leaf Node ID: 15, Samples: 15605, H-value: 13.32
Leaf Node ID: 19, Samples: 15755, H-value: 13.10
Leaf Node ID: 20, Samples: 15286, H-value: 13.00
Leaf Node ID: 22, Samples: 15922, H-value: 13.33
Leaf Node ID: 23, Samples: 15914, H-value: 13.47
Leaf Node ID: 26, Samples: 15783, H-value: 23.27
Leaf Node ID: 27, Samples: 15811, H-value: 24.05
Leaf Node ID: 29, Samples: 15932, H-value: 23.75
Leaf Node ID: 30, Samples: 15882, H-value: 24.44

Mean Squared Error on Test Set (BEFORE refinement): 0.0003

Refining tree based on h_threshold = 1.0:
 -> 16 leaves to ex