In [2]:
# ==============================================================================
#        LEGAL CASE OUTCOME PREDICTION PIPELINE (final version)
# ==============================================================================
#
# This script implements a machine learning pipeline for predicting legal case outcomes,
# specifically custody decisions ("mother", "father", "shared"). It handles data loading,
# preprocessing, feature encoding, judge-specific bucketing, flexible data balancing,
# model training (XGBoost, RF, LogReg, SVM), hyperparameter tuning, evaluation,
# and comprehensive result exports.

# ==========================
# Standard library imports
# ==========================
import os
import json
import logging
import sys
import io
import gc
import time
import traceback
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Union, Callable
from datetime import datetime
from collections import Counter
import warnings

# ==========================
# Third-party imports
# ==========================
import matplotlib.pyplot as plt # type: ignore
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from xgboost import XGBClassifier # type: ignore
from sklearn.ensemble import RandomForestClassifier # type: ignore
from sklearn.linear_model import LogisticRegression # type: ignore
from sklearn.svm import SVC # type: ignore
from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV # type: ignore
from sklearn.preprocessing import StandardScaler, OneHotEncoder # type: ignore
from sklearn.compose import ColumnTransformer # type: ignore
from sklearn.pipeline import Pipeline as SklearnPipeline # type: ignore
from sklearn.metrics import ( # type: ignore
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    confusion_matrix,
    ConfusionMatrixDisplay,
)

from sklearn.model_selection import (
    StratifiedKFold,
    RandomizedSearchCV,
    train_test_split,
)

from tabulate import tabulate # type: ignore

try:
    from imblearn.over_sampling import RandomOverSampler # type: ignore
    from imblearn.under_sampling import RandomUnderSampler # type: ignore
    # Imblearn Pipeline is not strictly needed if we apply samplers directly
except ImportError:
    RandomOverSampler = RandomUnderSampler = None
    logging.warning("imbalanced-learn library not found. Sampling-based balancing will not be available.")

# Suppress common warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')
warnings.filterwarnings('ignore', category=FutureWarning, module='sklearn')
warnings.filterwarnings('ignore', category=UserWarning, module='xgboost')

# ------------------------------------------------------------------ #
# 1️⃣  IMPORTS  (extend the existing sklearn.model_selection import) #
# ------------------------------------------------------------------ #


# ==========================
# Constants
# ==========================
RANDOM_STATE: int = 42
DEFAULT_OUTPUT_BASE_DIR: str = "pipeline_outputs"
CONFIG_FILENAME: str = "run_config.json"
LOG_FILENAME: str = "pipeline_run.log"
RESULTS_XLSX_FILENAME: str = "pipeline_results.xlsx"
CM_SUBDIR_NAME: str = "confusion_matrices"

# Target variable specifics
TARGET_CLASS_NAMES: List[str] = ["mother", "father", "shared"] # Corresponds to 0, 1, 2
TARGET_CLASS_MAP: Dict[str, int] = {name: i for i, name in enumerate(TARGET_CLASS_NAMES)}
CLASS_LABELS_NUMERIC: List[int] = list(TARGET_CLASS_MAP.values())

# ─── balancing choices ─────────────────────────
BALANCING_METHODS: List[str] = ["none", "sampling", "weighting"]

# default number of features we will keep when exporting
TOP_K_FEATURES: int = 15

# ==========================
# Global Configuration Store
# ==========================
CONFIG: Dict[str, Any] = {}


# ==========================
# Logging Setup
# ==========================
def setup_logging(log_path: Union[str, Path]) -> None:
    """Configures logging to file and console."""
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)-7s] [%(filename)s:%(lineno)d] %(message)s",
        handlers=[
            logging.FileHandler(log_path, mode='w', encoding='utf-8'), # Overwrite log for new run
            logging.StreamHandler(sys.stdout) # Ensure stdout is used
        ]
    )
    logging.info(f"Logging initialized. Log file: {log_path}")
    # Test encoding for logger
    logging.info("Test log with special characters: éàçüö €")


# ==========================
# Console Encoding Fix (from user script)
# ==========================
def fix_console_encoding() -> None:
    """Attempt to fix console encoding issues, especially on Windows."""
    if sys.platform == "win32":
        try:
            # Try to set console to UTF-8 if possible
            os.system("chcp 65001 > nul") # type: ignore
            sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
            sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
            logging.info("Attempted to set console to UTF-8 (chcp 65001) and wrap stdout/stderr.")
        except Exception as e:
            logging.warning(f"Could not fully set console to UTF-8 or wrap streams: {e}. Using PYTHONIOENCODING as fallback if set.")
            # Fallback for environments where buffer is not available or chcp fails (e.g., some IDE consoles)
            if "PYTHONIOENCODING" not in os.environ:
                os.environ["PYTHONIOENCODING"] = "utf-8"
                logging.info("Set PYTHONIOENCODING=utf-8 as a fallback.")

# Call encoding fix early
fix_console_encoding()

# ==========================
# Utility Functions
# ==========================
def create_output_directory(base_dir: str = DEFAULT_OUTPUT_BASE_DIR) -> str:
    """Creates a timestamped output directory for an experiment run."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(base_dir, f"experiment_{timestamp}")
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    return output_dir

def print_section_header(title: str) -> None:
    """Prints a formatted section header to the console."""
    bar = "=" * 80
    print(f"\n{bar}\n{title.center(80)}\n{bar}\n")

def safe_input(prompt: str, default: Optional[str] = None, type_caster: Optional[Callable] = None, choices: Optional[List[str]] = None) -> Any:
    """Generic safe input function with type casting and choice validation."""
    while True:
        default_str = f" (default: {default})" if default is not None else ""
        choice_str = f" (options: {', '.join(choices)})" if choices else ""
        user_input = input(f"{prompt}{default_str}{choice_str}: ").strip()

        if not user_input and default is not None:
            user_input = str(default) # Ensure default is string for processing

        if choices and user_input.lower() not in [c.lower() for c in choices]:
            print(f"Invalid choice. Please select from: {', '.join(choices)}")
            continue

        if type_caster:
            try:
                return type_caster(user_input)
            except ValueError:
                print(f"Invalid input type. Expected {type_caster.__name__ if hasattr(type_caster, '__name__') else 'specific type'}.")
        else:
            return user_input # Return as string if no caster

def get_yes_no_input(prompt_message: str, default_yes: bool = True) -> bool:
    """Gets a yes/no input from the user."""
    suffix = "(Y/n)" if default_yes else "(y/N)"
    while True:
        user_input = input(f"{prompt_message} {suffix}: ").strip().lower()
        if not user_input:
            return default_yes
        if user_input in ['y', 'yes']:
            return True
        if user_input in ['n', 'no']:
            return False
        print("Invalid input. Please enter 'y' or 'n'.")

# Assuming your display_df_columns is similar to this:
def display_df_columns(df_to_display: pd.DataFrame, description: str) -> None:
    """Displays DataFrame columns in a tabulated format with 0-based indexing for THIS list."""
    print_section_header(description) # Use your existing print_section_header
    if df_to_display.empty and not list(df_to_display.columns): # Handle empty df with no columns
        print(f"  ({description} is empty or has no columns to display)")
        return
        
    # df_to_display.columns will give the list of column names
    # The indices will be 0 to len(df_to_display.columns)-1
    cols_tbl = [[idx, col_name] for idx, col_name in enumerate(df_to_display.columns)]
    
    if not cols_tbl: # If there are columns but they are all None or something odd
        print(f"  (No valid column names to display for {description})")
        return

    print(tabulate(cols_tbl, headers=["Index (for this list)", "Column Name"], tablefmt="grid"))
    print(f"  Total columns in this list: {len(df_to_display.columns)}\n")

class NumpyJSONEncoder(json.JSONEncoder):
    """JSON encoder that handles NumPy data types (from user script)."""
    def default(self, obj: Any) -> Any:
        if isinstance(obj, (np.integer, np.int64, np.int32)): return int(obj)
        if isinstance(obj, (np.floating, np.float32, np.float64)): return float(obj)
        if isinstance(obj, np.ndarray): return obj.tolist()
        if isinstance(obj, np.bool_): return bool(obj)
        if isinstance(obj, (Path)): return str(obj) # Handle Path objects
        if isinstance(obj, pd.Timestamp): return obj.isoformat()
        try: # For other non-serializable objects, convert to string representation
            return super().default(obj)
        except TypeError:
            return str(obj)


# ==========================
# Configuration Management
# ==========================
def _clean_path_input(path_str: str) -> str:
    """Strips quotes and whitespace from a path string."""
    return path_str.strip().strip('"').strip("'")

def get_file_paths_config() -> Dict[str, Any]:
    """Gets file paths and saving directory configuration interactively."""
    print_section_header("File Paths Configuration")
    cfg = {}
    cfg['output_base_dir'] = _clean_path_input(safe_input("Enter base directory for outputs", DEFAULT_OUTPUT_BASE_DIR))
    cfg['cases_file_path'] = _clean_path_input(safe_input("Enter path to cases XLSX file"))
    cfg['judges_file_path'] = _clean_path_input(safe_input("Enter path to judges XLSX file (optional, press Enter to skip)", default="")) or None
    return cfg

def _select_columns_interactive(df_columns: List[str], prompt_message: str, allow_multiple: bool = False, is_required: bool = True) -> Union[Optional[str], List[str]]:
    """Helper for interactive column selection by name or index."""
    if not df_columns:
        print(f"Warning: No columns available for selection for '{prompt_message}'.")
        return [] if allow_multiple else None

    while True:
        user_input_str = safe_input(prompt_message).strip()

        if not user_input_str:
            if is_required:
                print("This field is required. Please provide an input.")
                continue
            return [] if allow_multiple else None

        selected_col_names: List[str] = []
        inputs = [item.strip() for item in user_input_str.split(',')] if allow_multiple else [user_input_str]

        valid_selection = True
        for item in inputs:
            try:
                if item.isdigit(): # Index-based selection
                    idx = int(item)
                    if 0 <= idx < len(df_columns):
                        selected_col_names.append(df_columns[idx])
                    else:
                        print(f"Error: Index {idx} is out of range (0-{len(df_columns)-1}).")
                        valid_selection = False; break
                elif item in df_columns: # Name-based selection
                    selected_col_names.append(item)
                else:
                    print(f"Error: Column '{item}' not found.")
                    # Provide suggestions for similar column names
                    from difflib import get_close_matches
                    matches = get_close_matches(item, df_columns, n=3, cutoff=0.6)
                    if matches:
                        print(f"Did you mean one of these: {', '.join(matches)}?")
                    valid_selection = False; break
            except ValueError:
                print(f"Error: Invalid input '{item}'. Use column names or indices.")
                valid_selection = False; break
        
        if valid_selection:
            return selected_col_names if allow_multiple else selected_col_names[0]

def get_operational_columns_config(df_final: pd.DataFrame, current_config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Gets bucketing ID, target, and features from the FINAL merged DataFrame.
    Ensures correct indexing for feature selection.
    """
    print_section_header("Select Operational Columns from Merged Data")
    
    # Display ALL columns of the final merged DataFrame first for context
    display_df_columns(df_final, "Full List of Columns in Final Merged Data")
    
    cfg = {}
    df_cols_list_full = list(df_final.columns) # All columns from the merged df

    # --- 1. Select Judge ID for Bucketing (Optional) ---
    if get_yes_no_input("Will you use judge-specific bucketing?", default_yes=True):
        print("\n--- Select Judge ID for Bucketing ---")
        # User selects from the full list of columns in the merged DataFrame
        cfg['judge_id_col_final_bucketing'] = _select_columns_interactive(
            df_cols_list_full, # Select from ALL columns in df_final
            "Enter the column NAME or INDEX (from the full list above) for Judge ID (for bucketing)",
            allow_multiple=False, 
            is_required=True # If they said yes to bucketing, this is required
        )
    else:
        cfg['judge_id_col_final_bucketing'] = None
    logging.info(f"Judge ID for bucketing set to: {cfg['judge_id_col_final_bucketing']}")

    # --- 2. Select Target Column ---
    print("\n--- Select Target Column ---")
    # Columns to exclude from being the target: the bucketing ID if selected
    cols_to_exclude_for_target = [cfg.get('judge_id_col_final_bucketing')]
    available_for_target = [
        col for col in df_cols_list_full 
        if col not in cols_to_exclude_for_target and col is not None
    ]
    # Display the list from which the target will be selected (which is almost the full list)
    # For clarity, could re-display df_cols_list_full here or a filtered version if it's very different.
    # For now, assume user refers to the "Full List" displayed at the start of this function.
    cfg['target_col'] = _select_columns_interactive(
        df_cols_list_full, # User selects target from ALL columns in df_final
        f"Enter the TARGET column NAME or INDEX (from the full list above, e.g., 'custody_outcome')",
        allow_multiple=False, 
        is_required=True
    )
    logging.info(f"Target column set to: {cfg['target_col']}")

    # --- 3. Prepare List of Potential Feature Columns ---
    # Exclude bucketing ID (if any) and the chosen target column
    cols_to_exclude_for_features = [
        cfg.get('judge_id_col_final_bucketing'),
        cfg.get('target_col')
    ]
    potential_feature_cols = [
        col for col in df_cols_list_full 
        if col not in cols_to_exclude_for_features and col is not None
    ]

    if not potential_feature_cols:
        logging.error("No potential feature columns remain after excluding ID and target. Cannot proceed.")
        # In a real scenario, you might want to sys.exit() or raise an error here.
        # For now, return empty features, which will likely cause issues later.
        cfg['feature_cols'] = []
        logging.info(f"Operational columns selected: Bucketing ID='{cfg.get('judge_id_col_final_bucketing')}', Target='{cfg['target_col']}', Features=0.")
        return cfg

    # --- 4. Select Feature Columns ---
    print_section_header("Select Feature Columns")
    logging.info(f"Presenting {len(potential_feature_cols)} potential features for selection.")
    
    # IMPORTANT: Display the `potential_feature_cols` list with ITS OWN 0-based indexing.
    # The `display_df_columns` helper needs to be able to take a list of column names
    # and display them with 0-based indexing for that list.
    # Let's assume `display_df_columns` can handle this by creating a temporary DataFrame for display:
    df_potential_features_for_display = pd.DataFrame(columns=potential_feature_cols)
    display_df_columns(df_potential_features_for_display, "POTENTIAL FEATURES (select from this list using its 0-based index or name)")

    if get_yes_no_input(f"Use all {len(potential_feature_cols)} listed POTENTIAL features?", default_yes=True):
        cfg['feature_cols'] = potential_feature_cols # Assign all names
    else:
        # User selects from the `potential_feature_cols` list using ITS 0-based index or name
        selected_features_from_potential_list = _select_columns_interactive(
            potential_feature_cols, # Pass the list they are seeing
            "Enter FEATURE column NAMES or INDICES (comma-separated, from the POTENTIAL FEATURES list above)",
            allow_multiple=True, 
            is_required=True # If they choose to select manually, they must select at least one
        )
        cfg['feature_cols'] = selected_features_from_potential_list if selected_features_from_potential_list else []


    # --- 5. Confirm Selected Features ---
    if cfg.get('feature_cols'):
        print_section_header("Confirm Final Selected Features")
        print("You have selected the following features for the model:")
        # Display the NAMES of the features that are now in cfg['feature_cols']
        # These names come from the `potential_feature_cols` list.
        selected_features_table = [[i, name] for i, name in enumerate(cfg['feature_cols'])]
        print(tabulate(selected_features_table, headers=["# (For Info)", "Selected Feature Name"], tablefmt="grid"))
        
        if not get_yes_no_input("Proceed with these features?", default_yes=True):
            logging.warning("User did not confirm the selected features. The pipeline will proceed with the current selection. To change, restart configuration.")
            # For a more robust UI, you would loop back to the feature selection step here.
            # Simplified: proceed with what was selected.
    elif not cfg.get('feature_cols'): # This case happens if user said "no" to all features, then selected none.
        logging.error("No features were selected. At least one feature is generally required for modeling.")
        # This should ideally be prevented by `_select_columns_interactive` if `is_required=True`.
        # If it can still happen, the pipeline might fail later.
        # For now, allow it to proceed with empty features, but it's a critical state.

    num_selected_features = len(cfg.get('feature_cols', []))
    logging.info(f"Operational columns configuration complete: Bucketing ID='{cfg.get('judge_id_col_final_bucketing')}', Target='{cfg['target_col']}', Features selected={num_selected_features}.")
    if num_selected_features > 0:
        logging.info(f"Selected feature names: {cfg.get('feature_cols')}")
    
    return cfg

