In [146]:
import sys
sys.path.append("..")
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score, precision_score, recall_score, mean_absolute_percentage_error, mean_absolute_error
from src.families import line1_families, get_alphabetical_family
import optuna
from time import perf_counter
import click
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score
from time import perf_counter
import xgboost as xgb
from xgboost import XGBClassifier, XGBRegressor
from fir.printer import colour_printer
from fir.core import  FeatureImportanceRanker
import shap

In [None]:
ranker = FeatureImportanceRanker(
    filepath="/Users/issaodeh/Documents/GitHub/intern-project-FE/dataset_uv_line1.csv",
    target_column="QM_database.scope_scrap_rate",
    columns_to_drop=["__index_level_0__", "QM_database.cast_scrap_rate", "QM_database.sand_scrap_rate"],
)

###################    DISCLAIMER      ######################
 #set your data & file paths in the following global variables
# DATA_PATH (PATH) : the exact location on your machine of the tabular data
# you wish to study using this tool. eg. "~/Documents/data/dataset_uv_line1.csv"

# TARGET_COl (STR) : the exact column that you wish to make predictions on eg. "scrap_flag" or similar

#BATCH_COL (STR): the name of the column that corresponds to the batch number/Letter if present.eg. "uv_batch_info.cim_batch.pattern_number"

#DEL_COLS (LIST): A list of  columsn that you wish to be deleted from the dataset to make manipulation easier.

DATA_PATH   = "/Users/issaodeh/Documents/GitHub/intern-project-FE/dataset_uv_line1.csv"      #  e.g. "../dataset_1.csv"
TARGET_COL  = "QM_database.scope_scrap_rate"   #  e.g. "scrap_flag" or similar
BATCH_COL = "uv_batch_info.cim_batch.pattern_number" # name of batch col if none leave as: None
DEL_COLS = ["__index_level_0__", "QM_database.cast_scrap_rate","QM_database.sand_scrap_rate"]
CLASSIFICATION = False # True or False
###################    DISCLAIMER      ######################


def split_by_batch_letter(data_frame, batch_col, target_col):
    """Split a DataFrame into (X, y) pairs for each batch “family letter.”

    This function:
      1. Adds a column “corresponding_batch_family_letter” by mapping each batch number
         (in batch_col) to its family letter via get_alphabetical_family(...) if the
         batch is in line1_families; otherwise it becomes None.
      2. Groups the rows by that family letter.
      3. For each letter, returns:
         - X_sub: all columns except batch_col, target_col, and the new family letter column.
         - y_sub: a copy of the Series `data_frame[target_col]` for that group.
    Args:
        data_frame (pd.DataFrame): Cleaned data containing at least `batch_col` and `target_col`.
        batch_col (str): Column name holding the batch number.
        target_col (str): Column name of the target variable.
    Returns:
        dict[str, (pd.DataFrame, pd.Series)]:
            Keys are each family letter (or None), values are (X_sub, y_sub).
    Notes:
      • The method mutates data_frame by adding “corresponding_batch_family_letter.”
      • If a batch isnt found in line1_families, it lands under key None.
      • Ensure you call sub[target_col].copy() (not .copy) to actually copy the Series.
    """
    #initializing result dictionary
    result = {}
    #create a new column where the correspondinh family letter is assigned to each row of data. Uses a function and key dictionary defined in families.py to divide them up
    data_frame["corresponding_batch_family_letter"] = data_frame[batch_col].map(lambda batch_num: get_alphabetical_family(batch_num) if batch_num in line1_families else None)

    #groups by letter and then copies and drops the neccesairy data to the result dict and then returns the result dict
    for letter, sub in data_frame.groupby("corresponding_batch_family_letter"):
        y_sub = sub[target_col].copy()
        X_sub = sub.drop(columns= [batch_col, target_col, "corresponding_batch_family_letter"])

        result[letter] = (X_sub, y_sub)
    return result

def load_and_clean_data(data_path, target_col, split_by_group=False, del_cols=None, batch_col=BATCH_COL):
    """
    Load a CSV file, remove unwanted columns and rows with missing data, drop constant columns,
    and return feature and target sets (optionally split by batch family).

    Args:
        data_path (str):
            Path to the CSV file.
        forced_str (list[int] or list[str]):
            Column indices or names to coerce to string dtype upon loading. (ONLY USE IF ERRORS ARRISE WITH displaying / loading the data)
        target_col (str):
            Name of the target column to extract as y.
        del_cols (list[int] or list[str] or None, optional):
            Columns to drop immediately after loading (defaults to None).
        batch_col (str, optional):
            Name of the batch ID column used when FAM_SPLIT is True

    Returns:
        If FAM_SPLIT is False:
            tuple[pd.DataFrame, pd.Series]:
                X — DataFrame of all remaining features (no NaNs, no constant columns, no batch_col).
                y — Series containing the target_col values.
        If FAM_SPLIT is True:
            dict[str, (pd.DataFrame, pd.Series)]:
                A mapping from each batch family letter to its (X_sub, y_sub) pair,
                where X_sub has batch_col, target_col, and helper columns removed.

    Raises:
        KeyError:
            If target_col is not present after dropping rows with any NaN.
    """
    df = pd.read_csv(
    data_path,
    low_memory=False
    )

    #replaces all NA values with pandas NaN
    df.replace(to_replace = ["N/A", "NA", "na", "NaN", "nan", "null", "NULL",'\\N', "","//n"], value = np.nan, inplace=True)
    #drop the specified columns if they exist
    if del_cols is not None:
        df.drop(columns = del_cols, errors = "ignore", inplace = True)

    df.dropna(axis=0, how = "any", inplace=True) #drops any row that has NaN in it.
    df.reset_index(drop = True, inplace = True) #resets indeces after dropping all the rows and prevents pandas from adding a new columns of the old indices

    #checks if any column value has is constant accross all the rows if so removes it
    constant_cols = [col for col in df.columns if df[col].nunique(dropna=False) == 1]
    df = df.drop(columns = constant_cols)

    #raises KeyError if the target column is not in the dataframe
    if target_col not in df.columns:
        raise KeyError( f" the target column {target_col}, does not exist in the dataframe after cleaning.")

    if split_by_group:
        return split_by_batch_letter(df, BATCH_COL, TARGET_COL)
    else:
        y = df[target_col].copy()
        if batch_col is not None:
            X = df.drop(columns = [target_col,batch_col])
        else:
            X = df.drop(columns = [target_col])
        return {"global_data":(X, y)}

def split_data(data_tuple, train_size, test_size, random_state=42):
    """
    Split data into train and test sets, optionally stratifying on the target.

    Depending on `letter`, this function calls `get_appropriate_data_and_target(letter)` to fetch
    (X_full, y_full). It then applies sklearns `train_test_split`, using stratification only if
    every class in `data_y` appears at least twice.

    Args:
        data_y (pd.Series):
            Target series corresponding to X_full. Used to decide stratification.
        test_size (float, optional):
            Fraction of data to allocate to the test set (default 0.2).
        letter (str or None, optional):
            If a family letter is provided, split that subset; otherwise split the global (X, y).
        random_state (int, optional):
            Seed for reproducibility (default 0).

    Returns:
        tuple[
            pd.DataFrame,  # X_train
            pd.DataFrame,  # X_test
            pd.Series,     # y_train
            pd.Series      # y_test
        ]

    Raises:
        KeyError: If `letter` is not None but not found in `batch_dict`.
    """
    data_y = data_tuple[1]
    counts = data_y.value_counts()
    if len(counts) > 1 and counts.min() >= 2:
        stratify_arg = data_y
    else:
        stratify_arg = None

    X_train, X_test, y_train, y_test = train_test_split(
        *data_tuple, train_size=train_size, test_size=test_size, random_state=random_state, shuffle=True, stratify=stratify_arg)

    return X_train, X_test, y_train, y_test

def split_data_dict(data_grouping_dict, train_size, test_size, random_state=42):
    return {group_label: split_data(data_tuple, train_size, test_size, random_state) for group_label, data_tuple in data_grouping_dict.items()}


In [148]:

###################    DISCLAIMER      ######################
#   set your data and target paths in the required data_cleaning_tools code box above
###################    DISCLAIMER      ######################

batch_data_dict = load_and_clean_data(DATA_PATH,TARGET_COL,del_cols=DEL_COLS, split_by_group=True, batch_col=BATCH_COL)


In [149]:
global_data = load_and_clean_data(DATA_PATH,TARGET_COL,del_cols=DEL_COLS, split_by_group=False, batch_col=BATCH_COL)

In [150]:
global_data_X, global_data_y = global_data["global_data"]

In [151]:
# batch_data_dict["S"] is a tuple: (X, y) for batch "S"
X_S, y_S = batch_data_dict["S"]