def validate_feature_types_interactive(df: pd.DataFrame, feature_cols: List[str]) -> Tuple[List[str], List[str]]:
    """
    Interactively validates feature types (categorical/numeric),
    incorporating a heuristic for unique value counts.
    """
    print_section_header("Feature Type Validation")
    numerical_features: List[str] = []
    categorical_features: List[str] = []

    if not feature_cols:
        logging.warning("No feature columns provided for type validation.")
        return numerical_features, categorical_features

    for col in feature_cols:
        if col not in df.columns:
            logging.warning(f"Feature column '{col}' not found in DataFrame. Skipping validation for this column.")
            continue

        print(f"\n--- Validating Feature: '{col}' ---")
        print(f"  Original Data Type: {df[col].dtype}")
        
        # Handle potential all-NaN columns gracefully before nunique() or stats
        if df[col].isnull().all():
            print("  Column contains only NaN values.")
            # Decide how to treat all-NaN columns. Often categorical or dropped later.
            # For now, let's suggest categorical and let user decide.
            default_type = "categorical"
            suggestion_reason = "Column is all NaN values."
            nunique = 0
        else:
            unique_values = df[col].dropna().unique() # Drop NaNs for unique value display
            nunique = len(unique_values) # df[col].nunique() also works but might be slower on object types
            print(f"  Unique Values (sample, non-NaN): {unique_values[:min(5, len(unique_values))]}")
            print(f"  Number of Unique Values (non-NaN): {nunique}")

            is_numeric_dtype = pd.api.types.is_numeric_dtype(df[col])

            # Determine the default suggestion based on dtype and your nunique rule
            if is_numeric_dtype:
                if nunique > 7:
                    default_type = "numeric"
                    suggestion_reason = "Detected as numeric dtype with >7 unique values."
                elif nunique == 0 and df[col].isnull().all(): # Already handled above, but as a safeguard
                    default_type = "categorical"
                    suggestion_reason = "Column is all NaN values (numeric dtype)."
                else: # nunique <= 7 and is_numeric_dtype (and not all NaN)
                    default_type = "categorical"
                    suggestion_reason = "Numeric dtype, but <=7 unique values suggests it might be categorical (e.g., codes)."
            else: # Not a numeric dtype (e.g., object, string, bool)
                default_type = "categorical"
                suggestion_reason = "Detected as non-numeric dtype."
        
        print(f"  Suggestion: Treat as {default_type.upper()}. Reason: {suggestion_reason}")

        user_choice = safe_input(
            f"  Confirm type for '{col}' (NUMERIC or CATEGORICAL)?",
            default=default_type,
            choices=["numeric", "categorical"]
        ).lower()

        if user_choice == "numeric":
            numerical_features.append(col)
            # Attempt to print numeric stats if confirmed as numeric
            # Ensure the column can actually be treated as numeric before stats
            try:
                # Attempt to convert to numeric if user says so, to catch errors early if not possible
                temp_numeric_col = pd.to_numeric(df[col], errors='raise')
                if temp_numeric_col.notna().any(): # Check if not all NaNs after conversion
                    col_stats = temp_numeric_col.agg(['min', 'max', 'mean'])
                    print(f"    Selected as NUMERIC. Stats: Min={col_stats['min']:.2f}, Max={col_stats['max']:.2f}, Mean={col_stats['mean']:.2f}")
                else:
                    print("    Selected as NUMERIC, but column became all NaNs after attempted numeric conversion or was already all NaNs.")
            except (ValueError, TypeError):
                print(f"    Selected as NUMERIC, but failed to convert column '{col}' to a numeric type for stats calculation. Original dtype: {df[col].dtype}. It will be included in numerical_features list, but ensure it's truly numeric for transformers.")
        else: # user_choice == "categorical"
            categorical_features.append(col)
            print(f"    Selected as CATEGORICAL.")
        
        logging.info(f"Feature '{col}': Original dtype {df[col].dtype}, Unique values {nunique}, Validated by user as {user_choice.upper()}.")

    logging.info(f"Final Numerical features ({len(numerical_features)}): {numerical_features}")
    logging.info(f"Final Categorical features ({len(categorical_features)}): {categorical_features}")
    return numerical_features, categorical_features

def get_encoding_config_interactive(df: pd.DataFrame, target_col: str) -> Dict[str, Any]:
    """Gets data encoding configuration, especially for the target variable."""
    print_section_header("Data Encoding Configuration")
    cfg = {}
    print(f"\nTarget Variable: '{target_col}'")
    unique_targets_original = df[target_col].dropna().unique()
    print(f"  Unique values found in '{target_col}': {unique_targets_original}")
    print(f"  Expected target classes: {', '.join(TARGET_CLASS_NAMES)} (mapped to {CLASS_LABELS_NUMERIC})")

    # Target mapping: Allow user to map their raw labels to the predefined numeric ones
    target_map_raw_to_numeric: Dict[Any, int] = {}
    print("\nPlease map your raw target labels to the standard numeric classes:")
    for std_name, numeric_val in TARGET_CLASS_MAP.items():
        raw_val_str = safe_input(f"  Enter raw label in your data that corresponds to '{std_name}' (Class {numeric_val})")
        # Try to cast raw_val_str to original dtype of target column
        try:
            original_dtype = df[target_col].dtype
            if pd.api.types.is_numeric_dtype(original_dtype) and not pd.api.types.is_bool_dtype(original_dtype):
                raw_val_typed = original_dtype.type(raw_val_str)
            elif pd.api.types.is_bool_dtype(original_dtype): # Handle boolean types carefully
                 if raw_val_str.lower() in ['true', '1', 't', 'yes', 'y']: raw_val_typed = True
                 elif raw_val_str.lower() in ['false', '0', 'f', 'no', 'n']: raw_val_typed = False
                 else: raise ValueError("Invalid boolean string")
            else: # string or object
                raw_val_typed = raw_val_str
            target_map_raw_to_numeric[raw_val_typed] = numeric_val
        except ValueError:
            logging.warning(f"Could not convert '{raw_val_str}' to dtype {original_dtype}. Storing as string.")
            target_map_raw_to_numeric[raw_val_str] = numeric_val
            
    cfg['target_mapping_raw_to_numeric'] = target_map_raw_to_numeric
    logging.info(f"Target encoding map (raw to numeric): {target_map_raw_to_numeric}")
    return cfg

def get_bucketing_config_interactive() -> Dict[str, Any]:
    """Gets judge bucketing configuration."""
    print_section_header("Judge Bucketing Configuration")
    cfg = {}
    cfg['use_judge_bucketing'] = get_yes_no_input("Enable judge-specific bucketing?", default_yes=True)
    if cfg['use_judge_bucketing']:
        cfg['min_samples_per_judge_bucket'] = int(safe_input("Min samples per judge for dedicated bucket (others to 'generic')", default=30, type_caster=int))
    else:
        cfg['min_samples_per_judge_bucket'] = 0 # Not used
    logging.info(f"Bucketing config: {cfg}")
    return cfg

def get_balancing_config_interactive(
    class_counter: Optional[Counter] = None
) -> Dict[str, Any]:
    """
    Simplified balancing config: 'none', 'sampling', 'weighting'.

    If sampling  → ask for target PERCENTAGES.
    If weighting → ask for relative weights (ratios) w.r.t. class 0.
    """
    print_section_header("Class Balancing Configuration")
    cfg: Dict[str, Any] = {}

    # — current distribution preview —
    if class_counter:
        print("Current global distribution:")
        print_class_distribution(class_counter, "GLOBAL")

    # — choose method —
    method = safe_input(
        "Balancing method",
        default="sampling",
        choices=BALANCING_METHODS
    ).lower()
    cfg["balancing_method"] = method

    if method == "sampling":
        default_str = "33,33,34"
        pct_str = safe_input(
            "Target class percentages for (mother,father,shared)",
            default=default_str,
        )
        try:
            p0, p1, p2 = [float(x) for x in pct_str.split(",")]
        except Exception:
            print("Invalid format – falling back to default 33,33,34.")
            p0, p1, p2 = 33.0, 33.0, 34.0
        total = p0 + p1 + p2
        cfg["sampling_target_percentages"] = {
            0: 100.0 * p0 / total,
            1: 100.0 * p1 / total,
            2: 100.0 * p2 / total,
        }
    elif method == "weighting":
        default_str = "1,1,1"
        w_str = safe_input(
            "Relative class weights for (mother,father,shared)",
            default=default_str,
        )
        try:
            w0, w1, w2 = [float(x) for x in w_str.split(",")]
        except Exception:
            print("Invalid format – using equal weights.")
            w0, w1, w2 = 1.0, 1.0, 1.0
        cfg["class_weight_ratios"] = {0: w0, 1: w1, 2: w2}

    logging.info(f"Balancing configuration: {cfg}")
    return cfg

def _load_hyperparameter_grid_file() -> Optional[Dict[str, Any]]:
    """
    Lets the user point to a JSON file that contains a dictionary
    {model_name -> param_distributions}.
    Returns the dict or None if user skips / on error.
    """
    print_section_header("Hyper-parameter Grid (optional)")
    wants_file = get_yes_no_input(
        "Load hyper-parameter search space from a JSON file?",
        default_yes=False,
    )
    if not wants_file:
        return None

    path_str = safe_input("Path to hyper-param JSON file")
    path = Path(_clean_path_input(path_str))
    if not path.exists():
        print(f"⚠️  File '{path}' not found. Falling back to built-in grids.")
        logging.warning(f"Hyper-param grid file missing: {path}")
        return None

    try:
        with open(path, "r", encoding="utf-8") as fp:
            grids = json.load(fp)
        print("✅  Custom hyper-parameter grid loaded.")
        logging.info(f"Hyper-param grids loaded from {path}")
        return grids
    except Exception as e:
        print(f"⚠️  Failed to load hyper-parameter JSON: {e}")
        logging.error(f"Hyper-param JSON load error: {e}")
        return None

def get_cv_model_hyperparam_config() -> Dict[str, Any]:
    """
    Builds the cross-validation, model list and hyper-parameter grid section
    of CONFIG.  Now supports external JSON for the grid definition.
    """
    cfg: Dict[str, Any] = {}

    # ── models to run ──
    all_models = ["RandomForest", "LogisticRegression", "SVC", "XGB"]
    use_models = safe_input(
        f"Models to run (comma separated, available {all_models})",
        default=",".join(all_models)
    )
    cfg["models_to_run"] = [m.strip() for m in use_models.split(",") if m.strip()]

    # ── CV settings ──
    cfg["cv_folds"] = int(safe_input("Number of CV folds", default="5", type_caster=int))
    cfg["cv_random_state"] = RANDOM_STATE
    cfg["hyperparameter_tuning_iterations"] = int(
        safe_input("RandomizedSearch iterations", default="50", type_caster=int)
    )

    # ── hyper-parameter grid ──
    external_grid = _load_hyperparameter_grid_file()
    if external_grid:
        cfg["hyperparameter_grids"] = external_grid
        cfg["hyperparameter_grids_loaded_from"] = True
    else:
        # fall-back defaults (feel free to extend)
        cfg["hyperparameter_grids"] = {
            "RandomForest": {
                "n_estimators": [100, 300, 500],
                "max_depth": [None, 10, 20],
                "min_samples_split": [2, 5, 10],
            },
            "LogisticRegression": {
                "C": np.logspace(-3, 3, 7),
                "penalty": ["l2"],
                "solver": ["lbfgs"],
            },
            "SVC": {
                "C": np.logspace(-2, 2, 5),
                "kernel": ["linear", "rbf"],
                "gamma": ["scale", "auto"],
            },
            "XGB": {
                "n_estimators": [200, 400],
                "max_depth": [3, 6, 9],
                "learning_rate": [0.01, 0.1, 0.2],
                "subsample": [0.8, 1.0],
            },
        }
        cfg["hyperparameter_grids_loaded_from"] = False
    return cfg

def load_or_request_config(output_dir: str) -> None:
    """Loads config from file or prompts user if not found/chosen."""
    global CONFIG
    config_file_path = Path(output_dir) / CONFIG_FILENAME

    if get_yes_no_input("Load configuration from a file?", default_yes=config_file_path.exists()):
        if config_file_path.exists():
            chosen_path_str = safe_input("Enter path to configuration JSON file", default=str(config_file_path))
        else:
            chosen_path_str = safe_input("Enter path to configuration JSON file (no default found)")
        
        chosen_path = Path(_clean_path_input(chosen_path_str))
        try:
            with open(chosen_path, 'r', encoding='utf-8') as f:
                CONFIG = json.load(f)
            logging.info(f"Configuration loaded from {chosen_path}")
            # Update output_base_dir from loaded config if it exists, otherwise keep current
            CONFIG['output_base_dir'] = Path(output_dir).parent # Parent of the timestamped experiment dir
            CONFIG['output_dir_actual'] = str(output_dir) # The current timestamped dir
            return
        except Exception as e:
            logging.error(f"Error loading config file '{chosen_path}': {e}. Proceeding with interactive setup.")
            # Fall through to interactive setup
    
    # Interactive Setup
    print_section_header("Pipeline Configuration (Interactive Mode)")
    # 1. File paths (already got output_base_dir, need others)
    CONFIG.update(get_file_paths_config()) # This will update output_base_dir if user changes it
    # Re-derive output_dir_actual based on potentially new output_base_dir from user
    # The output_dir passed to this function was based on initial/default base_dir
    # If user changes output_base_dir, the actual experiment dir path needs to reflect that.
    # However, the timestamped dir is already created. For simplicity, we'll use the one created at startup.
    # A more complex flow would recreate the dir or warn. For now, assume output_dir is fixed once created.
    CONFIG['output_dir_actual'] = str(output_dir)

    # 2. Load data for column selection
    df_cases_preview = load_data_from_path(CONFIG['cases_file_path'], "Cases Data")
    df_judges_preview = None
    if CONFIG.get('judges_file_path'):
        df_judges_preview = load_data_from_path(CONFIG['judges_file_path'], "Judges Data")

    df_merged_preview = initial_merge_for_column_selection(df_cases_preview, df_judges_preview, CONFIG)
    if df_merged_preview is None or df_merged_preview.empty:
        logging.error("Failed to load or merge data for column selection. Exiting.")
        sys.exit("Cannot proceed without data for column selection.")

    # 3. Column selections
    CONFIG.update(get_column_config_interactive(df_merged_preview))

    # 4. Feature type validation
    numerical_cols, categorical_cols = validate_feature_types_interactive(df_merged_preview, CONFIG['feature_cols'])
    CONFIG['numerical_features'] = numerical_cols
    CONFIG['categorical_features'] = categorical_cols
    
    # 5. Encoding (target variable)
    CONFIG.update(get_encoding_config_interactive(df_merged_preview, CONFIG['target_col']))

    # 6. Bucketing
    CONFIG.update(get_bucketing_config_interactive())

    # 7. Balancing
    CONFIG.update(get_balancing_config_interactive())

    # 8. CV, Models, Hyperparams
    CONFIG.update(get_cv_model_hyperparam_config())

    logging.info("Interactive configuration complete.")
    save_config(CONFIG, Path(CONFIG['output_dir_actual']) / CONFIG_FILENAME)