# Combine X and y for display
df_S = X_S.copy()
df_S["target"] = y_S
display(df_S)

def split_by_target_quartiles(df, target_col):
    """
    Splits the DataFrame into four sub-DataFrames based on the quartiles of the target column.
    Returns a dict: {'Q1': df_q1, 'Q2': df_q2, 'Q3': df_q3, 'Q4': df_q4}
    """
    # Ensure target column is numeric
    target_numeric = pd.to_numeric(df[target_col], errors="coerce")
    q1 = target_numeric.quantile(0.25)
    q2 = target_numeric.quantile(0.50)
    q3 = target_numeric.quantile(0.75)

    df_q1 = df[target_numeric <= q1]
    df_q2 = df[(target_numeric > q1) & (target_numeric <= q2)]
    df_q3 = df[(target_numeric > q2) & (target_numeric <= q3)]
    df_q4 = df[target_numeric > q3]

    return {'Q1': df_q1, 'Q2': df_q2, 'Q3': df_q3, 'Q4': df_q4}

# Example usage with your S batch DataFrame:
batch_S_quartile_dfs = split_by_target_quartiles(df_S, "target")
for q, subdf in batch_S_quartile_dfs.items():
    print(f"{q} shape: {subdf.shape}")


Unnamed: 0,ABB.ATAS.C,ABB.ATAS.Si,ABB.laddle_event.ladle_event_tap_temperature,ABB.progelta.Weight_REAL_Cu,ABB.progelta.Weight_REAL_Fe,ABB.progelta.Weight_REAL_Mg,ABB.progelta.mapped_ABB_laddle_event_weight,moulding.CIM3Data.blow_off_in_operation_3_swing_plate_bottom_stop,moulding.CIM3Data.blow_off_in_operation_3_swing_plate_front_stop,moulding.CIM3Data.blow_off_in_operation_6_swing_plate_bottom_start,...,pouring.ATASW.Si_equivalent,sand_plant.Sand_Mixer.Compactability_mean_value,sand_plant.Sand_Mixer.Correction_moisture,sand_plant.Sand_Mixer.Moisture_return_sand,sand_plant.Sand_Mixer.S1C1_Act_return_sand,sand_plant.Sand_Mixer.S3C1_Act_water,sand_plant.Sand_Mixer.Shearing_strength_mean_value,sand_plant.Sand_Mixer.Temperature_return_sand,sand_plant.Sand_Mixer.bentonite_and_coal_dust_mix_additions,target
4,3.78348,2.87077,1505.192138671875,1.3328859806,21.5749816895,17.771812439,2541.7777777777774,1550.0,1550.0,1550.0,...,3.17189,34.807964325,-0.8569321632,1.5047121048,4920.7177734375,66.6999969482,5.4496355057,40.7696762085,41.015623,0.11661166116611661
5,3.78348,2.87077,1505.192138671875,1.3328859806,21.5749816895,17.771812439,2541.7777777777774,1550.0,1550.0,1550.0,...,3.17189,36.9298782349,-0.8214690685,1.48983109,4901.6201171875,69.200012207,5.6758379936,40.3935203552,41.449654,0.11661166116611661
6,3.78348,2.87077,1505.192138671875,1.3328859806,21.5749816895,17.771812439,2541.7777777777774,1550.0,1550.0,1550.0,...,3.17189,36.9298782349,-0.8214690685,1.48983109,4901.6201171875,69.200012207,5.6758379936,40.3935203552,41.449654,0.11661166116611661
7,3.78348,2.87077,1505.192138671875,1.3328859806,21.5749816895,17.771812439,2541.7777777777774,1550.0,1550.0,1550.0,...,3.17189,36.9298782349,-0.8214690685,1.48983109,4901.6201171875,69.200012207,5.6758379936,40.3935203552,41.449654,0.11661166116611661
8,3.78348,2.87077,1505.192138671875,1.3328859806,21.5749816895,17.771812439,2541.7777777777774,1550.0,1550.0,1550.0,...,3.17189,36.9298782349,-0.8214690685,1.48983109,4901.6201171875,69.200012207,5.6758379936,40.3935203552,41.449654,0.11661166116611661
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1027849,3.72818,2.91767,1504.6392822265625,1.8304966688,21.3261737823,17.7895851135,2473.0297619047615,1550.0,1550.0,1550.0,...,3.08896,36.5199623108,-0.7506318688,1.4451882839,4938.2236328125,76,5.5641102791,42.390045166,75.347218,0.10869086908690868
1027850,3.72818,2.91767,1504.6392822265625,1.8304966688,21.3261737823,17.7895851135,2473.0297619047615,1550.0,1550.0,1550.0,...,3.08896,36.5199623108,-0.7506318688,1.4451882839,4938.2236328125,76,5.5641102791,42.390045166,75.347218,0.10869086908690868
1027851,3.72818,2.91767,1504.6392822265625,1.8304966688,21.3261737823,17.7895851135,2473.0297619047615,1550.0,1550.0,1550.0,...,3.08896,36.5199623108,-0.7506318688,1.4451882839,4938.2236328125,76,5.5641102791,42.390045166,75.347218,0.10869086908690868
1027852,3.72818,2.91767,1504.6392822265625,1.8304966688,21.3261737823,17.7895851135,2473.0297619047615,1550.0,1550.0,1550.0,...,3.08896,36.5199623108,-0.7506318688,1.4451882839,4938.2236328125,76,5.5641102791,42.390045166,75.347218,0.10869086908690868


Q1 shape: (42882, 68)
Q2 shape: (41954, 68)
Q3 shape: (41965, 68)
Q4 shape: (42254, 68)


In [152]:
df_q1 = batch_S_quartile_dfs["Q1"]
X_q1 = df_q1.drop(columns=["target"])
y_q1 = df_q1["target"]

df_q3 = batch_S_quartile_dfs["Q3"]
X_q3 = df_q1.drop(columns=["target"])
y_q3 = df_q1["target"]

X_train, X_test, y_train, y_test = split_data((X_S, y_S), train_size=0.8, test_size=0.2)
X_train_q1, X_test_q1, y_train_q1, y_test_q1 = split_data((X_q1, y_q1), train_size=0.8, test_size=0.2)
X_train = X_train.apply(pd.to_numeric)
X_test = X_test.apply(pd.to_numeric)

X_train_q1 = X_train_q1.apply(pd.to_numeric)
X_test_q1 = X_test_q1.apply(pd.to_numeric)


X_train_q3, X_test_q3, y_train_q1, y_test_q1 = split_data((X_q3, y_q3), train_size=0.8, test_size=0.2)


X_train_q3 = X_train_q3.apply(pd.to_numeric)
X_test_q3 = X_test_q3.apply(pd.to_numeric)

In [153]:
display(df_q1)
display(batch_S_quartile_dfs["Q3"])