def save_config(config_data: Dict[str, Any], filepath: Path) -> None:
    """Saves the configuration dictionary to a JSON file."""
    try:
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(config_data, f, cls=NumpyJSONEncoder, indent=2)
        logging.info(f"Configuration successfully saved to {filepath}")
    except Exception as e:
        logging.error(f"Error saving configuration to {filepath}: {e}")

# ==========================
# Data Loading and Initial Merge for Column Selection
# ==========================
def load_data_from_path(file_path: Optional[str], description: str) -> Optional[pd.DataFrame]:
    """Loads data from an Excel file (first sheet)."""
    if not file_path:
        logging.info(f"{description} file path not provided. Skipping load.")
        return None
    try:
        logging.info(f"Loading {description} from: {file_path}")
        df = pd.read_excel(file_path, sheet_name=0) # Always first sheet
        logging.info(f"Successfully loaded {description}. Shape: {df.shape}")
        return df
    except FileNotFoundError:
        logging.error(f"Error: File not found at {file_path} for {description}")
    except Exception as e:
        logging.error(f"Error loading {description} from {file_path}: {e}")
    return None

def initial_merge_for_column_selection(
    df_cases: Optional[pd.DataFrame],
    df_judges: Optional[pd.DataFrame],
    current_config: Dict[str, Any]
) -> Optional[pd.DataFrame]:
    """
    Performs a preliminary (“initial”) merge so the user can preview
    columns from both dataframes before the real merge keys are final.
    """
    if df_cases is None:
        return None
    if df_judges is None:
        return df_cases.copy()  # no judges data → just return cases

    # 1. use configured keys if they exist
    case_key  = current_config.get('judge_id_col_cases')
    judge_key = current_config.get('judge_id_col_judges')

    temp_keys_used = False
    if not (case_key and judge_key
            and case_key in df_cases.columns
            and judge_key in df_judges.columns):
        print_section_header("Temporary Merge Key for Column Preview")
        logging.info("Configured merge keys missing; prompting for temporary keys.")
        display_df_columns(df_cases,  "Cases Data (temp merge key)")
        case_key = _select_columns_interactive(
            list(df_cases.columns),
            "Enter temporary merge key from CASES data",
            is_required=True
        )
        display_df_columns(df_judges, "Judges Data (temp merge key)")
        judge_key = _select_columns_interactive(
            list(df_judges.columns),
            "Enter temporary merge key from JUDGES data",
            is_required=True
        )
        temp_keys_used = True
        if not (case_key and judge_key
                and case_key in df_cases.columns
                and judge_key in df_judges.columns):
            logging.error("Invalid temporary merge keys. Using cases data only.")
            return df_cases.copy()

    try:
        df_preview = pd.merge(
            df_cases,
            df_judges,
            left_on=str(case_key),
            right_on=str(judge_key),
            how='left',
            suffixes=('_case', '_judge')
        )
        logging.info(
            f"Pre-merge preview done (keys: {case_key}/{judge_key}, "
            f"temp_used={temp_keys_used}). Shape: {df_preview.shape}"
        )
        return df_preview
    except Exception as e:
        logging.error(f"Error in pre-merge preview: {e}")
        return df_cases.copy()

# ==========================
# Data Preparation (Full Merge & Preprocessing)
# ==========================
def full_data_merge(df_cases: pd.DataFrame, df_judges: Optional[pd.DataFrame], config: Dict[str, Any]) -> pd.DataFrame:
    """Merges cases and judges data based on configuration captured earlier."""
    if df_judges is None or 'merge_config' not in config:
        logging.info("Judges data not available or merge not configured. Using only cases data for pipeline run.")
        return df_cases.copy()
    
    merge_cfg = config['merge_config']
    case_key = str(merge_cfg['cases_link_col'])
    judge_key = str(merge_cfg['judges_link_col'])
    how_merge = merge_cfg.get('how', 'left')

    try:
        df_cases_copy = df_cases.copy()
        df_judges_copy = df_judges.copy()
        df_cases_copy[case_key] = df_cases_copy[case_key].astype(str)
        df_judges_copy[judge_key] = df_judges_copy[judge_key].astype(str)

        df_merged = pd.merge(
            df_cases_copy, df_judges_copy,
            left_on=case_key, right_on=judge_key,
            how=how_merge, suffixes=('_case', '_judge')
        )
        logging.info(f"Pipeline data merge successful. Merged shape: {df_merged.shape}")
        # Further checks for essential columns (target, features, bucketing ID) can be done here or before use.
        return df_merged
    except KeyError as e:
        logging.error(f"KeyError during pipeline merge: {e}. Check config for merge keys. Using unmerged cases data.")
        return df_cases.copy() # Fallback
    except Exception as e:
        logging.error(f"Unexpected error during pipeline data merge: {e}. Using unmerged cases data.")
        return df_cases.copy() # Fallback

def preprocess_data(df: pd.DataFrame, config: Dict[str, Any]) -> pd.DataFrame:
    """Applies target mapping and basic feature imputation."""
    logging.info("Starting data preprocessing...")
    df_processed = df.copy()
    target_col = config['target_col']
    raw_to_numeric_map = config['target_mapping_raw_to_numeric']

    # Apply target mapping
    if target_col not in df_processed.columns:
        logging.error(f"Target column '{target_col}' not found in DataFrame. Preprocessing cannot continue.")
        raise ValueError(f"Target column '{target_col}' not found.")
    
    df_processed[target_col] = df_processed[target_col].map(raw_to_numeric_map)
    
    original_rows = len(df_processed)
    df_processed.dropna(subset=[target_col], inplace=True) # Drop rows where target is NaN after mapping
    rows_dropped = original_rows - len(df_processed)
    if rows_dropped > 0:
        logging.warning(f"Dropped {rows_dropped} rows due to unmapped/NaN target values after mapping.")

    # Impute missing values for features
    feature_cols = config.get('feature_cols', [])
    numerical_features = config.get('numerical_features', [])
    
    for col in feature_cols:
        if df_processed[col].isnull().any():
            if col in numerical_features:
                fill_value = df_processed[col].median()
                df_processed[col].fillna(fill_value, inplace=True)
                logging.info(f"Imputed NaNs in NUMERIC column '{col}' with median ({fill_value:.2f}).")
            else: # Categorical
                fill_value = df_processed[col].mode()
                if not fill_value.empty:
                    fill_value = fill_value[0]
                    df_processed[col].fillna(fill_value, inplace=True)
                    logging.info(f"Imputed NaNs in CATEGORICAL column '{col}' with mode ('{fill_value}').")
                else:
                    df_processed[col].fillna("__MISSING__", inplace=True) # Fallback if mode is empty
                    logging.warning(f"Column '{col}' had no mode for imputation; filled NaNs with '__MISSING__'.")
    
    logging.info(f"Preprocessing complete. Data shape: {df_processed.shape}")
    return df_processed

def get_feature_transformer(config: Dict[str, Any]) -> ColumnTransformer:
    """Creates a ColumnTransformer for scaling numerical and encoding categorical features."""
    numerical_features = config.get('numerical_features', [])
    categorical_features = config.get('categorical_features', [])
    
    transformers = []
    if numerical_features:
        transformers.append(('num', StandardScaler(), numerical_features))
    if categorical_features:
        transformers.append(('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), categorical_features))
    
    if not transformers: # No features to transform, passthrough all
        return ColumnTransformer(transformers=[], remainder='passthrough')
        
    # remainder='drop' (default) or 'passthrough'. If 'passthrough', non-specified cols are kept.
    # For feature selection, 'drop' is safer unless passthrough columns are explicitly handled.
    # Let's use 'passthrough' for now and ensure only feature_cols are used.
    preprocessor = ColumnTransformer(transformers=transformers, remainder='passthrough', verbose_feature_names_out=False)
    logging.info("Feature transformer (preprocessor) created.")
    return preprocessor

# ==========================
# Judge Bucketing
# ==========================
def create_judge_buckets(df: pd.DataFrame, config: Dict[str, Any]) -> Dict[str, pd.DataFrame]:
    """Creates judge-specific data buckets and a generic bucket."""
    if not config.get('use_judge_bucketing', False) or not config.get('judge_id_col_final_bucketing'):
        logging.info("Judge bucketing not enabled or no final judge ID column. Using one 'generic' bucket.")
        return {'generic': df.copy()}

    judge_col = str(config['judge_id_col_final_bucketing'])
    min_samples = config.get('min_samples_per_judge_bucket', 30)
    
    if judge_col not in df.columns:
        logging.warning(f"Judge column '{judge_col}' for bucketing not found. Using one 'generic' bucket.")
        return {'generic': df.copy()}

    judge_counts = df[judge_col].value_counts()
    judge_buckets: Dict[str, pd.DataFrame] = {}
    generic_bucket_list: List[pd.DataFrame] = []

    for judge_id_val, count in judge_counts.items():
        judge_id_str = str(judge_id_val) # Ensure string key
        judge_df = df[df[judge_col] == judge_id_val].copy() # Use original value for filtering
        if count >= min_samples:
            judge_buckets[judge_id_str] = judge_df
            logging.info(f"Created bucket for judge '{judge_id_str}' with {count} samples.")
        else:
            generic_bucket_list.append(judge_df)
            logging.info(f"Judge '{judge_id_str}' with {count} samples added to generic bucket.")
    
    if generic_bucket_list:
        df_generic = pd.concat(generic_bucket_list, ignore_index=True)
        if not df_generic.empty:
            judge_buckets['generic'] = df_generic
            logging.info(f"Generic bucket created with {len(df_generic)} samples from {len(generic_bucket_list)} judges/groups.")
    elif not judge_buckets: # No specific buckets and no generic data (e.g. all judges too small and no generic created)
        logging.warning("No specific judge buckets created and no data for generic bucket. All data might be in one implicit 'generic' bucket if bucketing failed.")
        return {'generic': df.copy()} # Fallback

    if not judge_buckets: # If still empty after all logic
        logging.warning("No buckets could be created. Returning all data in a single 'generic' bucket.")
        return {'generic': df.copy()}
        
    return judge_buckets

# ==========================
# Data Balancing 
# ==========================

from collections import Counter

def print_class_distribution(
    y_or_counter: Union[pd.Series, Counter],
    context_message: str = ""
) -> Counter:
    """
    Accepts a pandas Series of class labels OR an already-computed Counter.
    Prints counts and percentages; returns the Counter.
    """
    if isinstance(y_or_counter, Counter):
        counts = y_or_counter
    else:
        counts = Counter(y_or_counter)

    total = sum(counts.values())
    table = [
        [TARGET_CLASS_NAMES[c], counts.get(c, 0),
         f"{100.0 * counts.get(c, 0) / total:6.2f} %"]  # pct column
        for c in CLASS_LABELS_NUMERIC
    ]

    if context_message:
        print(f"\n{context_message}")

    print(tabulate(table, headers=["Class", "Count", "%"], tablefmt="pretty"))
    return counts

def _determine_auto_sampling_targets(current_counts: Counter, config: Dict[str, Any]) -> Dict[int, int]:
    """Calculates target sample counts for 'auto' sampling (two-stage)."""
    N0_orig, N1_orig, N2_orig = current_counts.get(0,0), current_counts.get(1,0), current_counts.get(2,0)
    r10_target = config.get('sampling_target_ratio_1_vs_0', 1.0)
    r2_01_target = config.get('sampling_target_ratio_2_vs_01', 0.5)
    
    logging.info(f"Auto-sampling: Initial counts N0={N0_orig}, N1={N1_orig}, N2={N2_orig}")
    logging.info(f"Auto-sampling: Target R1/R0={r10_target}, Target R2/(R0+R1)={r2_01_target}")

    N0_inter, N1_inter = N0_orig, N1_orig

    # Stage 1: Balance Class 0 (mother) and Class 1 (father)
    if N0_orig == 0 and N1_orig > 0 and r10_target > 0:
        logging.warning("Auto-sampling Stage 1: Class 0 is empty, cannot use as reference for Class 1. Class 1 count remains unchanged.")
    elif N0_orig > 0 : # N0_orig is not 0
        current_r10 = N1_orig / N0_orig
        delta_samples_total_stage1 = abs(N1_orig - (N0_orig * r10_target)) # Simplified delta idea
        
        # This logic aims to meet N1_inter / N0_inter = r10_target
        # by adjusting N0_inter and N1_inter by delta/2 each.
        # (N1_orig + d) / (N0_orig - d) = r10_target  => N1_orig + d = r10_target * N0_orig - r10_target * d
        # d * (1 + r10_target) = r10_target * N0_orig - N1_orig
        # d = (r10_target * N0_orig - N1_orig) / (1 + r10_target)
        # d is the amount to ADD to N1 and SUBTRACT from N0.
        if (1 + r10_target) == 0: # Avoid division by zero
            d1 = 0
        else:
            d1 = (r10_target * N0_orig - N1_orig) / (1 + r10_target)
        
        N0_inter = round(N0_orig - d1)
        N1_inter = round(N1_orig + d1)
    
    # Ensure non-negative and preserve original zeros
    N0_inter = max(0, N0_inter) if N0_orig > 0 else 0
    N1_inter = max(0, N1_inter) if N1_orig > 0 else 0
    logging.info(f"Auto-sampling Stage 1: Intermediate counts N0_inter={N0_inter}, N1_inter={N1_inter}")

    # Stage 2: Balance Class 2 (shared) against (N0_inter + N1_inter)
    N01_inter_sum = N0_inter + N1_inter
    N2_final = N2_orig # Start with original N2

    if N01_inter_sum == 0 and N2_orig > 0 and r2_01_target > 0:
        logging.warning("Auto-sampling Stage 2: Sum of (Class 0 + Class 1) is zero. Cannot use as reference for Class 2. Class 2 count remains unchanged.")
    elif N01_inter_sum > 0:
        # (N2_orig + d) / (N01_inter_sum - d) = r2_01_target
        # d = (r2_01_target * N01_inter_sum - N2_orig) / (1 + r2_01_target)
        # d is amount to ADD to N2 and SUBTRACT from N01_inter_sum
        if (1 + r2_01_target) == 0:
            d2 = 0
        else:
            d2 = (r2_01_target * N01_inter_sum - N2_orig) / (1 + r2_01_target)
        
        N2_final = round(N2_orig + d2)
        N01_sum_final = round(N01_inter_sum - d2)

        # Distribute change in N01_sum_final back to N0 and N1 proportionally
        N0_final, N1_final = N0_inter, N1_inter # Start with intermediate values
        if N01_inter_sum > 0 and N01_sum_final != N01_inter_sum : # If there was a change to the sum
            prop0 = N0_inter / N01_inter_sum if N01_inter_sum > 0 else 0
            prop1 = N1_inter / N01_inter_sum if N01_inter_sum > 0 else 0
            N0_final = round(N01_sum_final * prop0)
            N1_final = round(N01_sum_final * prop1)
            # Adjust for rounding to ensure sum matches N01_sum_final
            if N0_final + N1_final != N01_sum_final:
                if N0_orig >= N1_orig: N0_final = N01_sum_final - N1_final
                else: N1_final = N01_sum_final - N0_final
        else: # No change to N01_sum or it was zero
            N0_final, N1_final = N0_inter, N1_inter
    else: # N01_inter_sum is 0
        N0_final, N1_final = N0_inter, N1_inter # Which are likely 0

    # Ensure non-negative and preserve original zeros
    target_counts = {
        0: max(0, N0_final) if N0_orig > 0 else 0,
        1: max(0, N1_final) if N1_orig > 0 else 0,
        2: max(0, N2_final) if N2_orig > 0 else 0,
    }
    
    # Sanity check: prevent reduction to 0 if class existed, unless user confirms (not interactive here)
    # For now, if a class had samples, ensure it has at least 1 after auto-sampling if target became 0
    for cls_idx in [0,1,2]:
        if current_counts.get(cls_idx, 0) > 0 and target_counts[cls_idx] == 0:
            logging.warning(f"Auto-sampling for class {TARGET_CLASS_NAMES[cls_idx]} resulted in 0 samples (was {current_counts.get(cls_idx,0)}). Setting target to 1 to preserve class.")
            target_counts[cls_idx] = 1
            
    logging.info(f"Auto-sampling Final Targets: N0={target_counts[0]}, N1={target_counts[1]}, N2={target_counts[2]}")
    return target_counts

def _prompt_manual_sampling_targets(current_counts: Counter, bucket_name: str) -> Dict[int, int]:
    """Prompts user for manual target sample counts per class for a bucket."""
    print_section_header(f"Manual Sampling Configuration for Bucket: {bucket_name}")
    print_class_distribution(pd.Series(dict(current_counts)), f"Current distribution for {bucket_name}") # pd.Series for print_class_distribution
    
    target_counts: Dict[int, int] = {}
    for cls_numeric, cls_name in enumerate(TARGET_CLASS_NAMES):
        default_val = current_counts.get(cls_numeric, 0)
        target_counts[cls_numeric] = int(safe_input(
            f"  Enter target #samples for Class {cls_numeric} ({cls_name})",
            default=default_val, type_caster=int
        ))
    logging.info(f"Manual sampling for {bucket_name}: Target counts {target_counts}")
    return target_counts

def _apply_imblearn_sampling(X: pd.DataFrame, y: pd.Series, target_counts: Dict[int, int]) -> Tuple[pd.DataFrame, pd.Series]:
    """Applies imblearn over/under sampling to reach target_counts."""
    if X.empty or y.empty:
        logging.warning("Skipping imblearn sampling for empty X or y.")
        return X, y
    if RandomOverSampler is None or RandomUnderSampler is None:
        logging.error("imbalanced-learn is not installed. Cannot perform sampling.")
        return X, y

    current_counts = Counter(y)
    X_resampled, y_resampled = X.copy(), y.copy()

    # Strategy:
    # 1. Undersample classes where current_count > target_count
    # 2. Oversample classes where current_count_after_under < target_count

    # Undersampling
    under_targets = {cls: count for cls, count in target_counts.items() if cls in current_counts and current_counts[cls] > count and count > 0}
    if under_targets:
        try:
            # Ensure target counts for undersampling are not less than 1 if original class exists
            safe_under_targets = {k: max(1, v) for k,v in under_targets.items()}
            under_sampler = RandomUnderSampler(sampling_strategy=safe_under_targets, random_state=RANDOM_STATE)
            X_resampled, y_resampled = under_sampler.fit_resample(X_resampled, y_resampled)
            logging.info(f"Applied undersampling. New distribution: {Counter(y_resampled)}")
        except ValueError as e:
            logging.warning(f"Undersampling failed: {e}. Proceeding with data before undersampling for oversampling step.")
            X_resampled, y_resampled = X.copy(), y.copy() # Reset to before this attempt

    current_counts_after_under = Counter(y_resampled)

    # Oversampling
    # Oversample if target is > current (after under), and target > 0, and original class existed or target > 0
    over_targets = {
        cls: count for cls, count in target_counts.items() 
        if count > 0 and # Target must be positive
           ( (cls in current_counts_after_under and current_counts_after_under[cls] < count) or \
             (cls not in current_counts_after_under and current_counts.get(cls,0) > 0) ) # Class existed originally but might have been removed by undersampling
    }
    
    # Ensure over_targets only targets classes that originally existed if they are not in current_counts_after_under
    # And ensure that for classes present, the target is greater than current.
    final_over_targets = {}
    for cls, target_count in over_targets.items():
        if current_counts.get(cls, 0) == 0 and target_count > 0: # Trying to create a new class
            logging.warning(f"Cannot oversample class {cls} as it had 0 samples originally. Skipping oversample for this class.")
            continue
        current_val_for_over = current_counts_after_under.get(cls,0)
        if target_count > current_val_for_over:
             final_over_targets[cls] = target_count

    if final_over_targets:
        try:
            over_sampler = RandomOverSampler(sampling_strategy=final_over_targets, random_state=RANDOM_STATE)
            X_resampled, y_resampled = over_sampler.fit_resample(X_resampled, y_resampled)
            logging.info(f"Applied oversampling. Final distribution: {Counter(y_resampled)}")
        except ValueError as e:
            logging.warning(f"Oversampling failed: {e}. Using data after any undersampling.")
            # Data is already X_resampled, y_resampled from undersampling stage or original

    return X_resampled, y_resampled

def _calculate_class_weights(y_series: pd.Series, config: Dict[str, Any]) -> Optional[Dict[int, float]]:
    """Calculates class weights based on target ratios relative to class 0."""
    counts = Counter(y_series)
    N0, N1, N2 = counts.get(0,0), counts.get(1,0), counts.get(2,0)
    
    r10_target_weight = config.get('weighting_target_ratio_1_vs_0', 1.0)
    r20_target_weight = config.get('weighting_target_ratio_2_vs_0', 1.0)
    
    class_weights = {0: 1.0, 1: 1.0, 2: 1.0} # Default

    if N0 > 0:
        if N1 > 0:
            class_weights[1] = (N0 / N1) * r10_target_weight
        elif r10_target_weight > 0 : # N1 is 0, but target wants it to have weight
            class_weights[1] = 100.0 # Assign a large weight as a heuristic
            logging.warning(f"Class {TARGET_CLASS_NAMES[1]} has 0 samples, but weighting target ratio > 0. Assigning large weight.")

        if N2 > 0:
            class_weights[2] = (N0 / N2) * r20_target_weight
        elif r20_target_weight > 0: # N2 is 0
            class_weights[2] = 100.0
            logging.warning(f"Class {TARGET_CLASS_NAMES[2]} has 0 samples, but weighting target ratio > 0. Assigning large weight.")
    else:
        logging.warning("Class 0 (mother) has 0 samples. Cannot compute ratio-based class weights. Using default weights [1,1,1].")
        # Or could try sklearn's 'balanced' string for models that support it.
        # For custom weights, this is tricky. Defaulting to 1.0 for all.
        return None # Indicate that custom weights couldn't be effectively computed

    logging.info(f"Calculated class weights for model: {class_weights}")
    return class_weights

def balance_data_for_fold(
    X: pd.DataFrame,
    y: pd.Series,
    cfg: Dict[str, Any],
    *,
    tag: str = "",
    verbose: bool = True
) -> Tuple[pd.DataFrame, pd.Series, Optional[np.ndarray], Counter, Counter]:
    """
    Balances a fold according to cfg['balancing_method'] and returns:
        X_bal, y_bal, sample_weight (or None), Counter_before, Counter_after

    If verbose=True, prints BEFORE and AFTER distributions (counts + %).
    """
    # ------------------------------------------------------------------ #
    # BEFORE
    # ------------------------------------------------------------------ #
    before_cnt = Counter(y)
    if verbose:
        print_class_distribution(before_cnt, f"{tag} – BEFORE")

    method = cfg["balancing_method"]

    # ------------------------------------------------------------------ #
    # NONE  →  no changes
    # ------------------------------------------------------------------ #
    if method == "none":
        after_cnt = before_cnt  # unchanged
        if verbose:
            print_class_distribution(after_cnt, f"{tag} – AFTER (no change)")
        return X, y, None, before_cnt, after_cnt

    # ------------------------------------------------------------------ #
    # WEIGHTING  →  compute sample_weight, no resampling
    # ------------------------------------------------------------------ #
    if method == "weighting":
        ratios = cfg["class_weight_ratios"]
        sample_w = y.map(ratios).to_numpy(dtype=float)
        after_cnt = before_cnt  # distribution unchanged
        if verbose:
            print_class_distribution(after_cnt, f"{tag} – AFTER (weighted)")
        return X, y, sample_w, before_cnt, after_cnt

    # ------------------------------------------------------------------ #
    # SAMPLING  →  under-/over-sample to hit target percentages
    # ------------------------------------------------------------------ #
    target_pct = cfg["sampling_target_percentages"]

    # desired counts given the current total size
    total_after = len(y)
    target_counts = {
        cls: int(round(total_after * target_pct[cls] / 100.0))
        for cls in CLASS_LABELS_NUMERIC
    }

    # -- UNDER-sample classes above target ------------------------------
    to_under = {cls: cnt for cls, cnt in before_cnt.items()
                if cnt > target_counts.get(cls, cnt)}
    if to_under:
        rus = RandomUnderSampler(
            sampling_strategy={cls: target_counts[cls] for cls in to_under},
            random_state=RANDOM_STATE
        )
        X, y = rus.fit_resample(X, y)

    # -- OVER-sample classes below target -------------------------------
    after_under_cnt = Counter(y)
    to_over = {cls: target_counts[cls] for cls in CLASS_LABELS_NUMERIC
               if after_under_cnt.get(cls, 0) < target_counts[cls]}
    if to_over:
        ros = RandomOverSampler(
            sampling_strategy=to_over,
            random_state=RANDOM_STATE
        )
        X, y = ros.fit_resample(X, y)

    after_cnt = Counter(y)
    if verbose:
        print_class_distribution(after_cnt, f"{tag} – AFTER (sampled)")

    return X, y, None, before_cnt, after_cnt

# ==========================
# Model Training and Evaluation
# ==========================

def get_feature_importances(model, feature_names: List[str]) -> pd.Series:
    """
    Returns a pd.Series indexed by feature name.
    Works for tree models, linear coef_, etc.
    """
    if hasattr(model, "feature_importances_"):
        imp = model.feature_importances_
    elif hasattr(model, "coef_"):
        coef = model.coef_
        # handle binary vs multinomial
        if coef.ndim == 1:
            imp = np.abs(coef)
        else:
            imp = np.abs(coef).mean(axis=0)
    else:
        logging.warning("Model has no native feature importance; "
                        "returning zeros.")
        imp = np.zeros(len(feature_names))
    return pd.Series(imp, index=feature_names)

def get_model_instance(model_name: str,
                       model_params: Optional[Dict] = None,
                       random_state: int = RANDOM_STATE) -> Any:
    """
    Returns a fresh model instance for the given shorthand.
    """
    model_params = model_params or {}
    name = model_name.lower()

    if name in {"rf", "randomforest", "randomforestclassifier"}:
        return RandomForestClassifier(random_state=random_state, **model_params)

    if name in {"lr", "logreg", "logisticregression"}:
        return LogisticRegression(
            random_state=random_state,
            max_iter=1000,
            n_jobs=-1,
            **model_params
        )

    if name in {"svc", "svm"}:
        return SVC(probability=True, random_state=random_state, **model_params)

    if name in {"xgb", "xgboost", "xgbclassifier"}:
        return XGBClassifier(
            random_state=random_state,
            n_jobs=-1,
            objective="multi:softprob",
            eval_metric="mlogloss",
            **model_params
        )

    raise ValueError(f"Unsupported model: {model_name}")

def tune_hyperparameters(
    X: pd.DataFrame,
    y: pd.Series,
    model_name: str,
    param_distributions: Dict[str, Any],
    *,
    cv_folds_inner: int,
    n_iter: int,
    sample_weight_train: Optional[np.ndarray] = None,
) -> Tuple[Any, Dict[str, Any]]:
    """
    RandomizedSearchCV wrapper that returns (best_estimator, best_params).
    """
    mdl = get_model_instance(model_name)
    rs = RandomizedSearchCV(
        estimator=mdl,
        param_distributions=param_distributions,
        n_iter=n_iter,
        cv=cv_folds_inner,
        scoring="f1_macro",
        random_state=RANDOM_STATE,
        n_jobs=-1,
        refit=True,
        verbose=0,
    )
    rs.fit(X, y, sample_weight=sample_weight_train)
    return rs.best_estimator_, rs.best_params_


def evaluate_model_on_test_set(
    model: Any, X_test: pd.DataFrame, y_test: pd.Series, 
    class_labels_numeric: List[int] = CLASS_LABELS_NUMERIC
) -> Tuple[Dict[str, Any], np.ndarray]:
    """Evaluates the model and returns metrics and confusion matrix."""
    if X_test.empty or y_test.empty:
        logging.warning("Skipping evaluation due to empty X_test or y_test.")
        return {}, np.array([])

    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test) if hasattr(model, "predict_proba") else None

    metrics: Dict[str, Any] = {
        'accuracy': accuracy_score(y_test, y_pred),
        'per_class': {},
        'auc_macro_ovr': np.nan # Initialize
    }
    
    # Precision, Recall, F1 (per-class and macro)
    # Ensure labels parameter matches the unique values in y_test for calculation, but report for all expected classes
    # unique_y_test_labels = sorted(y_test.unique())
    # if not all(l in class_labels_numeric for l in unique_y_test_labels):
    #     logging.warning(f"y_test contains labels not in predefined class_labels_numeric. Metrics might be affected. y_test labels: {unique_y_test_labels}")

    # Use predefined class_labels_numeric for consistent reporting structure
    # zero_division=0 avoids warnings and returns 0 for metrics where division by zero occurs
    precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, labels=class_labels_numeric, average=None, zero_division=0)
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(y_test, y_pred, labels=class_labels_numeric, average='macro', zero_division=0)

    metrics['macro_precision'] = macro_precision
    metrics['macro_recall'] = macro_recall
    metrics['macro_f1'] = macro_f1

    for i, label_num in enumerate(class_labels_numeric):
        metrics['per_class'][label_num] = {
            'precision': precision[i],
            'recall': recall[i],
            'f1_score': f1[i],
            'auc_ovr': np.nan # Initialize
        }

    # AUC
    if y_proba is not None and y_test.nunique() > 1:
        try:
            # Ensure y_proba has correct shape for multi_class='ovr' (n_samples, n_classes)
            if y_proba.ndim == 1: # If predict_proba returns 1D array (e.g. for binary case by some models)
                 if len(class_labels_numeric) == 2: y_proba = np.vstack([1 - y_proba, y_proba]).T
                 else: logging.warning(f"1D y_proba for {len(class_labels_numeric)}-class problem. AUC might be incorrect."); y_proba = None

            if y_proba is not None and y_proba.shape[1] == len(class_labels_numeric):
                auc_scores_ovr = roc_auc_score(y_test, y_proba, multi_class='ovr', average=None, labels=class_labels_numeric)
                metrics['auc_macro_ovr'] = roc_auc_score(y_test, y_proba, multi_class='ovr', average='macro', labels=class_labels_numeric)
                for i, label_num in enumerate(class_labels_numeric):
                    metrics['per_class'][label_num]['auc_ovr'] = auc_scores_ovr[i]
            else:
                logging.warning(f"y_proba shape mismatch for AUC calculation. Expected {len(class_labels_numeric)} columns, got {y_proba.shape[1] if y_proba is not None else 'None'}.")
        except ValueError as e_auc:
            logging.warning(f"Could not calculate ROC AUC: {e_auc}. y_test unique: {y_test.unique().tolist()}")
    
    cm_array = confusion_matrix(y_test, y_pred, labels=class_labels_numeric)
    return metrics, cm_array