Unnamed: 0,ABB.ATAS.C,ABB.ATAS.Si,ABB.laddle_event.ladle_event_tap_temperature,ABB.progelta.Weight_REAL_Cu,ABB.progelta.Weight_REAL_Fe,ABB.progelta.Weight_REAL_Mg,ABB.progelta.mapped_ABB_laddle_event_weight,moulding.CIM3Data.blow_off_in_operation_3_swing_plate_bottom_stop,moulding.CIM3Data.blow_off_in_operation_3_swing_plate_front_stop,moulding.CIM3Data.blow_off_in_operation_6_swing_plate_bottom_start,...,pouring.ATASW.Si_equivalent,sand_plant.Sand_Mixer.Compactability_mean_value,sand_plant.Sand_Mixer.Correction_moisture,sand_plant.Sand_Mixer.Moisture_return_sand,sand_plant.Sand_Mixer.S1C1_Act_return_sand,sand_plant.Sand_Mixer.S3C1_Act_water,sand_plant.Sand_Mixer.Shearing_strength_mean_value,sand_plant.Sand_Mixer.Temperature_return_sand,sand_plant.Sand_Mixer.bentonite_and_coal_dust_mix_additions,target
4870,3.91084,1.87932,1501.1123860677083,4.8339328766,21.5394363403,17.8073558807,2467.4285714285716,1550.0,1550.0,1550.0,...,2.16385,38.0872840881,-0.5221194625,1.5391863585,4922.3090820312,86.0999908447,5.6776690483,48.0324058533,75.303818,0
4871,3.91084,1.87932,1501.1123860677083,4.8339328766,21.5394363403,17.8073558807,2467.4285714285716,1550.0,1550.0,1550.0,...,2.16385,38.0872840881,-0.5221194625,1.5391863585,4922.3090820312,86.0999908447,5.6776690483,48.0324058533,75.303818,0
4872,3.91084,1.87932,1501.1123860677083,4.8339328766,21.5394363403,17.8073558807,2467.4285714285716,1550.0,1550.0,1550.0,...,2.16385,38.6901016235,-0.5326652527,1.5693204403,4917.5346679688,84.8000183105,5.6456160545,48.5821762085,74.565971,0
4873,3.91084,1.87932,1501.1123860677083,4.8339328766,21.5394363403,17.8073558807,2467.4285714285716,1550.0,1550.0,1550.0,...,2.16385,38.6901016235,-0.5326652527,1.5693204403,4917.5346679688,84.8000183105,5.6456160545,48.5821762085,74.565971,0
4874,3.91084,1.87932,1501.1123860677083,4.8339328766,21.5394363403,17.8073558807,2467.4285714285716,1550.0,1550.0,1550.0,...,2.16385,38.6901016235,-0.5326652527,1.5693204403,4917.5346679688,84.8000183105,5.6456160545,48.5821762085,74.565971,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1024609,3.92754,1.88602,1507.7097981770833,1.012993335723877,22.67683219909668,18.926979064941406,2546.3333333333335,1550.0,1550.0,1550.0,...,2.22127,38.0993423462,-0.600710392,1.6801834106,4935.0405273438,75.8999938965,5.5815105438,48.234954834,78.385414,0
1024610,3.92754,1.88602,1507.7097981770833,1.012993335723877,22.67683219909668,18.926979064941406,2546.3333333333335,1550.0,1550.0,1550.0,...,2.22127,38.0993423462,-0.600710392,1.6801834106,4935.0405273438,75.8999938965,5.5815105438,48.234954834,78.385414,0
1024611,3.92754,1.88602,1507.7097981770833,1.012993335723877,22.67683219909668,18.926979064941406,2546.3333333333335,1550.0,1550.0,1550.0,...,2.22127,38.0993423462,-0.600710392,1.6801834106,4935.0405273438,75.8999938965,5.5815105438,48.234954834,78.385414,0
1024612,3.92754,1.88602,1507.7097981770833,1.012993335723877,22.67683219909668,18.926979064941406,2546.3333333333335,1550.0,1550.0,1550.0,...,2.22127,38.0993423462,-0.600710392,1.6801834106,4935.0405273438,75.8999938965,5.5815105438,48.234954834,78.385414,0


Unnamed: 0,ABB.ATAS.C,ABB.ATAS.Si,ABB.laddle_event.ladle_event_tap_temperature,ABB.progelta.Weight_REAL_Cu,ABB.progelta.Weight_REAL_Fe,ABB.progelta.Weight_REAL_Mg,ABB.progelta.mapped_ABB_laddle_event_weight,moulding.CIM3Data.blow_off_in_operation_3_swing_plate_bottom_stop,moulding.CIM3Data.blow_off_in_operation_3_swing_plate_front_stop,moulding.CIM3Data.blow_off_in_operation_6_swing_plate_bottom_start,...,pouring.ATASW.Si_equivalent,sand_plant.Sand_Mixer.Compactability_mean_value,sand_plant.Sand_Mixer.Correction_moisture,sand_plant.Sand_Mixer.Moisture_return_sand,sand_plant.Sand_Mixer.S1C1_Act_return_sand,sand_plant.Sand_Mixer.S3C1_Act_water,sand_plant.Sand_Mixer.Shearing_strength_mean_value,sand_plant.Sand_Mixer.Temperature_return_sand,sand_plant.Sand_Mixer.bentonite_and_coal_dust_mix_additions,target
14322,3.92382,1.9664,1504.4508056640625,6.8421478271,29.1990871429,24.987167358400004,2486.6153846153848,1550.0,1550.0,1550.0,...,2.28508,36.9781036377,-0.6535090208,1.5784968138,4927.0834960938,76,5.5302257538,47.0196762085,57.899303,0.012254901960784314
14323,3.92382,1.9664,1504.4508056640625,6.8421478271,29.1990871429,24.987167358400004,2486.6153846153848,1550.0,1550.0,1550.0,...,2.28508,36.9781036377,-0.6535090208,1.5784968138,4927.0834960938,76,5.5302257538,47.0196762085,57.899303,0.012254901960784314
14324,3.92382,1.9664,1504.4508056640625,6.8421478271,29.1990871429,24.987167358400004,2486.6153846153848,1550.0,1550.0,1550.0,...,2.28508,36.9781036377,-0.6535090208,1.5784968138,4927.0834960938,76,5.5302257538,47.0196762085,57.899303,0.012254901960784314
14325,3.92382,1.9664,1504.4508056640625,6.8421478271,29.1990871429,24.987167358400004,2486.6153846153848,1550.0,1550.0,1550.0,...,2.28508,36.9781036377,-0.6535090208,1.5784968138,4927.0834960938,76,5.5302257538,47.0196762085,57.899303,0.012254901960784314
14326,3.92382,1.9664,1504.4508056640625,6.8421478271,29.1990871429,24.987167358400004,2486.6153846153848,1550.0,1550.0,1550.0,...,2.28508,36.9781036377,-0.6535090208,1.5784968138,4927.0834960938,76,5.5302257538,47.0196762085,57.899303,0.012254901960784314
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1015224,3.9055,2.18078,1501.8114013671875,1.8482685089,21.8415584564,18.0917053223,2555.333333333333,1550.0,1550.0,1550.0,...,2.52114,36.6164169312,-0.7646554708,1.4233630896,4931.8579101562,76.5999908447,5.643784523,43.952545166,56.857637,0.012027491408934709
1015225,3.9055,2.18078,1501.8114013671875,1.8482685089,21.8415584564,18.0917053223,2555.333333333333,1550.0,1550.0,1550.0,...,2.52114,36.6164169312,-0.7646554708,1.4233630896,4931.8579101562,76.5999908447,5.643784523,43.952545166,56.857637,0.012027491408934709
1015226,3.9055,2.18078,1501.8114013671875,1.8482685089,21.8415584564,18.0917053223,2555.333333333333,1550.0,1550.0,1550.0,...,2.52114,36.6164169312,-0.7646554708,1.4233630896,4931.8579101562,76.5999908447,5.643784523,43.952545166,56.857637,0.012027491408934709
1015227,3.9055,2.18078,1501.8114013671875,1.8482685089,21.8415584564,18.0917053223,2555.333333333333,1550.0,1550.0,1550.0,...,2.52114,36.6164169312,-0.7646554708,1.4233630896,4931.8579101562,76.5999908447,5.643784523,43.952545166,56.857637,0.012027491408934709


In [154]:
def tune_xgb(
    X_train,
    X_test,
    y_train,
    y_test,
    n_trials=50,
    seed=42,
    is_classification=False,
    printer=colour_printer,
):
    """
    Tunes an XGBoost model (regressor or classifier) using Optuna for hyperparameter optimization.

    This function validates the input data, checks for numeric features, and ensures the target is not constant (for regression).
    It defines an Optuna objective function that trains and evaluates an XGBoost model with hyperparameters suggested by Optuna.
    The function maximizes either R^2 (regression) or accuracy (classification) on the test set, and returns the Optuna study object
    containing the best hyperparameters and score.

    :param X_train: Training input features as a pandas DataFrame or numpy array. All features must be numeric.
    :param X_test: Testing input features as a pandas DataFrame or numpy array. All features must be numeric.
    :param y_train: Target labels corresponding to X_train.
    :param y_test: Target labels corresponding to X_test.
    :param n_trials: Number of Optuna trials to run for hyperparameter search (default: 50).
    :param seed: Random seed for reproducibility (default: 42).
    :param is_classification: If True, tunes an XGBClassifier; otherwise, an XGBRegressor.
    :param printer: Function to print messages (default: colour_printer).

    :raises ValueError: If input data is empty, features are not numeric, target is constant (for regression), or lengths mismatch.
    :raises TypeError: If input types are incorrect or unknown parameters are provided.

    :return: Optuna study object containing the best hyperparameters and score.
    """
    if not isinstance(X_train, (pd.DataFrame, np.ndarray)):
        raise TypeError("X_train must be a pandas DataFrame or numpy array.")

    if not isinstance(X_test, (pd.DataFrame, np.ndarray)):
        raise TypeError("X_test must be a pandas DataFrame or numpy array.")

    if len(X_train) != len(y_train):
        raise ValueError("X_train and y_train must have the same length.")
    if not isinstance(n_trials, int) or n_trials <= 0:
        raise ValueError("n_trials must be a positive integer.")
    if not isinstance(seed, int):
        raise TypeError("seed must be an integer.")

    # Handle constant regression target
    y_arr = np.asarray(y_train)
    if np.all(y_arr == y_arr[0]):
        printer(
            "Regression target is constant; cannot tune model.", fg="red", bold=True
        )
        raise ValueError("Regression target is constant; cannot tune model.")

    # Handle unseen classes in classification
    if is_classification:
        train_classes = set(np.unique(y_train))
        test_classes = set(np.unique(y_test))
        if not test_classes.issubset(train_classes):
            printer("Unseen classes in test set.", fg="red", bold=True)
            raise ValueError("Unseen classes in test set.")

    def objective(trial):
        trial_number = trial.number
        printer(f"Running trial {trial_number}...", fg="blue")
        # common hyperparameter search space
        param = {
            "n_estimators": trial.suggest_int("n_estimators", 50, 500),
            "max_depth": trial.suggest_int("max_depth", 3, 12),
            "learning_rate": trial.suggest_float("learning_rate", 1e-3, 1e-1),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
            "min_child_weight": trial.suggest_int("min_child_weight", 1, 10),
            "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 1.0),
            "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 1.0),
            "random_state": seed,
        }

        if is_classification:
            # classification-specific
            param["objective"] = "binary:logistic"
            model = xgb.XGBClassifier(**param, eval_metric="logloss")
        else:
            # regression-specific
            param["objective"] = "reg:squarederror"
            param["verbosity"] = 0
            param["booster"] = "gbtree"
            model = xgb.XGBRegressor(**param)

        # train and predict
        model.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
        preds = model.predict(X_test)

        # return the metric to maximize
        if is_classification:
            return accuracy_score(y_test, preds)
        else:
            return r2_score(y_test, preds)

    # always maximize (higher R^2 or higher accuracy)
    study = optuna.create_study(
        direction="maximize", sampler=optuna.samplers.TPESampler(seed=seed)
    )
    study.optimize(objective, n_trials=n_trials)

    return study