# ==========================
# Results Export
# ==========================

# =============================================================
#  Model Persistence Utilities
#  (NEW – place e.g. in util_io.py or wherever helpers live)
# =============================================================
import copy
import joblib
from datetime import datetime as dt
from pathlib import Path
from typing import Any, Dict

def export_model_bundle(
    model: Any,
    fitted_transformer: Any,
    meta: Dict[str, Any],
    save_path: Path,
) -> Path:
    """
    Serialises the trained model **together with** its fitted transformer
    (and some lightweight metadata) so that the bundle can be re-loaded and
    used for single-case predictions in any Python app.

    Parameters
    ----------
    model : sklearn-compatible estimator
    fitted_transformer : ColumnTransformer or Pipeline – must already be .fit()
    meta : dict
        Arbitrary information you want to save (bucket_id, model_name, etc.).
    save_path : pathlib.Path
        Full filename ('.joblib' recommended). Directories will be created.

    Returns
    -------
    pathlib.Path – the actual path written to disk.
    """
    import logging, os

    # 1. Deep-copy the transformer so subsequent .fit() calls in the main loop
    #    do NOT mutate the object stored on disk.
    transformer_copy = copy.deepcopy(fitted_transformer)

    # 2. Build the bundle
    bundle = {
        "model": model,
        "transformer": transformer_copy,
        "meta": {
            "saved_at": dt.now().isoformat(timespec="seconds"),
            **meta,
        },
    }

    # 3. Make sure the directory exists
    os.makedirs(save_path.parent, exist_ok=True)

    # 4. Dump with joblib
    joblib.dump(bundle, save_path)
    logging.info(f"Model bundle exported to {save_path}")

    return save_path


def load_model_bundle(bundle_path: Path) -> Dict[str, Any]:
    """
    Quick convenience wrapper for `joblib.load`.

    Returns the dict created by `export_model_bundle`.
    Keys: 'model', 'transformer', 'meta'
    """
    import joblib, logging

    loaded = joblib.load(bundle_path)
    logging.info(f"Model bundle loaded from {bundle_path}")
    return loaded

# =============================================================
#  Cross-Bucket / Cross-Model Evaluation
#  (NEW)
# =============================================================
from collections import defaultdict
from typing import Any, Dict, List, Tuple
import numpy as np
import pandas as pd
from pathlib import Path
import logging

def cross_compare_models(
    model_bundles: Dict[Tuple[str, str], Dict[str, Any]],
    test_sets: Dict[str, Tuple[pd.DataFrame, pd.Series]],
    out_dir: Path,
) -> Tuple[List[Dict[str, Any]], List[pd.DataFrame]]:
    """
    Evaluates every trained model (source bucket × model name)
    on **every** available held-out test set (target bucket).

    Parameters
    ----------
    model_bundles : dict
        Key = (src_bucket_id, model_name)  
        Value = dict with at least keys { 'model', 'transformer' }
        (i.e. exactly what `export_model_bundle` returns, but kept in-memory).
    test_sets : dict
        Key = target_bucket_id  
        Value = (X_test_raw, y_test)  – **un-transformed** dataframes/series.
    out_dir : pathlib.Path
        Root experiment directory → confusion-matrix PNGs will be placed
        in `<out_dir>/<CM_SUBDIR_NAME>/`.

    Returns
    -------
    cross_metrics_rows : list[dict]
        One flat dict per (src_model, target_test_set) combination.
    cross_cm_tables    : list[pd.DataFrame]
        Each DF is the flattened confusion matrix, augmented with identifiers.
    """

    # Ensure CM directory exists
    cm_dir = out_dir / CM_SUBDIR_NAME
    cm_dir.mkdir(exist_ok=True, parents=True)

    cross_metrics_rows: List[Dict[str, Any]] = []
    cross_cm_tables:  List[pd.DataFrame]      = []

    total_evals = len(model_bundles) * len(test_sets)
    logging.info(f"Cross-evaluation: {total_evals} model×test combinations.")
    print_section_header("CROSS-BUCKET EVALUATION")

    # ──────────────────────────────────────────────────────────────
    for (src_bucket, model_name), bundle in model_bundles.items():
        model       = bundle["model"]
        transformer = bundle["transformer"]

        for tgt_bucket, (X_tgt_raw, y_tgt) in test_sets.items():
            tag = f"{src_bucket}/{model_name} → {tgt_bucket}"
            print(f"  ↳ {tag}")

            # 1. Transform with *that model's* fitted transformer
            try:
                X_tgt = pd.DataFrame(
                    transformer.transform(X_tgt_raw),
                    columns=transformer.get_feature_names_out(),
                    index=X_tgt_raw.index,
                )
            except Exception as e_tr:
                logging.error(f"Transformer failure for {tag}: {e_tr}")
                continue  # skip to next pair

            # 2. Evaluate
            metrics, cm = evaluate_model_on_test_set(model, X_tgt, y_tgt)

            # 3. Store metrics
            row = {
                "ModelBucket": src_bucket,
                "Model":       model_name,
                "TestBucket":  tgt_bucket,
                **metrics,
            }
            cross_metrics_rows.append(row)

            # 4. Confusion matrix – save PNG & DF
            if cm.size:
                cm_png = cm_dir / f"{src_bucket}_{model_name}_ON_{tgt_bucket}.png"
                save_cm_png(cm, TARGET_CLASS_NAMES, tag, cm_png)

                cm_df = cm_to_dataframe(cm, TARGET_CLASS_NAMES)
                cm_df["ModelBucket"] = src_bucket
                cm_df["Model"]       = model_name
                cm_df["TestBucket"]  = tgt_bucket
                cross_cm_tables.append(cm_df)

    print(f"Completed cross-evaluation for {len(cross_metrics_rows)} combinations.")
    logging.info("Cross-bucket evaluation finished.")

    return cross_metrics_rows, cross_cm_tables

# -----------------------------------------------------------------
# CONFUSION-MATRIX HELPERS
# -----------------------------------------------------------------
def save_cm_png(cm: np.ndarray, labels: List[str], title: str, out_path: Path) -> None:
    plt.figure(figsize=(4, 4))
    sns.heatmap(cm, annot=True, fmt="d",
                xticklabels=labels, yticklabels=labels,
                cmap="Blues", cbar=False)
    plt.xlabel("Predicted"); plt.ylabel("Actual"); plt.title(title)
    plt.tight_layout(); out_path.parent.mkdir(exist_ok=True, parents=True)
    plt.savefig(out_path, dpi=150); plt.close()

def cm_to_dataframe(cm: np.ndarray, labels: List[str]) -> pd.DataFrame:
    return pd.DataFrame(cm, index=[f"Actual_{l}" for l in labels],
                           columns=[f"Pred_{l}" for l in labels])

# -----------------------------------------------------------------
# CLASS-DISTRIBUTION → tidy DataFrame   (for nice XLSX export)
# -----------------------------------------------------------------

def distribution_to_rows(
    counter: Counter,
    bucket: str,
    subset: str
) -> List[Dict[str, Any]]:
    total = sum(counter.values())
    rows = []
    for cls in CLASS_LABELS_NUMERIC:
        rows.append({
            "Bucket"     : bucket,
            "Subset"     : subset,
            "Class"      : TARGET_CLASS_NAMES[cls],
            "Count"      : counter.get(cls, 0),
            "Percent"    : 100.0 * counter.get(cls, 0) / total if total else 0.0,
        })
    return rows

def export_class_distributions_xlsx(writer: pd.ExcelWriter, global_dist: Optional[Counter], per_bucket_dist: Dict[str, Counter]) -> None:
    """Exports global and per-bucket class distributions to Excel."""
    if global_dist:
        global_data = [[TARGET_CLASS_NAMES[cls], count] for cls, count in sorted(global_dist.items())]
        df_global = pd.DataFrame(global_data, columns=['Class', 'Count'])
        df_global.to_excel(writer, sheet_name='Global Class Distribution', index=False)

    bucket_data_list = []
    for bucket_id, dist_counter in per_bucket_dist.items():
        for cls_numeric, count in sorted(dist_counter.items()):
            bucket_data_list.append({
                'Bucket ID': bucket_id, 
                'Class Name': TARGET_CLASS_NAMES[cls_numeric],
                'Class Numeric': cls_numeric,
                'Count': count
            })
    if bucket_data_list:
        pd.DataFrame(bucket_data_list).to_excel(writer, sheet_name='Per Bucket Class Distribution', index=False)
    logging.info("Exported class distributions to XLSX.")


def export_metrics_xlsx(writer: pd.ExcelWriter, all_fold_results: List[Dict[str, Any]]) -> None:
    """Exports aggregated CV metrics per bucket and model to Excel."""
    metrics_export_list = []
    for res in all_fold_results: # res is one model's aggregated performance in one bucket
        bucket_id = res['bucket_id']
        model_name = res['model_name']
        avg_metrics = res['avg_metrics'] # This should contain mean metrics from CV

        # Macro metrics
        for metric_key, report_name in [
            ('accuracy_mean', 'Accuracy'), ('macro_precision_mean', 'Macro Precision'),
            ('macro_recall_mean', 'Macro Recall'), ('macro_f1_mean', 'Macro F1-Score'),
            ('auc_macro_ovr_mean', 'Macro AUC (OvR)')
        ]:
            metrics_export_list.append({
                'Bucket ID': bucket_id, 'Model': model_name, 'Metric Type': 'Macro',
                'Class': 'N/A', 'Metric Name': report_name, 'Value': avg_metrics.get(metric_key, np.nan)
            })
        
        # Per-class metrics
        if 'per_class_mean' in avg_metrics:
            for cls_numeric, cls_avg_metrics in avg_metrics['per_class_mean'].items():
                cls_name = TARGET_CLASS_NAMES[cls_numeric]
                for pc_metric_key, pc_report_name in [
                    ('precision_mean', 'Precision'), ('recall_mean', 'Recall'),
                    ('f1_score_mean', 'F1-Score'), ('auc_ovr_mean', 'AUC (OvR)')
                ]:
                    metrics_export_list.append({
                        'Bucket ID': bucket_id, 'Model': model_name, 'Metric Type': 'Per-Class',
                        'Class': f"{cls_name} ({cls_numeric})", 'Metric Name': pc_report_name, 
                        'Value': cls_avg_metrics.get(pc_metric_key, np.nan)
                    })
    
    if metrics_export_list:
        pd.DataFrame(metrics_export_list).to_excel(writer, sheet_name='Model CV Metrics', index=False)
    logging.info("Exported model CV metrics to XLSX.")

def export_hyperparameters_xlsx(writer: pd.ExcelWriter, all_fold_results: List[Dict[str, Any]]) -> None:
    """Exports best hyperparameters (from final model fit on full bucket data) to Excel."""
    params_list = []
    for res in all_fold_results:
        if res.get('final_best_params'): # These are from the model trained on full bucket data after CV
            for param_name, param_value in res['final_best_params'].items():
                params_list.append({
                    'Bucket ID': res['bucket_id'], 'Model': res['model_name'],
                    'Parameter': param_name, 'Value': str(param_value)
                })
    if params_list:
        pd.DataFrame(params_list).to_excel(writer, sheet_name='Best Hyperparameters', index=False)
    logging.info("Exported best hyperparameters to XLSX.")

def export_feature_importances_xlsx(writer: pd.ExcelWriter, all_fold_results: List[Dict[str, Any]], feature_names_map: Dict[str, List[str]]) -> None:
    """Exports feature importances from final models to Excel."""
    importances_list = []
    for res in all_fold_results:
        model = res.get('final_trained_model') # Model trained on full bucket data
        bucket_id = res['bucket_id']
        model_name = res['model_name']
        
        # Get correct feature names for this bucket (after preprocessing)
        # These names come from the preprocessor fitted on the bucket's data or globally
        bucket_feature_names = feature_names_map.get(bucket_id)
        if not bucket_feature_names:
            logging.warning(f"No feature names found for bucket '{bucket_id}'. Skipping FI export for model '{model_name}'.")
            continue

        fi_values = None
        fi_type = "Importance"
        if hasattr(model, 'feature_importances_'): # Tree-based models
            fi_values = model.feature_importances_
        elif hasattr(model, 'coef_'): # Linear models
            if model.coef_.ndim == 2 and model.coef_.shape[0] > 1: # Multi-class coef_
                fi_values = np.mean(np.abs(model.coef_), axis=0) # Avg abs coef across classes
            else: # Binary or single output coef_
                fi_values = np.abs(model.coef_.flatten())
            fi_type = "Abs Coefficient"
        
        if fi_values is not None:
            if len(bucket_feature_names) == len(fi_values):
                for feat_name, importance_val in zip(bucket_feature_names, fi_values):
                    importances_list.append({
                        'Bucket ID': bucket_id, 'Model': model_name,
                        'Feature': feat_name, fi_type: importance_val
                    })
            else:
                logging.warning(f"Feature name/importance length mismatch for {model_name} in bucket {bucket_id}. Names: {len(bucket_feature_names)}, FI: {len(fi_values)}. Skipping.")
        else:
            logging.info(f"Model {model_name} in bucket {bucket_id} does not have standard feature_importances_ or coef_ attribute.")

    if importances_list:
        df_importances = pd.DataFrame(importances_list)
        # Sort by importance value descending
        sort_col = "Importance" if "Importance" in df_importances.columns else "Abs Coefficient"
        df_importances = df_importances.sort_values(by=['Bucket ID', 'Model', sort_col], ascending=[True, True, False])
        df_importances.to_excel(writer, sheet_name='Feature Importances', index=False)
    logging.info("Exported feature importances to XLSX.")