In [155]:
# — example usage —
# classification = False   # for regression
# classification = True    # for classification
study = tune_xgb(X_train, X_test, y_train, y_test,n_trials=50, is_classification=CLASSIFICATION)
best_params, best_score = study.best_params, study.best_value
print("Best score:", best_score)
print("Best hyperparameters:", best_params)

[I 2025-07-01 14:11:50,127] A new study created in memory with name: no-name-88e33166-6f6f-4fee-96d8-7540ccda14c2


[34mRunning trial 0...[0m


[I 2025-07-01 14:11:52,044] Trial 0 finished with value: 0.9996028808407059 and parameters: {'n_estimators': 218, 'max_depth': 12, 'learning_rate': 0.07346740023932911, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182, 'min_child_weight': 2, 'reg_alpha': 0.05808362158736334, 'reg_lambda': 0.8661761471131737}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 1...[0m


[I 2025-07-01 14:11:54,423] Trial 1 finished with value: 0.8408833425312222 and parameters: {'n_estimators': 321, 'max_depth': 10, 'learning_rate': 0.0030378649352844423, 'subsample': 0.9849549260809971, 'colsample_bytree': 0.9162213204002109, 'min_child_weight': 3, 'reg_alpha': 0.18182497538885092, 'reg_lambda': 0.1834045180193887}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 2...[0m


[I 2025-07-01 14:11:55,740] Trial 2 finished with value: 0.9984161084089088 and parameters: {'n_estimators': 187, 'max_depth': 8, 'learning_rate': 0.04376255684556946, 'subsample': 0.645614570099021, 'colsample_bytree': 0.8059264473611898, 'min_child_weight': 2, 'reg_alpha': 0.2921446556137717, 'reg_lambda': 0.36636184963007323}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 3...[0m


[I 2025-07-01 14:11:57,662] Trial 3 finished with value: 0.9979817953058299 and parameters: {'n_estimators': 255, 'max_depth': 10, 'learning_rate': 0.020767704433677616, 'subsample': 0.7571172192068059, 'colsample_bytree': 0.7962072844310213, 'min_child_weight': 1, 'reg_alpha': 0.6075448558259898, 'reg_lambda': 0.17052413198205027}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 4...[0m


[I 2025-07-01 14:11:58,511] Trial 4 finished with value: 0.9990244520968634 and parameters: {'n_estimators': 79, 'max_depth': 12, 'learning_rate': 0.09659757127438139, 'subsample': 0.9041986740582306, 'colsample_bytree': 0.6523068845866853, 'min_child_weight': 1, 'reg_alpha': 0.6842330296698267, 'reg_lambda': 0.4401524993380763}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 5...[0m


[I 2025-07-01 14:11:59,160] Trial 5 finished with value: 0.5582262169790961 and parameters: {'n_estimators': 105, 'max_depth': 7, 'learning_rate': 0.004404463590406622, 'subsample': 0.954660201039391, 'colsample_bytree': 0.6293899908000085, 'min_child_weight': 7, 'reg_alpha': 0.31171108297230016, 'reg_lambda': 0.5200680259771306}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 6...[0m


[I 2025-07-01 14:12:00,291] Trial 6 finished with value: 0.9964221619443696 and parameters: {'n_estimators': 296, 'max_depth': 4, 'learning_rate': 0.0969888781486913, 'subsample': 0.8875664116805573, 'colsample_bytree': 0.9697494707820946, 'min_child_weight': 9, 'reg_alpha': 0.5978999828320853, 'reg_lambda': 0.9218742358043744}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 7...[0m


[I 2025-07-01 14:12:00,766] Trial 7 finished with value: 0.49327048434957443 and parameters: {'n_estimators': 89, 'max_depth': 4, 'learning_rate': 0.0054775016021432685, 'subsample': 0.6626651653816322, 'colsample_bytree': 0.6943386448447411, 'min_child_weight': 3, 'reg_alpha': 0.8287375108645543, 'reg_lambda': 0.356753333126056}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 8...[0m


[I 2025-07-01 14:12:01,864] Trial 8 finished with value: 0.982912334224783 and parameters: {'n_estimators': 176, 'max_depth': 8, 'learning_rate': 0.014951498272501501, 'subsample': 0.9010984903770198, 'colsample_bytree': 0.5372753218398854, 'min_child_weight': 10, 'reg_alpha': 0.7722447715742098, 'reg_lambda': 0.19871568954701557}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 9...[0m


[I 2025-07-01 14:12:02,439] Trial 9 finished with value: 0.9966479710124755 and parameters: {'n_estimators': 52, 'max_depth': 11, 'learning_rate': 0.0709788770409141, 'subsample': 0.8645035840204937, 'colsample_bytree': 0.8856351733429728, 'min_child_weight': 1, 'reg_alpha': 0.3584657349596153, 'reg_lambda': 0.1158690683664391}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 10...[0m


[I 2025-07-01 14:12:04,493] Trial 10 finished with value: 0.9991349380584483 and parameters: {'n_estimators': 453, 'max_depth': 6, 'learning_rate': 0.06595915000703412, 'subsample': 0.5089809378074099, 'colsample_bytree': 0.5076838686640521, 'min_child_weight': 5, 'reg_alpha': 0.015144246951314472, 'reg_lambda': 0.9761399000965962}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 11...[0m


[I 2025-07-01 14:12:06,581] Trial 11 finished with value: 0.9991880282581829 and parameters: {'n_estimators': 477, 'max_depth': 6, 'learning_rate': 0.06879954897800857, 'subsample': 0.5085930751344794, 'colsample_bytree': 0.5123526693065024, 'min_child_weight': 5, 'reg_alpha': 0.017030099516058217, 'reg_lambda': 0.9684322141265528}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 12...[0m


[I 2025-07-01 14:12:08,739] Trial 12 finished with value: 0.9991826537689695 and parameters: {'n_estimators': 496, 'max_depth': 6, 'learning_rate': 0.07353227647783467, 'subsample': 0.5032837044352211, 'colsample_bytree': 0.5761645627200017, 'min_child_weight': 5, 'reg_alpha': 0.019119734175538517, 'reg_lambda': 0.742192354701299}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 13...[0m


[I 2025-07-01 14:12:09,941] Trial 13 finished with value: 0.9762216847518863 and parameters: {'n_estimators': 385, 'max_depth': 3, 'learning_rate': 0.05218313341751528, 'subsample': 0.776582882187381, 'colsample_bytree': 0.5937663157166836, 'min_child_weight': 7, 'reg_alpha': 0.9993624386479272, 'reg_lambda': 0.7521755622302247}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 14...[0m


[I 2025-07-01 14:12:11,546] Trial 14 finished with value: 0.9993398916178085 and parameters: {'n_estimators': 225, 'max_depth': 9, 'learning_rate': 0.08455403006940831, 'subsample': 0.6370058072961822, 'colsample_bytree': 0.5065701578589041, 'min_child_weight': 4, 'reg_alpha': 0.13881284616445927, 'reg_lambda': 0.7715329689378894}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 15...[0m


[I 2025-07-01 14:12:13,354] Trial 15 finished with value: 0.9992500609114032 and parameters: {'n_estimators': 215, 'max_depth': 12, 'learning_rate': 0.08399511589326394, 'subsample': 0.6644195468693028, 'colsample_bytree': 0.7223988907495709, 'min_child_weight': 3, 'reg_alpha': 0.19321913361300092, 'reg_lambda': 0.7432794655679429}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 16...[0m


[I 2025-07-01 14:12:15,727] Trial 16 finished with value: 0.9992922191036582 and parameters: {'n_estimators': 348, 'max_depth': 10, 'learning_rate': 0.05067730859391659, 'subsample': 0.5968917547218694, 'colsample_bytree': 0.5707921208261756, 'min_child_weight': 4, 'reg_alpha': 0.46132476605433137, 'reg_lambda': 0.6062359521127435}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 17...[0m


[I 2025-07-01 14:12:16,936] Trial 17 finished with value: 0.9992703995625654 and parameters: {'n_estimators': 152, 'max_depth': 9, 'learning_rate': 0.0842131819811029, 'subsample': 0.8189893644369733, 'colsample_bytree': 0.6555901420761158, 'min_child_weight': 7, 'reg_alpha': 0.16410018615385003, 'reg_lambda': 0.8425165602470802}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 18...[0m


[I 2025-07-01 14:12:19,108] Trial 18 finished with value: 0.9990801740034583 and parameters: {'n_estimators': 246, 'max_depth': 12, 'learning_rate': 0.08420548350821593, 'subsample': 0.7180099389878547, 'colsample_bytree': 0.7700475100567971, 'min_child_weight': 4, 'reg_alpha': 0.12096737982700446, 'reg_lambda': 0.6290491924913858}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 19...[0m


[I 2025-07-01 14:12:20,184] Trial 19 finished with value: 0.9979385154445568 and parameters: {'n_estimators': 142, 'max_depth': 9, 'learning_rate': 0.03832260110587146, 'subsample': 0.5860351917118192, 'colsample_bytree': 0.6121569568853539, 'min_child_weight': 2, 'reg_alpha': 0.44632623893794093, 'reg_lambda': 0.8157298391522443}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 20...[0m


[I 2025-07-01 14:12:23,053] Trial 20 finished with value: 0.9994789969023763 and parameters: {'n_estimators': 393, 'max_depth': 11, 'learning_rate': 0.057859749033007765, 'subsample': 0.7157642328075761, 'colsample_bytree': 0.5456670758695035, 'min_child_weight': 6, 'reg_alpha': 0.10087923882719174, 'reg_lambda': 0.6326303964360144}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 21...[0m


[I 2025-07-01 14:12:26,081] Trial 21 finished with value: 0.9994963574353856 and parameters: {'n_estimators': 405, 'max_depth': 11, 'learning_rate': 0.05987998270752033, 'subsample': 0.7148735992133634, 'colsample_bytree': 0.5549473022812605, 'min_child_weight': 6, 'reg_alpha': 0.08815365007946861, 'reg_lambda': 0.6271101128646857}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 22...[0m


[I 2025-07-01 14:12:28,916] Trial 22 finished with value: 0.9993291446190411 and parameters: {'n_estimators': 408, 'max_depth': 11, 'learning_rate': 0.05810064606341804, 'subsample': 0.7148507812795809, 'colsample_bytree': 0.550726687289138, 'min_child_weight': 6, 'reg_alpha': 0.26016000980254717, 'reg_lambda': 0.6441276504574716}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 23...[0m


[I 2025-07-01 14:12:32,268] Trial 23 finished with value: 0.9993240442422359 and parameters: {'n_estimators': 411, 'max_depth': 11, 'learning_rate': 0.03312032047074087, 'subsample': 0.7814154168201571, 'colsample_bytree': 0.6941448424294733, 'min_child_weight': 8, 'reg_alpha': 0.08796389321600329, 'reg_lambda': 0.5420414245560543}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 24...[0m


[I 2025-07-01 14:12:34,608] Trial 24 finished with value: 0.9993835562077139 and parameters: {'n_estimators': 351, 'max_depth': 11, 'learning_rate': 0.058410221488398, 'subsample': 0.8312712399145312, 'colsample_bytree': 0.606348541401837, 'min_child_weight': 6, 'reg_alpha': 0.3831715768935102, 'reg_lambda': 0.6680838798120045}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 25...[0m


[I 2025-07-01 14:12:37,552] Trial 25 finished with value: 0.9993425134172854 and parameters: {'n_estimators': 447, 'max_depth': 12, 'learning_rate': 0.061470423883508776, 'subsample': 0.7072958782974688, 'colsample_bytree': 0.5385577514297871, 'min_child_weight': 8, 'reg_alpha': 0.24225217936979748, 'reg_lambda': 0.8882828769468167}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 26...[0m


[I 2025-07-01 14:12:39,619] Trial 26 finished with value: 0.9994490199251035 and parameters: {'n_estimators': 292, 'max_depth': 10, 'learning_rate': 0.0767480773321835, 'subsample': 0.8211338116630437, 'colsample_bytree': 0.6713684393913659, 'min_child_weight': 6, 'reg_alpha': 0.09190125307512442, 'reg_lambda': 0.448495729440453}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 27...[0m


[I 2025-07-01 14:12:42,515] Trial 27 finished with value: 0.99948599667317 and parameters: {'n_estimators': 368, 'max_depth': 11, 'learning_rate': 0.04367909048624693, 'subsample': 0.7342337192226447, 'colsample_bytree': 0.5535976491991504, 'min_child_weight': 8, 'reg_alpha': 0.07661390194496248, 'reg_lambda': 0.5730531174267831}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 28...[0m


[I 2025-07-01 14:12:45,670] Trial 28 finished with value: 0.9993493391781569 and parameters: {'n_estimators': 353, 'max_depth': 12, 'learning_rate': 0.02819492228052967, 'subsample': 0.7479967063109072, 'colsample_bytree': 0.6265001487830318, 'min_child_weight': 8, 'reg_alpha': 0.2209842515292027, 'reg_lambda': 0.286933282034596}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 29...[0m


[I 2025-07-01 14:12:48,098] Trial 29 finished with value: 0.9988242730777631 and parameters: {'n_estimators': 317, 'max_depth': 10, 'learning_rate': 0.04373114858252437, 'subsample': 0.7998159287778641, 'colsample_bytree': 0.8472147933184168, 'min_child_weight': 10, 'reg_alpha': 0.06001668397270818, 'reg_lambda': 0.5581313412940082}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 30...[0m


[I 2025-07-01 14:12:50,980] Trial 30 finished with value: 0.9992140762558672 and parameters: {'n_estimators': 437, 'max_depth': 9, 'learning_rate': 0.044554800180102036, 'subsample': 0.68793338827598, 'colsample_bytree': 0.7454187658198306, 'min_child_weight': 9, 'reg_alpha': 0.1836436551562699, 'reg_lambda': 0.6970256703053422}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 31...[0m


[I 2025-07-01 14:12:53,775] Trial 31 finished with value: 0.9994406114984636 and parameters: {'n_estimators': 376, 'max_depth': 11, 'learning_rate': 0.05409575164246422, 'subsample': 0.7412992813016962, 'colsample_bytree': 0.5644420107478969, 'min_child_weight': 7, 'reg_alpha': 0.10162987020713936, 'reg_lambda': 0.47034517605864296}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 32...[0m


[I 2025-07-01 14:12:56,155] Trial 32 finished with value: 0.9995290277484941 and parameters: {'n_estimators': 322, 'max_depth': 11, 'learning_rate': 0.06350957101363379, 'subsample': 0.7386519963289075, 'colsample_bytree': 0.541459278100023, 'min_child_weight': 6, 'reg_alpha': 0.06246227560139786, 'reg_lambda': 0.38300695044288213}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 33...[0m


[I 2025-07-01 14:12:58,113] Trial 33 finished with value: 0.9993691201563543 and parameters: {'n_estimators': 275, 'max_depth': 10, 'learning_rate': 0.07808214616601238, 'subsample': 0.6118556467611441, 'colsample_bytree': 0.5847935007961649, 'min_child_weight': 9, 'reg_alpha': 0.32327472991058337, 'reg_lambda': 0.3615978943861362}. Best is trial 0 with value: 0.9996028808407059.


[34mRunning trial 34...[0m


[I 2025-07-01 14:13:00,668] Trial 34 finished with value: 0.9996542723990144 and parameters: {'n_estimators': 315, 'max_depth': 12, 'learning_rate': 0.04335349423060722, 'subsample': 0.8470245247214504, 'colsample_bytree': 0.5001720662662578, 'min_child_weight': 2, 'reg_alpha': 0.00958344029601977, 'reg_lambda': 0.4008657505083054}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 35...[0m


[I 2025-07-01 14:13:02,565] Trial 35 finished with value: 0.9995792135755698 and parameters: {'n_estimators': 320, 'max_depth': 12, 'learning_rate': 0.0676553300963508, 'subsample': 0.9950104287731715, 'colsample_bytree': 0.5260350491479402, 'min_child_weight': 2, 'reg_alpha': 0.0008809342873865589, 'reg_lambda': 0.26837864298888237}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 36...[0m


[I 2025-07-01 14:13:04,568] Trial 36 finished with value: 0.9996045075229415 and parameters: {'n_estimators': 320, 'max_depth': 12, 'learning_rate': 0.06492801640965386, 'subsample': 0.9243454930577013, 'colsample_bytree': 0.5038596276134546, 'min_child_weight': 2, 'reg_alpha': 0.01490955503781112, 'reg_lambda': 0.019196947935674946}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 37...[0m


[I 2025-07-01 14:13:06,220] Trial 37 finished with value: 0.9995973688932912 and parameters: {'n_estimators': 255, 'max_depth': 12, 'learning_rate': 0.07738271172893224, 'subsample': 0.9968562685285973, 'colsample_bytree': 0.515146698281772, 'min_child_weight': 2, 'reg_alpha': 0.01493785687866013, 'reg_lambda': 0.06927781806713118}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 38...[0m


[I 2025-07-01 14:13:07,791] Trial 38 finished with value: 0.999405991466917 and parameters: {'n_estimators': 256, 'max_depth': 12, 'learning_rate': 0.08791408383521938, 'subsample': 0.9524889255846127, 'colsample_bytree': 0.5040933164872313, 'min_child_weight': 1, 'reg_alpha': 0.557626738746278, 'reg_lambda': 0.006963230311992863}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 39...[0m


[I 2025-07-01 14:13:09,375] Trial 39 finished with value: 0.9981873338348228 and parameters: {'n_estimators': 199, 'max_depth': 12, 'learning_rate': 0.09305389110500266, 'subsample': 0.9364666782917456, 'colsample_bytree': 0.9899226356996809, 'min_child_weight': 2, 'reg_alpha': 0.15368930201131553, 'reg_lambda': 0.0366504823192372}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 40...[0m


[I 2025-07-01 14:13:11,080] Trial 40 finished with value: 0.9993458367475008 and parameters: {'n_estimators': 270, 'max_depth': 8, 'learning_rate': 0.07777073748300582, 'subsample': 0.8624215939922876, 'colsample_bytree': 0.5006064065119095, 'min_child_weight': 3, 'reg_alpha': 0.2859381607977133, 'reg_lambda': 0.08500489108259235}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 41...[0m


[I 2025-07-01 14:13:13,073] Trial 41 finished with value: 0.9996338530798172 and parameters: {'n_estimators': 323, 'max_depth': 12, 'learning_rate': 0.06754271968991148, 'subsample': 0.9636303366078827, 'colsample_bytree': 0.5249315655536981, 'min_child_weight': 2, 'reg_alpha': 0.014626655003476173, 'reg_lambda': 0.2306281396072284}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 42...[0m


[I 2025-07-01 14:13:14,791] Trial 42 finished with value: 0.9995898852544823 and parameters: {'n_estimators': 240, 'max_depth': 12, 'learning_rate': 0.07290874901553951, 'subsample': 0.9722300577520261, 'colsample_bytree': 0.52345322610177, 'min_child_weight': 2, 'reg_alpha': 0.03926598445793026, 'reg_lambda': 0.1852432276536574}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 43...[0m


[I 2025-07-01 14:13:16,674] Trial 43 finished with value: 0.9996229016154226 and parameters: {'n_estimators': 281, 'max_depth': 12, 'learning_rate': 0.06516714317380955, 'subsample': 0.9210042985598068, 'colsample_bytree': 0.5285768761122942, 'min_child_weight': 3, 'reg_alpha': 0.0011583040266222346, 'reg_lambda': 0.12535441030507533}. Best is trial 34 with value: 0.9996542723990144.


[34mRunning trial 44...[0m


[I 2025-07-01 14:13:18,633] Trial 44 finished with value: 0.9996546674572004 and parameters: {'n_estimators': 299, 'max_depth': 10, 'learning_rate': 0.06732235538055657, 'subsample': 0.9232897053077487, 'colsample_bytree': 0.5916408483310093, 'min_child_weight': 3, 'reg_alpha': 0.0034210846533113716, 'reg_lambda': 0.29614774407111255}. Best is trial 44 with value: 0.9996546674572004.


[34mRunning trial 45...[0m


[I 2025-07-01 14:13:20,743] Trial 45 finished with value: 0.9995306191556985 and parameters: {'n_estimators': 297, 'max_depth': 10, 'learning_rate': 0.06414211513464371, 'subsample': 0.9275082461499973, 'colsample_bytree': 0.5860393317197359, 'min_child_weight': 1, 'reg_alpha': 0.14148733025644145, 'reg_lambda': 0.23954078404679988}. Best is trial 44 with value: 0.9996546674572004.


[34mRunning trial 46...[0m


[I 2025-07-01 14:13:23,178] Trial 46 finished with value: 0.999504935771129 and parameters: {'n_estimators': 328, 'max_depth': 12, 'learning_rate': 0.05459468461042818, 'subsample': 0.8752873835800248, 'colsample_bytree': 0.6407500495548348, 'min_child_weight': 3, 'reg_alpha': 0.05657620838164173, 'reg_lambda': 0.14058694353001938}. Best is trial 44 with value: 0.9996546674572004.


[34mRunning trial 47...[0m


[I 2025-07-01 14:13:24,943] Trial 47 finished with value: 0.9987696453674346 and parameters: {'n_estimators': 312, 'max_depth': 7, 'learning_rate': 0.04747913186321505, 'subsample': 0.9147113052090038, 'colsample_bytree': 0.6050363457106182, 'min_child_weight': 3, 'reg_alpha': 0.766167927602472, 'reg_lambda': 0.2935132480291498}. Best is trial 44 with value: 0.9996546674572004.


[34mRunning trial 48...[0m


[I 2025-07-01 14:13:26,924] Trial 48 finished with value: 0.9995907697975319 and parameters: {'n_estimators': 287, 'max_depth': 10, 'learning_rate': 0.06880470307450542, 'subsample': 0.964881207596612, 'colsample_bytree': 0.5320201912769342, 'min_child_weight': 4, 'reg_alpha': 0.006344404677310889, 'reg_lambda': 0.23076279823096435}. Best is trial 44 with value: 0.9996546674572004.


[34mRunning trial 49...[0m


[I 2025-07-01 14:13:29,537] Trial 49 finished with value: 0.9995445390300307 and parameters: {'n_estimators': 333, 'max_depth': 11, 'learning_rate': 0.03780460014396347, 'subsample': 0.8885365008067071, 'colsample_bytree': 0.5736652512267625, 'min_child_weight': 1, 'reg_alpha': 0.20650071636516648, 'reg_lambda': 0.14304538993389165}. Best is trial 44 with value: 0.9996546674572004.


Best score: 0.9996546674572004
Best hyperparameters: {'n_estimators': 299, 'max_depth': 10, 'learning_rate': 0.06732235538055657, 'subsample': 0.9232897053077487, 'colsample_bytree': 0.5916408483310093, 'min_child_weight': 3, 'reg_alpha': 0.0034210846533113716, 'reg_lambda': 0.29614774407111255}


In [156]:
def model_initializer(
    X_train, y_train, is_classification=False, params_dict=None, printer=colour_printer
):
    """
    Initializes and trains an XGBoost model for regression or classification.

    This function validates the input data, checks for numeric features, and ensures the target is not constant (for regression).
    It then instantiates an XGBoost model (XGBRegressor or XGBClassifier) with the provided parameters, fits it to the training data,
    and wraps the predict method in classification mode to check for unseen classes in predictions.

    :param X_train: Training input features as a pandas DataFrame or numpy array. All features must be numeric.
    :param y_train: Target labels corresponding to X_train.
    :param is_classification: If True, initializes an XGBClassifier; otherwise, an XGBRegressor.
    :param params_dict: Dictionary of hyperparameters to pass to the XGBoost model.
    :param printer: Function to print messages (default: colour_printer).

    :raises ValueError: If input data is empty, features are not numeric, target is constant (for regression), or lengths mismatch.
    :raises TypeError: If unknown parameters are provided in params_dict.

    :return: Trained XGBoost model (XGBRegressor or XGBClassifier).
    """
    if params_dict is None:
        params_dict = {}

    # Check for empty input
    if X_train is None or y_train is None or len(X_train) == 0 or len(y_train) == 0:
        raise ValueError("Input data cannot be empty.")

    # Check for length mismatch
    if len(X_train) != len(y_train):
        raise ValueError("X_train and y_train must have the same length.")

    # Check for non-numeric features
    if isinstance(X_train, pd.DataFrame):
        if not np.all([np.issubdtype(dtype, np.number) for dtype in X_train.dtypes]):
            raise ValueError("All features must be numeric.")
    else:
        # If not DataFrame, try to convert to float
        try:
            np.asarray(X_train, dtype=float)
        except Exception:
            raise ValueError("All features must be numeric.")

    # Check for constant regression target
    if not is_classification:
        y_arr = np.asarray(y_train)
        if np.all(y_arr == y_arr[0]):
            printer(
                "Regression target is constant; cannot fit model.", fg="red", bold=True
            )
            raise ValueError("Regression target is constant; cannot fit model.")

    # Try to instantiate the model (catch invalid params)
    xgb_type = XGBClassifier if is_classification else XGBRegressor
    valid_params = xgb_type().get_params().keys()
    unknown_params = set(params_dict) - set(valid_params)
    if unknown_params:
        printer(
            f"Unknown parameter(s) for {xgb_type.__name__}: {unknown_params}",
            fg="red",
            bold=True,
        )
        raise TypeError(
            f"Unknown parameter(s) for {xgb_type.__name__}: {unknown_params}"
        )

    xgb_model = xgb_type(**params_dict)

    start = perf_counter()
    xgb_model.fit(X_train, y_train)
    end = perf_counter()
    printer(f"✔ XGB model completed fitting in {end-start:.2f} seconds", fg="green")

    # Wrap predict to check for unseen classes in classification
    if is_classification:
        train_classes = set(np.unique(y_train))
        orig_predict = xgb_model.predict

        def safe_predict(X):
            preds = orig_predict(X)
            pred_classes = set(np.unique(preds))
            if not pred_classes.issubset(train_classes):
                raise ValueError("Unseen classes in test set.")
            return preds

        xgb_model.predict = safe_predict

    return xgb_model


In [157]:
def get_shap_values(model, X_test, target_class=None, printer=colour_printer):
    """
    Compute SHAP values for a tree-based model using SHAP's TreeExplainer.

    This function creates a SHAP TreeExplainer for the provided model and computes SHAP values for the given test data.
    For multi-class models, you can specify a target_class to extract SHAP values for a specific class.
    The function prints progress and timing information using the provided printer.

    :param model: Trained tree-based model.
    :param X_test: Test features as a pandas DataFrame or numpy array.
    :param target_class: (Optional) Integer index or class label for which to extract SHAP values in multi-class problems.
    :param printer: Function to print messages (default: colour_printer).

    :raises ValueError: If target_class is specified but not found in model classes.

    :return: SHAP Explanation object containing SHAP values for the test data (for the specified class if multi-class).
    """
    printer(f"Getting SHAP values for {type(model).__name__}", fg="yellow")
    start = perf_counter()
    shap_explainer = shap.TreeExplainer(model)
    shap_values = shap_explainer(X_test)
    end = perf_counter()
    printer(
        f"✔ Finished getting SHAP values for {type(model).__name__} in {end - start:.2f} seconds",
        fg="green",
    )

    # Handle multi-class SHAP value output
    if hasattr(shap_values, "values") and shap_values.values.ndim == 3:
        # Determine class index
        if target_class is not None:
            # If target_class is a label, convert to index
            if hasattr(model, "classes_"):
                class_labels = list(model.classes_)
                if target_class in class_labels:
                    class_idx = class_labels.index(target_class)
                else:
                    raise ValueError(
                        f"target_class {target_class} not found in model.classes_: {class_labels}"
                    )
            else:
                class_idx = int(target_class)
        else:
            class_idx = 0  # Default to first class

        shap_values = shap.Explanation(
            values=shap_values.values[:, :, class_idx],
            base_values=shap_values.base_values[:, class_idx]
            if shap_values.base_values.ndim == 2
            else shap_values.base_values,
            data=shap_values.data,
            feature_names=shap_values.feature_names,
        )
    return shap_values


def get_tree_rankings(shap_values):
    """
    Computes feature importances based on mean absolute SHAP values from a tree-based model.

    This function calculates the mean absolute SHAP value for each feature, using the SHAP Explanation object.
    It returns a dictionary mapping feature names to their corresponding mean absolute SHAP values, representing feature importance.

    :param shap_values: SHAP Explanation object containing SHAP values, feature names, and associated data.

    :return: Dictionary where keys are feature names (str) and values are the corresponding mean absolute SHAP values (float).
    """
    mean_abs_shap = np.abs(shap_values.values).mean(axis=0)
    rankings = {
        feature_name: mean_abs_shap_val
        for feature_name, mean_abs_shap_val in zip(
            shap_values.feature_names, mean_abs_shap
        )
    }
    return rankings


def combine_rankings(*rankings_dicts):
    """
    Combine multiple feature importance dictionaries into a single normalized ranking.

    This function takes any number of feature importance dictionaries, each mapping feature names to importance values.
    It normalizes the values within each dictionary so they sum to 1, then averages the normalized scores across all models
    to produce a combined ranking. The result is a sorted list of the top 10 features by average normalized importance.

    :param rankings_dicts: One or more dictionaries where each key is a feature name (str) and each value is a raw importance score (float).

    :return: Sorted list of (feature_name, normalized_importance) tuples representing the top 10 features.
    """
    normalized_rankings = []
    for rankings_dict in rankings_dicts:
        rank_vals_sum = sum(rankings_dict.values())
        normalized_rankings.append(
            {
                feature_name: orig_rank / rank_vals_sum
                for feature_name, orig_rank in rankings_dict.items()
            }
        )

    combined_rankings = []
    for feature_name in normalized_rankings[0].keys():
        normalized_ranking = sum(
            [rankings_dict[feature_name] for rankings_dict in normalized_rankings]
        ) / len(normalized_rankings)
        combined_rankings.append((feature_name, normalized_ranking))
    return sorted(combined_rankings, key=lambda x: x[1], reverse=True)[:10]


def find_quantile_ranges(shap_vals, X, features=None, quantile=0.10, direction="low"):
    """
    For each feature, find the range of X values whose raw SHAP contributions lie in the specified tail.

    This function identifies, for each feature, the range of values in X where the SHAP values are in the specified quantile tail.
    It can be used to find value ranges associated with the most positive or negative SHAP contributions.

    :param shap_vals: SHAP Explanation object (with .values, .feature_names).
    :param X: DataFrame of the same rows that produced shap_vals.
    :param features: List of feature names to include (defaults to all).
    :param quantile: Quantile threshold (e.g., 0.10 for bottom 10%).
    :param direction: "low" for shap ≤ quantile cutoff, "high" for shap ≥ (1 - quantile) cutoff.

    :return: Dictionary {feature: (min_value, max_value)} for each feature. If no rows in that tail, returns (None, None).
    """
    vals = shap_vals.values if hasattr(shap_vals, "values") else shap_vals
    names = (
        shap_vals.feature_names
        if hasattr(shap_vals, "feature_names")
        else list(X.columns)
    )

    if features is None:
        features = names

    out = {}
    for feat in features:
        j = names.index(feat)
        one_shap = vals[:, j]

        if direction == "low":
            cutoff = np.quantile(one_shap, quantile)
            mask = one_shap <= cutoff
        else:  # "high"
            cutoff = np.quantile(one_shap, 1 - quantile)
            mask = one_shap >= cutoff

        if mask.any():
            xs = X.iloc[mask, j]
            out[feat] = (float(xs.min()), float(xs.max()))
        else:
            out[feat] = (None, None)
    return out


def intersect_ranges(*ranked_feat_min_max_dicts):
    """
    Take a list of per-model {feature: (min, max)} dictionaries and return the intersection range for each feature.

    This function computes, for each feature, the intersection of the value ranges across all provided models.
    If the intersection is empty, it returns the full range covered by all models for that feature.

    :param ranked_feat_min_max_dicts: One or more dictionaries mapping feature names to (min, max) tuples.

    :return: Dictionary {feature: (min, max)} representing the intersection range for each feature.
    """

    feats = list(ranked_feat_min_max_dicts[0].keys())

    for rd in ranked_feat_min_max_dicts[1:]:
        if set(rd) != set(feats):
            raise ValueError("Feature keys differ among models.")

    final_dict = {}

    for feat in feats:
        global_feat_min_vals = [
            ranked_dict[feat][0] for ranked_dict in ranked_feat_min_max_dicts
        ]
        global_feat_max_vals = [
            ranked_dict[feat][1] for ranked_dict in ranked_feat_min_max_dicts
        ]

        lo = max(global_feat_min_vals)  # lo
        hi = min(global_feat_max_vals)  # hi

        if lo <= hi:

            final_dict[feat] = (lo, hi)
        else:
            true_min_range = min(global_feat_min_vals)
            true_max_range = max(global_feat_max_vals)
            final_dict[feat] = (true_min_range, true_max_range)

    return final_dict

In [158]:

best_model_trained = model_initializer(X_train,y_train,params_dict=study.best_params, is_classification=CLASSIFICATION)

[32m✔ XGB model completed fitting in 1.68 seconds[0m


In [159]:
xgb_shap = get_shap_values(best_model_trained, X_test_q3)
xgb_ranking =  get_tree_rankings(xgb_shap)
combined_ranking = combine_rankings(xgb_ranking)

range_dict = intersect_ranges(find_quantile_ranges(xgb_shap,X_test_q3,[x[0] for x in combined_ranking]))

[33mGetting SHAP values for XGBRegressor[0m
[32m✔ Finished getting SHAP values for XGBRegressor in 5.81 seconds[0m


In [160]:
def performance_report(model, X_test, y_test):
    print(f"Performance report for {type(model).__name__}:")
    y_pred = model.predict(X_test)

    mae = mean_absolute_error(y_test, y_pred)
    print(f"Mean Absolute Error: {mae}")

    mse = mean_squared_error(y_test, y_pred)
    print(f"Mean Squared Error: {mse}")

    rmse = mean_squared_error(y_test, y_pred, squared=False)
    print(f"Root Mean Sqaured Error: {rmse}")

    mape = mean_absolute_percentage_error(y_test, y_pred)
    print(f"Mean Absolute Percentage Error: {mape}")

    r2 = r2_score(y_test, y_pred)
    print(f"R^2: {r2}")
    print(f"----------------------------------------------------------------")
    return mae, mse, rmse, mape, r2

def make_beeswarm(shap_values):
    """Display a beeswarm plot for given shap_values of a model"""
    shap.plots.beeswarm(shap_values)

def display_ranking(rankings):
    for ranking in rankings:
        print(f"{ranking[0]} - {ranking[1]}")


def display_ranking_and_range(rankings, range_dict, top_n=10, printer=colour_printer):
    """
    Displays the top N features based on mean absolute SHAP value, along with their optimal value ranges.

    This function prints a formatted table of the top features, their mean SHAP values, and the corresponding optimal value ranges,
    with columns dynamically sized to fit the longest feature name and aligned decimal points/brackets.

    :param rankings: List of (feature_name, importance) tuples, sorted by importance.
    :param range_dict: Dictionary mapping feature names to (min, max) value ranges.
    :param top_n: Number of top features to display (default: 10).
    :param printer: Function to print messages (default: colour_printer).
    """
    # Determine dynamic column widths
    feature_names = [feat for feat, _ in rankings[:top_n]]
    max_feat_len = max(
        len("FEATURE"), max((len(str(f)) for f in feature_names), default=0)
    )
    shap_col_width = 12  # Enough for 'MEAN SHAP'
    range_strs = []
    for feat in feature_names:
        rng = range_dict.get(feat, (None, None))
        rng_str = f"[{rng[0]:.5f}, {rng[1]:.5f}]" if None not in rng else "N/A"
        range_strs.append(rng_str)
    max_range_len = max(
        len("OPTIMAL RANGE"), max((len(r) for r in range_strs), default=0)
    )

    # Prepare header
    header = (
        f"{'RANK':<5} "
        f"{'FEATURE':<{max_feat_len}} "
        f"{'MEAN SHAP':>{shap_col_width}} "
        f"{'OPTIMAL RANGE':>{max_range_len}}"
    )
    printer("\n" + header, fg="cyan", bold=True)
    printer("-" * len(header), fg="cyan")

    # Prepare data rows with aligned decimals and brackets
    for i, (feat, mean_shap) in enumerate(rankings[:top_n], 1):
        rng_str = range_strs[i - 1]
        shap_str = f"{mean_shap:.5f}"
        printer(
            f"{i:<5} "
            f"{feat:<{max_feat_len}} "
            f"{shap_str:>{shap_col_width}} "
            f"{rng_str:>{max_range_len}}",
            fg="white",
        )

In [165]:

display_ranking_and_range(combined_ranking, range_dict, top_n=10)
performance_report(best_model_trained,X_test,y_test)


[1m[36m
RANK  FEATURE                                            MEAN SHAP        OPTIMAL RANGE[0m
[36m---------------------------------------------------------------------------------------[0m
[37m1     ABB.ATAS.Si                                          0.12569   [1.48758, 2.16068][0m
[37m2     pouring.ATASW.Si_equivalent                          0.10965   [0.64542, 2.48924][0m
[37m3     moulding.CIM3Data.mould_thickness_correction         0.10662 [15.00000, 90.00000][0m
[37m4     moulding.CIM3Data.core_set_mode                      0.07966   [0.00000, 0.00000][0m
[37m5     moulding.CIM3Data.mould_retainer_time_extension      0.07189   [0.10000, 0.20000][0m
[37m6     moulding.CIM3Data.squeeze_pressure                   0.04417  [6.10000, 12.00000][0m
[37m7     pouring.ARL.Mg                                       0.03451   [0.01971, 0.03334][0m
[37m8     pouring.ARL.Cu                                       0.02861   [0.04459, 0.35913][0m
[37m9     pouring.ARL.S 

(0.000170224312048807,
 1.2456282385431477e-06,
 0.0011160771651383015,
 120050106960.92902,
 0.9995094689730281)

In [None]:
#RESULTS for  Q3 (above 75% of the target values) training on the S  group's subset of data!

In [163]:
# RANK  FEATURE                                                 MEAN SHAP           OPTIMAL RANGE
# -----------------------------------------------------------------------------------------------
# 1     moulding.CIM3Data.squeeze_pressure                        0.09630     [6.10000, 16.00000]
# 2     moulding.CIM3Data.swing_plate_spray_position              0.07369 [966.00000, 1500.00000]
# 3     moulding.CIM3Data.mould_thickness_correction              0.06372   [10.00000, 100.00000]
# 4     ABB.ATAS.Si                                               0.04685      [1.31660, 1.97310]
# 5     sand_plant.Sand_Mixer.Temperature_return_sand             0.03728    [27.98032, 49.65278]
# 6     moulding.CIM3Data.core_setter_core_set_time_stage_3       0.03503      [0.10000, 5.00000]
# 7     pouring.ATASW.Si_equivalent                               0.03041      [0.52420, 2.65512]
# 8     pouring.ARL.Ti                                            0.03031      [0.01620, 0.03315]
# 9     moulding.CIM3Data.core_setter_core_set_force_stage_3      0.02675  [145.00000, 500.00000]
# 10    pouring.ATASW.C                                           0.02482      [3.61802, 4.26846]
# Performance report for XGBRegressor:
# Mean Absolute Error: 0.006007528059404573
# Mean Squared Error: 0.00011617187312877407
# Root Mean Sqaured Error: 0.01077830567059471
# Mean Absolute Percentage Error: 3200379425107.812
# R^2: 0.969097461812619

# RESULTS FOR BATCH S

In [164]:
# RANK  FEATURE                                                        MEAN SHAP          OPTIMAL RANGE
# -----------------------------------------------------------------------------------------------------
# 1     moulding.CIM3Mould.compressibility_actual                        0.25387   [20.50028, 24.26221]
# 2     moulding.CIM3Data.closeup_correction                             0.10891   [-5.00000, -3.50000]
# 3     moulding.CIM3Data.core_setter_core_speed_forward_stage_2         0.10200 [100.00000, 100.00000]
# 4     sand_plant.Sand_Mixer.Correction_moisture                        0.09663    [-0.74979, 0.00000]
# 5     pouring.ARL.Ti                                                   0.08685     [0.01572, 0.02418]
# 6     moulding.CIM3Data.core_set_mode                                  0.06714     [0.00000, 0.00000]
# 7     sand_plant.Sand_Mixer.S3C1_Act_water                             0.03420   [77.50000, 98.00000]
# 8     pouring.ARL.S                                                    0.03102     [0.00147, 0.00327]
# 9     sand_plant.Sand_Mixer.bentonite_and_coal_dust_mix_additions      0.02729  [45.05208, 108.85416]
# 10    sand_plant.Sand_Mixer.Moisture_return_sand                       0.02448     [0.00000, 1.62116]
# Performance report for XGBRegressor:
# Mean Absolute Error: 1.7913057109436733e-05
# Mean Squared Error: 5.706850824590615e-10
# Root Mean Sqaured Error: 2.3889015937435797e-05
# Mean Absolute Percentage Error: 39152073439.7227
# R^2: 0.9532946695246253

# RESULT FOR Q1