def export_confusion_matrices_xlsx_and_png(
    writer: pd.ExcelWriter, 
    all_fold_results: List[Dict[str, Any]], 
    output_dir_path: Path
) -> None:
    """Exports aggregated confusion matrices (from CV) to Excel tables and PNG heatmaps."""
    cm_png_dir = output_dir_path / CM_SUBDIR_NAME
    cm_png_dir.mkdir(parents=True, exist_ok=True)
    
    current_row_excel = 0
    sheet_name_excel = 'Aggregated CV CMs'

    for res in all_fold_results:
        cm_array = res.get('aggregated_cv_cm') # Sum of CMs from CV folds
        bucket_id = res['bucket_id']
        model_name = res['model_name']
        
        # Ensure class labels for CM are consistent (0, 1, 2)
        cm_display_labels_str = TARGET_CLASS_NAMES
        cm_display_labels_num = CLASS_LABELS_NUMERIC

        if cm_array is not None and cm_array.size > 0 and cm_array.shape == (len(cm_display_labels_num), len(cm_display_labels_num)):
            # --- Export to XLSX Table ---
            df_cm = pd.DataFrame(cm_array, index=cm_display_labels_str, columns=cm_display_labels_str)
            
            header_df = pd.DataFrame([
                [f"Aggregated CV Confusion Matrix: Bucket {bucket_id}, Model {model_name}"],
                ["Actual \\ Predicted"] + cm_display_labels_str
            ])
            header_df.to_excel(writer, sheet_name=sheet_name_excel, startrow=current_row_excel, header=False, index=False)
            current_row_excel += len(header_df)
            df_cm.to_excel(writer, sheet_name=sheet_name_excel, startrow=current_row_excel, header=True, index=True) # Index=True for Actual labels
            current_row_excel += len(df_cm) + 3 # Add spacing

            # --- Export to PNG Heatmap ---
            try:
                fig, ax = plt.subplots(figsize=(7, 5)) # Smaller figure for individual CMs
                disp = ConfusionMatrixDisplay(confusion_matrix=cm_array, display_labels=cm_display_labels_str)
                disp.plot(ax=ax, cmap=plt.cm.Blues, values_format='d')
                ax.set_title(f"Agg. CV CM: {model_name}\nBucket: {bucket_id}", fontsize=10)
                plt.tight_layout()
                png_filename = f"cm_agg_cv_{bucket_id}_{model_name}.png".replace(" ", "_").replace("/", "_").replace("\\", "_")
                plt.savefig(cm_png_dir / png_filename)
                plt.close(fig)
            except Exception as e_plot:
                logging.error(f"Error plotting/saving CM PNG for {model_name}, bucket {bucket_id}: {e_plot}")
        else:
            logging.warning(f"Skipping CM export for {model_name} in bucket {bucket_id} due to missing, empty, or malformed CM array.")
            
    if current_row_excel > 0:
        logging.info(f"Exported aggregated CV confusion matrices (tables) to XLSX sheet '{sheet_name_excel}'.")
    logging.info(f"Exported confusion matrix PNGs to {cm_png_dir}.")

# ==========================
# Main Pipeline Orchestration
#  Run the full experiment, including logging & exports
# ==========================
def run_pipeline() -> None:
    """
    Executes the full experiment:
      • Merges / preprocesses data
      • Buckets by judge (or generic bucket)
      • Balances data per user config
      • Performs CV with RandomizedSearchCV hyper-param tuning
      • Evaluates on held-out TEST set
      • Cross-evaluates every model on every bucket’s TEST set  (NEW)
      • Persists fitted model+transformer bundles to disk       (NEW)
      • Exports distributions, metrics, best params, feature
        importances, confusion matrices, and cross-metrics
        to XLSX / PNG / .joblib
    """
    # ─────────────────────────────────────────────────────────────
    global CONFIG
    out_dir = Path(CONFIG['output_dir_actual'])
    cm_dir  = out_dir / CM_SUBDIR_NAME
    cm_dir.mkdir(exist_ok=True, parents=True)

    # ── NEW ► directory for serialised models
    model_dir = out_dir / "model_bundles"
    model_dir.mkdir(exist_ok=True, parents=True)

    # ─── containers for cross-evaluation (NEW) ───────────────────
    model_bundles: Dict[Tuple[str, str], Dict[str, Any]] = {}
    test_sets_raw: Dict[str, Tuple[pd.DataFrame, pd.Series]] = {}

    # ───────────────── DATA LOADING ──────────────────────────────
    df_cases  = load_data_from_path(CONFIG['cases_file_path'], "Cases Data")
    df_judges = load_data_from_path(CONFIG.get('judges_file_path'), "Judges Data")
    if df_cases is None:
        sys.exit("Cases data missing – aborting.")

    df_merged    = full_data_merge(df_cases, df_judges, CONFIG)
    df_processed = preprocess_data(df_merged, CONFIG)
    if df_processed.empty:
        sys.exit("No data after preprocessing – aborting.")

    print_section_header("DATA OVERVIEW")
    print(f"Final dataframe shape : {df_processed.shape}")

    # ─────────────── BALANCING CONFIG (needs distribution) ───────
    global_counter = Counter(df_processed[CONFIG['target_col']])
    get_balancing_config_interactive(global_counter)   # Updates CONFIG in‐place

    # ─────────────── TEST-SET SHARE PROMPT ───────────────────────
    cv_folds = CONFIG['cv_folds']
    test_pct = float(
        safe_input("\nTest-set percentage", default="20", type_caster=float)
    )
    CONFIG['test_size'] = test_pct / 100.0
    print(f"Validation share per fold ≈ {100/cv_folds:.1f}%")

    # ─────────────── FEATURE TRANSFORMER & BUCKETS ───────────────
    transformer = get_feature_transformer(CONFIG)
    buckets     = create_judge_buckets(df_processed, CONFIG)
    if not buckets:
        sys.exit("No buckets created – aborting.")

    # ── containers for XLSX export ───────────────────────────────
    dist_rows, agg_rows, test_rows, cm_tables = [], [], [], []
    hp_rows, feat_rows = [], []

    # ─────────────────── MAIN LOOP OVER BUCKETS ───────────────────
    for bucket_id, bucket_df in buckets.items():
        header = f"BUCKET {bucket_id}  ({len(bucket_df)} samples)"
        print_section_header(header)
        logging.info(header)

        if bucket_df[CONFIG['target_col']].nunique() < 2:
            print("Skipped: only one class present.")
            continue

        X_all = bucket_df[CONFIG['feature_cols']]
        y_all = bucket_df[CONFIG['target_col']].astype(int)

        # Split
        X_tr_raw, X_te_raw, y_tr_raw, y_te = train_test_split(
            X_all, y_all,
            test_size=CONFIG['test_size'],
            stratify=y_all,
            random_state=RANDOM_STATE
        )

        # ── NEW ► cache the raw TEST set once per bucket (for cross eval)
        if bucket_id not in test_sets_raw:
            test_sets_raw[bucket_id] = (X_te_raw.copy(), y_te.copy())

        # transformer
        transformer.fit(X_tr_raw)
        X_tr_raw = pd.DataFrame(transformer.transform(X_tr_raw),
                                columns=transformer.get_feature_names_out(),
                                index=X_tr_raw.index)
        X_te = pd.DataFrame(transformer.transform(X_te_raw),
                            columns=transformer.get_feature_names_out(),
                            index=X_te_raw.index)

        # store raw TRAIN / TEST distributions
        dist_rows += distribution_to_rows(Counter(y_tr_raw), bucket_id, "TRAIN_BEFORE")
        dist_rows += distribution_to_rows(Counter(y_te),      bucket_id, "TEST")

        skf = StratifiedKFold(
            n_splits=cv_folds,
            shuffle=True,
            random_state=CONFIG['cv_random_state']
        )

        # ─────────────────── MODELS ───────────────────────────────
        for model_name in CONFIG['models_to_run']:
            print(f"\nModel: {model_name}")
            fold_metrics = []

            # ── CV loop ───────────────────────────────────────────
            for f, (idx_tr, idx_val) in enumerate(skf.split(X_tr_raw, y_tr_raw), 1):
                subset_tag = f"FOLD{f}_TRAIN"
                X_f_tr_raw, y_f_tr_raw = X_tr_raw.iloc[idx_tr], y_tr_raw.iloc[idx_tr]
                X_f_val,    y_f_val    = X_tr_raw.iloc[idx_val], y_tr_raw.iloc[idx_val]

                # balance TRAIN fold
                X_bal, y_bal, sw, cnt_bef, cnt_aft = balance_data_for_fold(
                    X_f_tr_raw, y_f_tr_raw, CONFIG, tag=f"{bucket_id}_f{f}"
                )
                # log BEFORE/AFTER rows
                dist_rows += distribution_to_rows(cnt_bef, bucket_id, f"{subset_tag}_BEFORE")
                dist_rows += distribution_to_rows(cnt_aft, bucket_id, f"{subset_tag}_AFTER")

                tuned_model, _ = tune_hyperparameters(
                    X_bal, y_bal, model_name,
                    CONFIG['hyperparameter_grids'][model_name],
                    cv_folds_inner=max(2, cv_folds//2),
                    n_iter=CONFIG['hyperparameter_tuning_iterations'],
                    sample_weight_train=sw
                )
                met, _ = evaluate_model_on_test_set(tuned_model, X_f_val, y_f_val)
                fold_metrics.append(met)
                print(f"  Fold {f}/{cv_folds}  Macro-F1 = {met.get('macro_f1', np.nan):.3f}")

            # aggregate CV
            keys = ["accuracy", "macro_precision", "macro_recall", "macro_f1"]
            agg = {f"{k}_mean": float(np.nanmean([m[k] for m in fold_metrics])) for k in keys}
            agg.update({f"{k}_std": float(np.nanstd([m[k] for m in fold_metrics])) for k in keys})
            agg_rows.append(dict(Bucket=bucket_id, Model=model_name, **agg))
            print(f"  CV Macro-F1 = {agg['macro_f1_mean']:.3f} ± {agg['macro_f1_std']:.3f}")

            # ── final model on full TRAIN (balanced) ──────────────
            X_bal_fin, y_bal_fin, sw_fin, cnt_bef_fin, cnt_aft_fin = balance_data_for_fold(
                X_tr_raw, y_tr_raw, CONFIG, tag=f"{bucket_id}_final"
            )
            dist_rows += distribution_to_rows(cnt_bef_fin, bucket_id, "TRAIN_FULL_BEFORE")
            dist_rows += distribution_to_rows(cnt_aft_fin, bucket_id, "TRAIN_FULL_AFTER")

            final_model, best_params = tune_hyperparameters(
                X_bal_fin, y_bal_fin, model_name,
                CONFIG['hyperparameter_grids'][model_name],
                cv_folds_inner=cv_folds,
                n_iter=CONFIG['hyperparameter_tuning_iterations'],
                sample_weight_train=sw_fin
            )

            # store best hyper-parameters
            hp_row = {"Bucket": bucket_id, "Model": model_name}
            hp_row.update(best_params)
            hp_rows.append(hp_row)

            # test evaluation
            test_met, test_cm = evaluate_model_on_test_set(final_model, X_te, y_te)
            test_rows.append(dict(Bucket=bucket_id, Model=model_name, **test_met))
            print(f"  TEST Macro-F1 = {test_met.get('macro_f1', np.nan):.3f}")

            # confusion matrix PNG + DF
            if test_cm.size:
                cm_path = cm_dir / f"{bucket_id}_{model_name}_TEST.png"
                save_cm_png(test_cm, TARGET_CLASS_NAMES,
                            f"{bucket_id}-{model_name}-TEST", cm_path)
                cm_df = cm_to_dataframe(test_cm, TARGET_CLASS_NAMES)
                cm_df['Bucket'], cm_df['Model'], cm_df['Fold'] = bucket_id, model_name, "TEST"
                cm_tables.append(cm_df)

            # feature importances
            feat_imp = get_feature_importances(final_model, list(X_bal_fin.columns))
            for rank, (fname, score) in enumerate(feat_imp.sort_values(ascending=False).head(TOP_K_FEATURES).items(), 1):
                feat_rows.append({
                    "Bucket": bucket_id,
                    "Model": model_name,
                    "Rank":  rank,
                    "Feature": fname,
                    "Importance": float(score)
                })

            # ── NEW ► cache model + transformer for cross-eval
            import copy
            bundle_meta = {
                "bucket_id": bucket_id,
                "model_name": model_name,
                "n_samples_bucket": len(bucket_df),
            }
            bundle = {
                "model": final_model,
                "transformer": copy.deepcopy(transformer),
                "meta": bundle_meta,
            }
            model_bundles[(bucket_id, model_name)] = bundle

            # ── NEW ► persist bundle to disk
            bundle_path = model_dir / f"{bucket_id}_{model_name}.joblib"
            export_model_bundle(final_model, transformer, bundle_meta, bundle_path)

    # ───────────── CROSS-BUCKET EVALUATION (NEW) ──────────────────
    cross_rows, cross_cm_tables = cross_compare_models(
        model_bundles=model_bundles,
        test_sets=test_sets_raw,
        out_dir=out_dir,
    )

    # ───────────── EXPORT XLSX ────────────────────────────────────
    xlsx_path = out_dir / RESULTS_XLSX_FILENAME
    with pd.ExcelWriter(xlsx_path) as writer:
        pd.DataFrame(dist_rows).to_excel(writer, "Distributions", index=False)
        pd.DataFrame(agg_rows).to_excel(writer, "CV_Aggregated", index=False)
        pd.DataFrame(test_rows).to_excel(writer, "Test_Metrics", index=False)
        if cross_rows:
            pd.DataFrame(cross_rows).to_excel(writer, "Cross_Test_Metrics", index=False)
        if hp_rows:
            pd.DataFrame(hp_rows).to_excel(writer, "Best_Hyperparameters", index=False)
        if feat_rows:
            pd.DataFrame(feat_rows).to_excel(writer, "Top_Features", index=False)
        if cm_tables:
            pd.concat(cm_tables).to_excel(writer, "Confusion_Matrices", index=False)
        if cross_cm_tables:
            pd.concat(cross_cm_tables).to_excel(writer, "Cross_Confusion_Matrices", index=False)

    print_section_header("Pipeline Completed")
    print(f"Results saved to {xlsx_path}")
    logging.info(f"Run finished. Results saved to {xlsx_path}")

# ==========================
# Entry Point
# ==========================
# Place this function definition before main_interactive_pipeline()

def load_or_trigger_interactive_config(output_dir_path: Path) -> None:
    """
    Attempts to load configuration from a file. If not chosen or fails,
    triggers the full interactive setup process.
    Populates the global CONFIG.
    """
    global CONFIG
    
    # Try to find a default config file in the experiment directory
    default_config_file_in_experiment_dir = output_dir_path / CONFIG_FILENAME

    load_from_file = get_yes_no_input(
        f"Load configuration from a file? (Default config if exists: {default_config_file_in_experiment_dir})",
        default_yes=default_config_file_in_experiment_dir.exists() # Default to yes if config found
    )

    if load_from_file:
        if default_config_file_in_experiment_dir.exists():
            chosen_path_str = safe_input("Enter path to configuration JSON file", default=str(default_config_file_in_experiment_dir))
        else:
            chosen_path_str = safe_input("Enter path to configuration JSON file (no default found in experiment dir)")
        
        chosen_path = Path(_clean_path_input(chosen_path_str))
        try:
            with open(chosen_path, 'r', encoding='utf-8') as f:
                CONFIG = json.load(f)
            logging.info(f"Configuration successfully loaded from {chosen_path}")
            
            # Ensure essential output directory paths are correctly set in the loaded CONFIG
            # The loaded config might have an old output_base_dir.
            # For this run, the output_dir_path (timestamped experiment dir) is fixed.
            CONFIG['output_dir_actual'] = str(output_dir_path)
            CONFIG['output_base_dir'] = str(output_dir_path.parent) # The parent of the timestamped dir
            return # Config loaded successfully
        except FileNotFoundError:
            logging.error(f"Configuration file not found: {chosen_path}. Proceeding with interactive setup.")
        except json.JSONDecodeError as e:
            logging.error(f"Error decoding JSON from configuration file {chosen_path}: {e}. Proceeding with interactive setup.")
        except Exception as e:
            logging.error(f"An unexpected error occurred loading config file '{chosen_path}': {e}. Proceeding with interactive setup.")
        # If loading failed, fall through to interactive setup
    
    # If not loading from file, or if loading failed, proceed with full interactive setup
    logging.info("Proceeding with interactive configuration setup.")
    _interactive_setup_configuration(output_dir_path) # Calls the new helper

# Place this function definition before main_interactive_pipeline()

def _interactive_setup_configuration(output_dir_path: Path) -> None:
    """
    Handles the complete interactive configuration process, populating the global CONFIG.
    This includes file paths, data loading, interactive merge, saving merged data,
    and then selecting operational columns, validating features, and other settings.
    """
    global CONFIG

    print_section_header("Pipeline Configuration (Interactive Mode)")

    # --- Step 1: Get File Paths ---
    # output_base_dir is already implicitly handled by output_dir_path's parent
    # output_dir_actual is output_dir_path
    CONFIG.update(get_file_paths_config()) # This gets cases_file_path, judges_file_path
                                           # and might re-get output_base_dir if user changes it,
                                           # but output_dir_actual_path is already fixed for this run.
    CONFIG['output_dir_actual'] = str(output_dir_path) # Ensure this is always set

    # --- Step 2: Load Raw Data ---
    logging.info("Starting raw data load for interactive configuration...")
    df_cases_raw = load_data_from_path(CONFIG.get('cases_file_path'), "Raw Cases Data")
    if df_cases_raw is None:
        logging.critical("Raw cases data could not be loaded. Cannot proceed with interactive setup.")
        sys.exit("Raw cases data is essential for configuration.")

    df_judges_raw = None
    if CONFIG.get('judges_file_path'): # Check if path is not None and not empty string
        df_judges_raw = load_data_from_path(CONFIG.get('judges_file_path'), "Raw Judges Data")
        if df_judges_raw is None:
            logging.warning("Judges file path was provided, but data could not be loaded. Merge will be skipped.")

    # --- Step 3: Configure and Perform Data Merge Interactively ---
    df_merged_for_pipeline = df_cases_raw.copy() # Default to cases data if no judges or merge fails

    if df_judges_raw is not None:
        print_section_header("Configure Data Merge")
        display_df_columns(df_cases_raw, "RAW CASES Data - Select Linking Column")
        cases_link_col_name = _select_columns_interactive(
            list(df_cases_raw.columns),
            "Enter column name from RAW CASES data to link with Judges data",
            allow_multiple=False, is_required=True
        )
        
        display_df_columns(df_judges_raw, "RAW JUDGES Data - Select Linking Column")
        judges_link_col_name = _select_columns_interactive(
            list(df_judges_raw.columns),
            "Enter corresponding column name from RAW JUDGES data",
            allow_multiple=False, is_required=True
        )
        
        merge_type = safe_input(
            "Enter merge type (e.g., 'left', 'inner')",
            default='left',
            choices=['left', 'right', 'inner', 'outer']
        ).lower()

        if cases_link_col_name and judges_link_col_name:
            # Store merge parameters in CONFIG for run_pipeline and reproducibility
            CONFIG['merge_config'] = {
                'cases_link_col': str(cases_link_col_name),
                'judges_link_col': str(judges_link_col_name),
                'how': merge_type
            }
            try:
                # Perform the merge for subsequent steps
                # Ensure keys are strings for pd.merge
                df_cases_raw_copy = df_cases_raw.copy() # Use copies for merge operation
                df_judges_raw_copy = df_judges_raw.copy()
                df_cases_raw_copy[str(cases_link_col_name)] = df_cases_raw_copy[str(cases_link_col_name)].astype(str)
                df_judges_raw_copy[str(judges_link_col_name)] = df_judges_raw_copy[str(judges_link_col_name)].astype(str)

                df_merged_for_pipeline = pd.merge(
                    df_cases_raw_copy, df_judges_raw_copy,
                    left_on=str(cases_link_col_name),
                    right_on=str(judges_link_col_name),
                    how=merge_type,
                    suffixes=('_case', '_judge') # Handle overlapping column names
                )
                logging.info(f"Interactive merge successful. Merged data shape: {df_merged_for_pipeline.shape}")
            except Exception as e:
                logging.error(f"Error during interactive merge: {e}. Proceeding with unmerged CASES data.")
                df_merged_for_pipeline = df_cases_raw.copy() # Fallback
                CONFIG.pop('merge_config', None) # Remove invalid merge config
        else:
            logging.warning("Linking columns for merge not specified by user. Proceeding with unmerged CASES data.")
            df_merged_for_pipeline = df_cases_raw.copy()
            CONFIG.pop('merge_config', None)
    else:
        logging.info("No judges data provided. Skipping merge configuration. Using raw cases data.")
        CONFIG.pop('merge_config', None) # Ensure no old merge_config is present


    # --- Step 4: Save the Merged Data ---
    merged_data_filename = "01_merged_data_interactive_setup.xlsx"
    merged_data_file_path = output_dir_path / merged_data_filename
    try:
        df_merged_for_pipeline.to_excel(merged_data_file_path, index=False)
        logging.info(f"Interactively merged data saved to: {merged_data_file_path}")
        CONFIG['merged_data_path'] = str(merged_data_file_path) # Store path for reference
    except Exception as e:
        logging.error(f"Failed to save interactively merged data: {e}")
        # Pipeline can still continue if df_merged_for_pipeline is in memory

    if df_merged_for_pipeline.empty: # Critical check after merge attempt
        logging.critical("Data is empty after merge attempt. Cannot proceed with configuration.")
        sys.exit("Data became empty after merge. Check merge keys and data.")

    # --- Step 5: Configure Operational Columns (using your patched function) ---
    # This uses the df_merged_for_pipeline created above
    CONFIG.update(get_operational_columns_config(df_merged_for_pipeline, CONFIG))

    # --- Step 6: Validate Feature Types ---
    if not CONFIG.get('feature_cols'):
        logging.warning("No feature columns selected. Skipping feature type validation.")
    else:
        numerical_cols, categorical_cols = validate_feature_types_interactive(
            df_merged_for_pipeline, # Use the same merged df
            CONFIG['feature_cols']
        )
        CONFIG['numerical_features'] = numerical_cols
        CONFIG['categorical_features'] = categorical_cols

    # --- Step 7: Configure Target Encoding ---
    CONFIG.update(get_encoding_config_interactive(
        df_merged_for_pipeline, # Use the same merged df
        CONFIG['target_col']
    ))

    # --- Step 8: Configure Bucketing ---
    CONFIG.update(get_bucketing_config_interactive())

    # --- Step 9: Configure Balancing ---
    CONFIG.update(get_balancing_config_interactive())

    # --- Step 10: Configure CV, Models, Hyperparameters ---
    CONFIG.update(get_cv_model_hyperparam_config())

    logging.info("Interactive configuration process complete.")
    # The final CONFIG is now populated. It will be saved by the calling function.

# This is your main script entry point

def main_interactive_pipeline() -> None:
    """Main function to run the pipeline, handling initial setup and config."""
    global CONFIG # CONFIG will be populated by load_or_trigger_interactive_config
    
    # 0. Initial Setup (Output Dir for this specific run, Preliminary Logging)
    # Determine initial base directory for outputs. This can be a constant or user-defined.
    # For simplicity, let's assume a default. The user can change it during get_file_paths_config
    # if interactive setup is chosen, but the actual experiment_YYYYMMDD_HHMMSS dir is created once.
    initial_base_dir = DEFAULT_OUTPUT_BASE_DIR 
    
    # Create the unique timestamped output directory for this run
    # This directory will house logs, configs, and results for THIS execution.
    try:
        output_dir_actual_str = create_output_directory(initial_base_dir)
        output_dir_actual_path = Path(output_dir_actual_str)
    except Exception as e:
        # Fallback if directory creation fails (e.g. permissions)
        print(f"CRITICAL: Failed to create output directory in '{initial_base_dir}': {e}")
        print("Please check permissions or specify a different base directory.")
        # Use a local fallback directory
        fallback_dir_name = f"pipeline_outputs_fallback_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        output_dir_actual_path = Path(fallback_dir_name)
        output_dir_actual_path.mkdir(parents=True, exist_ok=True)
        print(f"Using fallback output directory: {output_dir_actual_path.resolve()}")
        initial_base_dir = str(output_dir_actual_path.parent) # Update base dir to reflect fallback
        # Store these paths in CONFIG early so logging can use them.
        CONFIG['output_dir_actual'] = str(output_dir_actual_path)
        CONFIG['output_base_dir'] = initial_base_dir


    # Setup logging to the actual output directory (timestamped one)
    log_file_path = output_dir_actual_path / LOG_FILENAME
    setup_logging(log_file_path) # Logging starts here properly for the run

    # Store the determined output paths in CONFIG so they are saved if user saves config
    # and accessible by all functions.
    CONFIG['output_dir_actual'] = str(output_dir_actual_path)
    CONFIG['output_base_dir'] = initial_base_dir # This might be updated if user changes it in interactive file paths

    try:
        # 1. Load configuration from file OR go through full interactive setup
        # This function will populate the global CONFIG dictionary.
        load_or_trigger_interactive_config(output_dir_actual_path)

        # 2. Save the final configuration (whether loaded or interactively built)
        # This ensures the config used for *this run* is saved in *this run's output dir*.
        final_config_path_for_this_run = output_dir_actual_path / CONFIG_FILENAME
        save_config(CONFIG, final_config_path_for_this_run)
        logging.info(f"Final configuration for this run saved to: {final_config_path_for_this_run}")

        # 3. Run the main processing pipeline using the populated CONFIG
        run_pipeline() # run_pipeline will use the global CONFIG

    except SystemExit as e:
        logging.error(f"Pipeline terminated: {e}")
        print(f"Pipeline terminated: {e}")
    except Exception as e:
        logging.critical("An unhandled critical error occurred in the main pipeline driver:")
        logging.critical(traceback.format_exc()) # Log full traceback
        print(f"CRITICAL PIPELINE ERROR: {e}\nSee log file for details: {log_file_path}")
    finally:
        logging.info("Performing final cleanup if any...")
        gc.collect()
        logging.info("Exiting legal case outcome prediction pipeline application.")

if __name__ == "__main__":
    main_interactive_pipeline()



Load configuration from a file? (Default config if exists: pipeline_outputs\experiment_20250630_212223\run_config.json) (y/N):  n



                   Pipeline Configuration (Interactive Mode)                    


                            File Paths Configuration                            



Enter base directory for outputs (default: pipeline_outputs):  
Enter path to cases XLSX file:  "C:\Users\guill\OneDrive\Documentos\MDPI_STATS\EXPERIMENT3\OUTCOME18937.xlsx"
Enter path to judges XLSX file (optional, press Enter to skip) (default: ):  "C:\Users\guill\OneDrive\Documentos\MDPI_STATS\EXPERIMENT2\JUDGES_pseudo.xlsx"



                              Configure Data Merge                              


                     RAW CASES Data - Select Linking Column                     

+-------------------------+------------------------------------------------------------------+
|   Index (for this list) | Column Name                                                      |
|                       0 | source_filename                                                  |
+-------------------------+------------------------------------------------------------------+
|                       1 | majority                                                         |
+-------------------------+------------------------------------------------------------------+
|                       2 | foster care                                                      |
+-------------------------+------------------------------------------------------------------+
|                       3 | name                                          

Enter column name from RAW CASES data to link with Judges data:  0



                    RAW JUDGES Data - Select Linking Column                     

+-------------------------+----------------------------+
|   Index (for this list) | Column Name                |
|                       0 | source_file                |
+-------------------------+----------------------------+
|                       1 | capp_city                  |
+-------------------------+----------------------------+
|                       2 | capp_date                  |
+-------------------------+----------------------------+
|                       3 | CaseID                     |
+-------------------------+----------------------------+
|                       4 | jaf_city                   |
+-------------------------+----------------------------+
|                       5 | sex1                       |
+-------------------------+----------------------------+
|                       6 | judge1                     |
+-------------------------+----------------------------+
|    

Enter corresponding column name from RAW JUDGES data:  0
Enter merge type (e.g., 'left', 'inner') (default: left) (options: left, right, inner, outer):  left



                  Select Operational Columns from Merged Data                   


                   Full List of Columns in Final Merged Data                    

+-------------------------+------------------------------------------------------------------+
|   Index (for this list) | Column Name                                                      |
|                       0 | source_filename                                                  |
+-------------------------+------------------------------------------------------------------+
|                       1 | majority                                                         |
+-------------------------+------------------------------------------------------------------+
|                       2 | foster care                                                      |
+-------------------------+------------------------------------------------------------------+
|                       3 | name                                          

Will you use judge-specific bucketing? (Y/n):  y



--- Select Judge ID for Bucketing ---


Enter the column NAME or INDEX (from the full list above) for Judge ID (for bucketing):  52



--- Select Target Column ---


Enter the TARGET column NAME or INDEX (from the full list above, e.g., 'custody_outcome'):  7



                             Select Feature Columns                             


   POTENTIAL FEATURES (select from this list using its 0-based index or name)   

+-------------------------+------------------------------------------------------------------+
|   Index (for this list) | Column Name                                                      |
|                       0 | source_filename                                                  |
+-------------------------+------------------------------------------------------------------+
|                       1 | majority                                                         |
+-------------------------+------------------------------------------------------------------+
|                       2 | foster care                                                      |
+-------------------------+------------------------------------------------------------------+
|                       3 | name                                          

Use all 58 listed POTENTIAL features? (Y/n):  n
Enter FEATURE column NAMES or INDICES (comma-separated, from the POTENTIAL FEATURES list above):  12,13,14,16,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,34,36,37,38,39,40,41,42,43,44,56,57,50



                        Confirm Final Selected Features                         

You have selected the following features for the model:
+----------------+------------------------------------------------------------------+
|   # (For Info) | Selected Feature Name                                            |
|              0 | child_child_expressed_conflict                                   |
+----------------+------------------------------------------------------------------+
|              1 | child_child_expressed_living_arrangement_preference              |
+----------------+------------------------------------------------------------------+
|              2 | child_during_appeal_father_request_regarding_living_arrangements |
+----------------+------------------------------------------------------------------+
|              3 | child_during_appeal_mother_request_regarding_living_arrangements |
+----------------+------------------------------------------------------------------+
|

Proceed with these features? (Y/n):  y



                            Feature Type Validation                             


--- Validating Feature: 'child_child_expressed_conflict' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['father' 'none' 'mother']
  Number of Unique Values (non-NaN): 3
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'child_child_expressed_conflict' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'child_child_expressed_living_arrangement_preference' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['mother' 'unknown' 'father' 'shared']
  Number of Unique Values (non-NaN): 4
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'child_child_expressed_living_arrangement_preference' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'child_during_appeal_father_request_regarding_living_arrangements' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['none' 'sole' 'shared' 'unknown']
  Number of Unique Values (non-NaN): 4
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'child_during_appeal_father_request_regarding_living_arrangements' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'child_during_appeal_mother_request_regarding_living_arrangements' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['sole' 'shared' 'none' 'unknown']
  Number of Unique Values (non-NaN): 4
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'child_during_appeal_mother_request_regarding_living_arrangements' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_parental_fitness' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['unfit' 'fit']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_parental_fitness' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_parental_fitness' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['fit' 'unfit']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_parental_fitness' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_has_history_of_abuse_against_child' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_has_history_of_abuse_against_child' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_has_history_of_abuse_against_child' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_has_history_of_abuse_against_child' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_has_history_of_abuse_against_mother' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_has_history_of_abuse_against_mother' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_has_history_of_abuse_against_father' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_has_history_of_abuse_against_father' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_has_history_of_neglect' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['yes' 'no']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_has_history_of_neglect' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_has_history_of_neglect' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_has_history_of_neglect' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_has_psych_issues' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_has_psych_issues' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_has_psych_issues' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_has_psych_issues' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_has_addiction_issues' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_has_addiction_issues' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_has_addiction_issues' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_has_addiction_issues' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_is_invested_with_child' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_is_invested_with_child' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_is_invested_with_child' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['yes' 'no']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_is_invested_with_child' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_employment_status' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['unemployed' 'employed']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_employment_status' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_employment_status' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['unemployed' 'employed']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_employment_status' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_work_availability' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['unknown' 'regular' 'irregular']
  Number of Unique Values (non-NaN): 3
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_work_availability' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_work_availability' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['unknown' 'regular' 'irregular']
  Number of Unique Values (non-NaN): 3
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_work_availability' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_housing_status' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['unknown' 'adequate' 'inadequate']
  Number of Unique Values (non-NaN): 3
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_housing_status' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_housing_status' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['unknown' 'adequate' 'inadequate']
  Number of Unique Values (non-NaN): 3
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_housing_status' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'Parent Proximity' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'Parent Proximity' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_lives_near_school' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes' 'unknown']
  Number of Unique Values (non-NaN): 3
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_lives_near_school' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_lives_near_school' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['yes' 'no' 'unknown']
  Number of Unique Values (non-NaN): 3
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_lives_near_school' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_receives_social_aid' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'father_receives_social_aid' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_receives_social_aid' ---
  Original Data Type: object
  Unique Values (sample, non-NaN): ['no' 'yes']
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Detected as non-numeric dtype.


  Confirm type for 'mother_receives_social_aid' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'mother_benefited_legal_aid' ---
  Original Data Type: int64
  Unique Values (sample, non-NaN): [1 0]
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Numeric dtype, but <=7 unique values suggests it might be categorical (e.g., codes).


  Confirm type for 'mother_benefited_legal_aid' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'father_benefited_legal_aid' ---
  Original Data Type: int64
  Unique Values (sample, non-NaN): [0 1]
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Numeric dtype, but <=7 unique values suggests it might be categorical (e.g., codes).


  Confirm type for 'father_benefited_legal_aid' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

--- Validating Feature: 'sex1' ---
  Original Data Type: float64
  Unique Values (sample, non-NaN): [1. 0.]
  Number of Unique Values (non-NaN): 2
  Suggestion: Treat as CATEGORICAL. Reason: Numeric dtype, but <=7 unique values suggests it might be categorical (e.g., codes).


  Confirm type for 'sex1' (NUMERIC or CATEGORICAL)? (default: categorical) (options: numeric, categorical):  categorical


    Selected as CATEGORICAL.

                          Data Encoding Configuration                           


Target Variable: 'child_immediately_after_appeal_child_living_arrangement'
  Unique values found in 'child_immediately_after_appeal_child_living_arrangement': ['mother' 'father' 'shared']
  Expected target classes: mother, father, shared (mapped to [0, 1, 2])

Please map your raw target labels to the standard numeric classes:


  Enter raw label in your data that corresponds to 'mother' (Class 0):  mother
  Enter raw label in your data that corresponds to 'father' (Class 1):  father
  Enter raw label in your data that corresponds to 'shared' (Class 2):  shared



                         Judge Bucketing Configuration                          



Enable judge-specific bucketing? (Y/n):  y
Min samples per judge for dedicated bucket (others to 'generic') (default: 30):  300



                         Class Balancing Configuration                          



Balancing method (default: sampling) (options: none, sampling, weighting):  sampling
Target class percentages for (mother,father,shared) (default: 33,33,34):  40,30,30
Models to run (comma separated, available ['RandomForest', 'LogisticRegression', 'SVC', 'XGB']) (default: RandomForest,LogisticRegression,SVC,XGB):  RandomForest,SVC,XGB
Number of CV folds (default: 5):  5
RandomizedSearch iterations (default: 50):  18



                        Hyper-parameter Grid (optional)                         



Load hyper-parameter search space from a JSON file? (y/N):  y
Path to hyper-param JSON file:  "C:\Users\guill\OneDrive\Documentos\MDPI_STATS\CODE\CASE OUTCOME PREDICTION\HyperGrid_COMBO3.json"


✅  Custom hyper-parameter grid loaded.

                                 DATA OVERVIEW                                  

Final dataframe shape : (18937, 60)

                         Class Balancing Configuration                          

Current global distribution:

GLOBAL
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother | 12741 | 67.28 % |
| father | 4216  | 22.26 % |
| shared | 1980  | 10.46 % |
+--------+-------+---------+


Balancing method (default: sampling) (options: none, sampling, weighting):  sampling
Target class percentages for (mother,father,shared) (default: 33,33,34):  40,30,30

Test-set percentage (default: 20):  20


Validation share per fold ≈ 20.0%

                         BUCKET anatole  (824 samples)                          


Model: RandomForest

anatole_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  355  | 67.36 % |
| father |  122  | 23.15 % |
| shared |  50   | 9.49 %  |
+--------+-------+---------+

anatole_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  211  | 40.04 % |
| father |  158  | 29.98 % |
| shared |  158  | 29.98 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.824

anatole_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  355  | 67.36 % |
| father |  122  | 23.15 % |
| shared |  50   | 9.49 %  |
+--------+-------+---------+

anatole_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  211  | 40.04 % |
| fath



  TEST Macro-F1 = 0.830

Model: XGB

anatole_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  355  | 67.36 % |
| father |  122  | 23.15 % |
| shared |  50   | 9.49 %  |
+--------+-------+---------+

anatole_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  211  | 40.04 % |
| father |  158  | 29.98 % |
| shared |  158  | 29.98 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.758

anatole_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  355  | 67.36 % |
| father |  122  | 23.15 % |
| shared |  50   | 9.49 %  |
+--------+-------+---------+

anatole_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  211  | 40.04 % |
| father |  158  | 29.98 % |
| shared |  158  | 29.98 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.865

Model: XGB

babiche_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  324  | 65.85 % |
| father |  115  | 23.37 % |
| shared |  53   | 10.77 % |
+--------+-------+---------+

babiche_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  197  | 39.96 % |
| father |  148  | 30.02 % |
| shared |  148  | 30.02 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.741

babiche_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  324  | 65.85 % |
| father |  115  | 23.37 % |
| shared |  53   | 10.77 % |
+--------+-------+---------+

babiche_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  197  | 39.96 % |
| father |  148  | 30.02 % |
| shared |  148  | 30.02 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.808

Model: XGB

cabeche_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  295  | 63.17 % |
| father |  103  | 22.06 % |
| shared |  69   | 14.78 % |
+--------+-------+---------+

cabeche_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  187  | 40.04 % |
| father |  140  | 29.98 % |
| shared |  140  | 29.98 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.821

cabeche_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  295  | 63.17 % |
| father |  103  | 22.06 % |
| shared |  69   | 14.78 % |
+--------+-------+---------+

cabeche_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  187  | 40.04 % |
| father |  140  | 29.98 % |
| shared |  140  | 29.98 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.789

Model: XGB

dacrons_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  264  | 65.02 % |
| father |  97   | 23.89 % |
| shared |  45   | 11.08 % |
+--------+-------+---------+

dacrons_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  162  | 39.90 % |
| father |  122  | 30.05 % |
| shared |  122  | 30.05 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.840

dacrons_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  264  | 65.02 % |
| father |  97   | 23.89 % |
| shared |  45   | 11.08 % |
+--------+-------+---------+

dacrons_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  162  | 39.90 % |
| father |  122  | 30.05 % |
| shared |  122  | 30.05 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.855

Model: XGB

echevin_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  233  | 61.80 % |
| father |  100  | 26.53 % |
| shared |  44   | 11.67 % |
+--------+-------+---------+

echevin_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  151  | 40.05 % |
| father |  113  | 29.97 % |
| shared |  113  | 29.97 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.855

echevin_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  233  | 61.80 % |
| father |  100  | 26.53 % |
| shared |  44   | 11.67 % |
+--------+-------+---------+

echevin_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  151  | 40.05 % |
| father |  113  | 29.97 % |
| shared |  113  | 29.97 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.880

Model: XGB

gargote_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  275  | 75.34 % |
| father |  65   | 17.81 % |
| shared |  25   | 6.85 %  |
+--------+-------+---------+

gargote_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  146  | 39.89 % |
| father |  110  | 30.05 % |
| shared |  110  | 30.05 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.848

gargote_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  275  | 75.34 % |
| father |  65   | 17.81 % |
| shared |  25   | 6.85 %  |
+--------+-------+---------+

gargote_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  146  | 39.89 % |
| father |  110  | 30.05 % |
| shared |  110  | 30.05 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.807

Model: XGB

faubers_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  223  | 63.53 % |
| father |  97   | 27.64 % |
| shared |  31   | 8.83 %  |
+--------+-------+---------+

faubers_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  140  | 40.00 % |
| father |  105  | 30.00 % |
| shared |  105  | 30.00 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.792

faubers_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  223  | 63.53 % |
| father |  97   | 27.64 % |
| shared |  31   | 8.83 %  |
+--------+-------+---------+

faubers_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  140  | 40.00 % |
| father |  105  | 30.00 % |
| shared |  105  | 30.00 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.790

Model: XGB

hauynes_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  203  | 64.04 % |
| father |  80   | 25.24 % |
| shared |  34   | 10.73 % |
+--------+-------+---------+

hauynes_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  127  | 40.06 % |
| father |  95   | 29.97 % |
| shared |  95   | 29.97 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.833

hauynes_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  203  | 64.04 % |
| father |  81   | 25.55 % |
| shared |  33   | 10.41 % |
+--------+-------+---------+

hauynes_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  127  | 40.06 % |
| father |  95   | 29.97 % |
| shared |  95   | 29.97 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.842

Model: XGB

jobelin_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  165  | 63.46 % |
| father |  62   | 23.85 % |
| shared |  33   | 12.69 % |
+--------+-------+---------+

jobelin_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  104  | 40.00 % |
| father |  78   | 30.00 % |
| shared |  78   | 30.00 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.754

jobelin_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  165  | 63.22 % |
| father |  63   | 24.14 % |
| shared |  33   | 12.64 % |
+--------+-------+---------+

jobelin_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  104  | 40.00 % |
| father |  78   | 30.00 % |
| shared |  78   | 30.00 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.795

Model: XGB

inconel_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  164  | 63.08 % |
| father |  65   | 25.00 % |
| shared |  31   | 11.92 % |
+--------+-------+---------+

inconel_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  104  | 40.00 % |
| father |  78   | 30.00 % |
| shared |  78   | 30.00 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.790

inconel_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  164  | 62.84 % |
| father |  66   | 25.29 % |
| shared |  31   | 11.88 % |
+--------+-------+---------+

inconel_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  104  | 40.00 % |
| father |  78   | 30.00 % |
| shared |  78   | 30.00 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.821

Model: XGB

kochias_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  133  | 57.83 % |
| father |  59   | 25.65 % |
| shared |  38   | 16.52 % |
+--------+-------+---------+

kochias_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  92   | 40.00 % |
| father |  69   | 30.00 % |
| shared |  69   | 30.00 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.803

kochias_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  133  | 57.83 % |
| father |  59   | 25.65 % |
| shared |  38   | 16.52 % |
+--------+-------+---------+

kochias_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  92   | 40.00 % |
| father |  69   | 30.00 % |
| shared |  69   | 30.00 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.881

Model: XGB

labourg_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  146  | 74.49 % |
| father |  26   | 13.27 % |
| shared |  24   | 12.24 % |
+--------+-------+---------+

labourg_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  78   | 39.80 % |
| father |  59   | 30.10 % |
| shared |  59   | 30.10 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.849

labourg_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  146  | 74.11 % |
| father |  27   | 13.71 % |
| shared |  24   | 12.18 % |
+--------+-------+---------+

labourg_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother |  79   | 40.10 % |
| father |  59   | 29.95 % |
| shared |  59   | 29.95 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 



  TEST Macro-F1 = 0.814

Model: XGB

generic_f1 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother | 5198  | 68.39 % |
| father | 1643  | 21.62 % |
| shared |  760  | 10.00 % |
+--------+-------+---------+

generic_f1 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother | 3040  | 40.00 % |
| father | 2280  | 30.00 % |
| shared | 2280  | 30.00 % |
+--------+-------+---------+
  Fold 1/5  Macro-F1 = 0.821

generic_f2 – BEFORE
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother | 5198  | 68.39 % |
| father | 1643  | 21.62 % |
| shared |  760  | 10.00 % |
+--------+-------+---------+

generic_f2 – AFTER (sampled)
+--------+-------+---------+
| Class  | Count |    %    |
+--------+-------+---------+
| mother | 3040  | 40.00 % |
| father | 2280  | 30.00 % |
| shared | 2280  | 30.00 % |
+--------+-------+---------+
  Fold 2/5  Macro-F1 