In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
from catboost import CatBoostClassifier
from sklearn.metrics import roc_auc_score
import warnings

warnings.simplefilter("ignore", category=UserWarning)

###############################################################################
# 1. GLOBAL PARAMETERS & SETTINGS
###############################################################################
FP_COST = 10     # Penalty for false positive
FN_COST = 50     # Penalty for false negative (never treat but was sick)
D_COST  = 1      # Penalty per time-step of delay if the patient is sick and untreated
GAMMA   = 0.99   # Discount factor
T_MAX   = 20     # Time horizon (0..T_MAX-1)
FEATURE_COLS = ["time","EIT","NIRS","EIS"]  # Adjust for your dataset
RANDOM_SEED = 42

###############################################################################
# 2. SPLIT DATA & HELPER FUNCTIONS
###############################################################################
def split_patients_kfold(df, n_splits=4, seed=0):
    """Shuffle unique patient IDs, then split into n_splits+1 groups."""
    rng = np.random.RandomState(seed)
    unique_pts = df['patient_id'].unique()
    rng.shuffle(unique_pts)
    
    n = len(unique_pts)
    splits = {}
    for i in range(n_splits + 1):
        start_idx = int(i * n / (n_splits + 1))
        end_idx   = int((i + 1) * n / (n_splits + 1))
        group_name = f"G{i+1}"
        pid_subset = unique_pts[start_idx:end_idx]
        splits[group_name] = set(pid_subset)
    return splits

def filter_by_group(df, pid_set):
    """Return rows of df whose patient_id is in pid_set."""
    return df[df['patient_id'].isin(pid_set)].copy()

###############################################################################
# 3. TRAIN & PREDICT CATBOOST
###############################################################################
def train_and_predict_model(depth_val, lr_val, df_train, df_val, feature_cols=FEATURE_COLS):
    """Train CatBoost and return predicted risk scores on df_val."""
    X_train = df_train[feature_cols]
    y_train = df_train["label"]
    params = {
        "iterations": 50,
        "depth": int(round(depth_val)),
        "learning_rate": lr_val,
        "verbose": False,
        "random_seed": 42
    }
    model = CatBoostClassifier(**params)
    model.fit(X_train, y_train)
    risk_scores = model.predict_proba(df_val[feature_cols])[:,1]
    return risk_scores

###############################################################################
# 4. COST EVALUATION BY SIMULATION
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must have columns: [patient_id, time, label, predicted_risk, ...].
    policy_func(patient_rows) -> integer time step to treat, or None if never treat.
    Returns dict with total cost, precision, recall, etc.
    """
    results = []
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values("time")
        label = grp["label"].iloc[0]  # 0 or 1 at patient-level
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp = 0
                fp = 0
            else:
                cost = 0
                tp = 0
                fp = 0
            treated_flag = 0
            tt = None
        else:
            # treated at treat_time
            treated_flag = 1
            if label == 1:
                # cost is D_COST * treat_time
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            tt = treat_time
        
        results.append({
            "patient_id": pid,
            "label": label,
            "treated": treated_flag,
            "treat_time": tt,
            "cost": cost,
            "tp": tp,
            "fp": fp
        })
    df_res = pd.DataFrame(results)
    total_cost = df_res["cost"].sum()
    treated_df = df_res[df_res["treated"] == 1]
    tp_sum = treated_df["tp"].sum()
    fp_sum = treated_df["fp"].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res["label"] == 1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df["treat_time"].dropna()
        avg_tt = valid_tt.mean() if len(valid_tt)>0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        "cost": total_cost,
        "precision": precision,
        "recall": recall,
        "avg_treatment_time": avg_tt
    }

###############################################################################
# 5. BENCHMARK POLICY FUNCTIONS (PARAMETRIC) + COST EVAL WRAPPERS
###############################################################################
## A) CONSTANT THRESHOLD
def policy_func_constant_threshold(df_patient, thr):
    """Return first time t where predicted_risk >= thr, else None."""
    for _, row in df_patient.iterrows():
        if row["predicted_risk"] >= thr:
            return int(row["time"])
    return None

def evaluate_cost_constant_threshold(params, df_train, df_val):
    """
    params = (depth, lr, thr)
    1) Train CatBoost on df_train
    2) Predict risk on df_val
    3) Evaluate constant-thr policy
    """
    depth, lr, thr = params
    risk_val = train_and_predict_model(depth, lr, df_train, df_val)
    df_val_eval = df_val.copy()
    df_val_eval["predicted_risk"] = risk_val
    
    def policy_func(grp):
        return policy_func_constant_threshold(grp, thr)
    
    stats = simulate_policy(df_val_eval, policy_func)
    return stats["cost"]

## B) DYNAMIC THRESHOLD
def policy_func_dynamic_threshold(df_patient, thr_vec):
    """
    thr_vec is length T_MAX. For each row, if risk >= thr_vec[t], treat at time t.
    """
    for _, row in df_patient.iterrows():
        t = int(row["time"])
        if t < len(thr_vec):
            if row["predicted_risk"] >= thr_vec[t]:
                return t
    return None

def evaluate_cost_dynamic_threshold(params, df_train, df_val):
    """
    params = (depth, lr, thr_0, ..., thr_{T-1}) => total length T_MAX+2
    """
    depth = params[0]
    lr    = params[1]
    thr_vec = params[2:]  # T_MAX thresholds
    
    risk_val = train_and_predict_model(depth, lr, df_train, df_val)
    df_val_eval = df_val.copy()
    df_val_eval["predicted_risk"] = risk_val
    
    def policy_func(grp):
        return policy_func_dynamic_threshold(grp, thr_vec)
    
    stats = simulate_policy(df_val_eval, policy_func)
    return stats["cost"]

## C) LINEAR THRESHOLD
def policy_func_linear_threshold(df_patient, A, B):
    """
    threshold = clip(A*t + B, 0, 1)
    Return first time with predicted_risk >= threshold.
    """
    for _, row in df_patient.iterrows():
        t = row["time"]
        thr = A * t + B
        thr = np.clip(thr, 0, 1)
        if row["predicted_risk"] >= thr:
            return int(t)
    return None

def evaluate_cost_linear_threshold(params, df_train, df_val):
    """
    params = (depth, lr, A, B)
    """
    depth, lr, A, B = params
    risk_val = train_and_predict_model(depth, lr, df_train, df_val)
    df_val_eval = df_val.copy()
    df_val_eval["predicted_risk"] = risk_val
    
    def policy_func(grp):
        return policy_func_linear_threshold(grp, A, B)
    
    stats = simulate_policy(df_val_eval, policy_func)
    return stats["cost"]

## D) WAIT-TILL-END
def policy_func_wait_till_end(df_patient, thr):
    """
    Check only the final time step's predicted_risk; if >= thr => treat at final time.
    """
    final_row = df_patient.iloc[-1]  # after sorting by time
    if final_row["predicted_risk"] >= thr:
        return int(final_row["time"])
    return None

def evaluate_cost_wait_till_end(params, df_train, df_val):
    """
    params = (depth, lr, thr)
    """
    depth, lr, thr = params
    risk_val = train_and_predict_model(depth, lr, df_train, df_val)
    df_val_eval = df_val.copy()
    df_val_eval["predicted_risk"] = risk_val
    
    def policy_func(grp):
        return policy_func_wait_till_end(grp, thr)
    
    stats = simulate_policy(df_val_eval, policy_func)
    return stats["cost"]

###############################################################################
# 6. DATA-DRIVEN DP (SPSA-catboost)
###############################################################################
def assign_buckets(prob, n_buckets=5):
    edges = np.linspace(0, 1, n_buckets+1)
    for b in range(n_buckets):
        if prob >= edges[b] and prob < edges[b+1]:
            return b
    return n_buckets-1

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    """
    df_train: must have columns [patient_id, time, predicted_risk, label].
    Returns p_trans[t,b,b'] and p_sick[t,b].
    """
    transition_counts = np.zeros((T-1, n_buckets, n_buckets))
    bucket_counts = np.zeros((T, n_buckets))
    sick_counts   = np.zeros((T, n_buckets))
    
    df_sorted = df_train.sort_values(["patient_id","time"])
    for pid, grp in df_sorted.groupby("patient_id"):
        grp = grp.sort_values("time")
        rows = grp.to_dict("records")
        for i, row in enumerate(rows):
            t = int(row["time"])
            b = int(row["risk_bucket"])
            lbl= int(row["label"])
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            if i < len(rows)-1:
                row_next = rows[i+1]
                t_next   = int(row_next["time"])
                b_next   = int(row_next["risk_bucket"])
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t, b, b_next] += 1.0
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets))
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_, b_, :].sum()
            if denom > 0:
                p_trans[t_, b_, :] = transition_counts[t_, b_, :] / denom
            else:
                # if no data, assume self-loop
                p_trans[t_, b_, b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets))
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_, b_]
            if denom > 0:
                p_sick[t_, b_] = sick_counts[t_, b_] / denom
            else:
                p_sick[t_, b_] = 0.0
    
    return p_trans, p_sick

def train_data_driven_dp(p_trans, p_sick, FP=FP_COST, FN=FN_COST, D=D_COST, gamma=GAMMA, T=T_MAX):
    """
    DP to handle cost = false positives, false negatives, and per-step delay D.
    V[t,b] = minimal future cost if not yet treated at time t, bucket b.
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # Terminal boundary
    for b in range(n_buckets):
        cost_treat_T   = (1.0 - p_sick[T-1, b])*FP
        cost_notreat_T = p_sick[T-1, b]*FN
        V[T,b] = min(cost_treat_T, cost_notreat_T)
    
    # backward recursion
    for t in reversed(range(T)):
        for b in range(n_buckets):
            cost_treat = (1 - p_sick[t,b])*FP
            cost_wait_immediate = p_sick[t,b]*D
            if t == T-1:
                future_wait = V[T,b]
            else:
                future_wait = 0
                for b_next in range(n_buckets):
                    future_wait += p_trans[t,b,b_next]*V[t+1,b_next]
            cost_wait = cost_wait_immediate + gamma * future_wait
            
            if cost_treat <= cost_wait:
                V[t,b] = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b] = cost_wait
                pi_[t,b] = 0
    return V, pi_

def make_data_driven_dp_policy(V, pi_, p_sick, T=T_MAX):
    """
    If pi_[t,b] = 1 => treat immediately at time t.
    If we reach t=T with no action, do final treat-or-not by comparing 
    cost_treat_T vs. cost_notreat_T at the final bucket.
    """
    def policy_func(df_patient):
        df_patient = df_patient.sort_values("time")
        last_row = None
        for _, row in df_patient.iterrows():
            t = int(row["time"])
            if t < T:
                b = int(row["risk_bucket"])
                if pi_[t,b] == 1:
                    return t
            last_row = row
        
        # If no treat action up to T-1, handle final step:
        if last_row is not None:
            b_final = int(last_row["risk_bucket"])
            cost_treat_T   = (1.0 - p_sick[T-1, b_final])*FP_COST
            cost_notreat_T = p_sick[T-1, b_final]*FN_COST
            if cost_treat_T <= cost_notreat_T:
                return int(last_row["time"])
            else:
                return None
        return None
    return policy_func

def evaluate_cost_data_driven_dp(params, df_train, df_val):
    """
    params = (depth, lr)
    1) Train catboost on df_train
    2) Predict on df_train => risk buckets => estimate DP transitions
    3) Build DP, then apply policy to df_val
    """
    depth, lr = params
    
    # 1) train on df_train
    risk_train = train_and_predict_model(depth, lr, df_train, df_train)
    df_train_dp = df_train.copy()
    df_train_dp["predicted_risk"] = risk_train
    df_train_dp["risk_bucket"]    = df_train_dp["predicted_risk"].apply(assign_buckets)
    
    # get transitions, p_sick
    p_trans, p_sick = estimate_transition_and_sick_probs(df_train_dp, T=T_MAX, n_buckets=5)
    V, pi_ = train_data_driven_dp(p_trans, p_sick, FP=FP_COST, FN=FN_COST, D=D_COST, gamma=GAMMA, T=T_MAX)
    
    # 2) predict on df_val
    risk_val = train_and_predict_model(depth, lr, df_train, df_val)
    df_val_eval = df_val.copy()
    df_val_eval["predicted_risk"] = risk_val
    df_val_eval["risk_bucket"]    = df_val_eval["predicted_risk"].apply(assign_buckets)
    
    # 3) apply DP policy
    dp_policy = make_data_driven_dp_policy(V, pi_, p_sick, T=T_MAX)
    stats = simulate_policy(df_val_eval, dp_policy)
    return stats["cost"]

###############################################################################
# 7. GENERIC SPSA OPTIMIZATION
###############################################################################
def spsa_optimization(
    cost_func,         # cost_func(params, df_train, df_val) -> scalar cost
    df_train, df_val,
    param_dim,         # dimension of param vector
    param_bounds,      # list of (lower, upper) for each dimension
    n_iterations=20,
    alpha=0.602,
    gamma=0.101,
    a0=0.2,
    c0=0.1,
    seed=42
):
    """
    A general SPSA routine that operates on param_dim-dimensional parameter vectors.
    param_bounds[i] = (lb_i, ub_i).
    Returns best_params, best_cost.
    """
    rng = np.random.RandomState(seed)
    
    # Initialize param in the middle of each bound
    init_params = []
    for i in range(param_dim):
        lb, ub = param_bounds[i]
        init_params.append(0.5*(lb+ub))
    params = np.array(init_params)
    
    best_params = params.copy()
    best_cost   = float('inf')
    
    for k in range(1, n_iterations+1):
        a_k = a0 / (k**alpha)
        c_k = c0 / (k**gamma)
        
        delta = rng.choice([-1,1], size=param_dim)
        params_plus  = params + c_k * delta
        params_minus = params - c_k * delta
        
        # clip
        for i in range(param_dim):
            lb, ub = param_bounds[i]
            params_plus[i]  = np.clip(params_plus[i], lb, ub)
            params_minus[i] = np.clip(params_minus[i], lb, ub)
        
        # Evaluate cost
        cost_plus  = cost_func(params_plus,  df_train, df_val)
        cost_minus = cost_func(params_minus, df_train, df_val)
        
        # Approx gradient
        g_k = (cost_plus - cost_minus)/(2.0*c_k) * delta
        
        # Update
        params_new = params - a_k*g_k
        # clip
        for i in range(param_dim):
            lb, ub = param_bounds[i]
            params_new[i] = np.clip(params_new[i], lb, ub)
        
        cost_new = cost_func(params_new, df_train, df_val)
        if cost_new < best_cost:
            best_cost   = cost_new
            best_params = params_new.copy()
        params = params_new
    
    return best_params, best_cost

###############################################################################
# 8. CROSS-VALIDATION FOR EACH APPROACH, THEN TEST ON HOLDOUT (ALGORITHM 5)
###############################################################################
def cross_val_sum_of_cost(cost_func, candidate_params, group_dfs, n_splits=4):
    """
    Re-evaluates a given set of hyperparameters across *all* folds G1..Gn, 
    summing up the cost on each fold's validation set.
    """
    total_cost = 0.0
    for val_i in range(1, n_splits+1):
        df_val_i = group_dfs[f"G{val_i}"]
        # Training is all groups except G{val_i}
        train_sets_i = []
        for j in range(1, n_splits+1):
            if j != val_i:
                train_sets_i.append(group_dfs[f"G{j}"])
        df_train_i = pd.concat(train_sets_i, ignore_index=True)
        
        fold_cost = cost_func(candidate_params, df_train_i, df_val_i)
        total_cost += fold_cost
    
    return total_cost

def run_experiment_algorithm5_fair(df_all, n_splits=4, seed=42, n_spsa_iters=20):
    """
    We'll do the following for each of the 5 methods:
      1) Split into n folds + 1 holdout
      2) For each fold i=1..n, run SPSA to find best_params_i
      3) Then pick (lambda*, mu*) that yields the *lowest sum of cost* over all folds
      4) Retrain on union of folds (all training) and evaluate on holdout
    """
    # 1) Split data
    splits = split_patients_kfold(df_all, n_splits=n_splits, seed=seed)
    group_dfs = {}
    for g, pidset in splits.items():
        group_dfs[g] = filter_by_group(df_all, pidset)
    test_name = f"G{n_splits+1}"
    df_test = group_dfs[test_name]
    
    # We'll define the 5 approaches & their param space
    # 1) constant threshold -> param = (depth, lr, thr) => dim=3
    # 2) dynamic threshold -> param = (depth, lr, thr_0..thr_{T-1}) => dim=2+T_MAX
    # 3) linear threshold -> param = (depth, lr, A, B) => dim=4
    # 4) wait till end -> param = (depth, lr, thr) => dim=3
    # 5) data-driven dp -> param = (depth, lr) => dim=2
    
    method_defs = [
      {
        "name": "Constant Threshold",
        "cost_func": evaluate_cost_constant_threshold,
        "dim": 3,
        "bounds": [(2.0, 8.0), (0.01, 0.3), (0.0, 1.0)],
      },
      {
        "name": "Dynamic Threshold",
        "cost_func": evaluate_cost_dynamic_threshold,
        "dim": 2 + T_MAX,
        "bounds": ([(2.0, 8.0), (0.01, 0.3)] + [(0.0,1.0)]*T_MAX),
      },
      {
        "name": "Linear Threshold",
        "cost_func": evaluate_cost_linear_threshold,
        "dim": 4,
        "bounds": [(2.0,8.0), (0.01,0.3), (-0.1,0.1), (0.0,1.0)],
      },
      {
        "name": "Wait Till End",
        "cost_func": evaluate_cost_wait_till_end,
        "dim": 3,
        "bounds": [(2.0, 8.0), (0.01, 0.3), (0.0, 1.0)],
      },
      {
        "name": "Data-Driven DP (SPSA-catboost)",
        "cost_func": evaluate_cost_data_driven_dp,
        "dim": 2,
        "bounds": [(2.0,8.0), (0.01,0.3)],
      },
    ]
    
    final_records = []
    
    # For each method, do cross-validation with SPSA
    for mdef in method_defs:
        method_name = mdef["name"]
        cost_func   = mdef["cost_func"]
        param_dim   = mdef["dim"]
        param_bounds= mdef["bounds"]
        
        # Step 2: For each fold i in 1..n, run SPSA
        fold_records = []
        for i_val in range(1, n_splits+1):
            df_val = group_dfs[f"G{i_val}"]
            # train folds = all except i_val
            train_sets = []
            for j in range(1, n_splits+1):
                if j != i_val:
                    train_sets.append(group_dfs[f"G{j}"])
            df_train_fold = pd.concat(train_sets, ignore_index=True)
            
            # run SPSA on this fold
            best_params_fold, best_cost_fold = spsa_optimization(
                cost_func,
                df_train_fold, df_val,
                param_dim, param_bounds,
                n_iterations=n_spsa_iters,
                seed=42
            )
            fold_records.append({
                "fold": i_val,
                "best_params": best_params_fold,
                "val_cost": best_cost_fold
            })
        
        # We now have n sets of best_params (one from each fold).
        # Step 3: Re-evaluate each candidate on *all* folds, sum the cost, pick best.
        sum_costs = []
        for rec in fold_records:
            candidate_params = rec["best_params"]
            total_cv_cost = cross_val_sum_of_cost(cost_func, candidate_params, group_dfs, n_splits)
            sum_costs.append(total_cv_cost)
        
        # Find the param with minimal sum-of-CV-cost
        best_idx = np.argmin(sum_costs)
        chosen_params = fold_records[best_idx]["best_params"]
        chosen_sum_cost = sum_costs[best_idx]
        
        # Step 4: Retrain on union of G1..G_n, evaluate on G_{n+1}
        train_all = []
        for i in range(1, n_splits+1):
            train_all.append(group_dfs[f"G{i}"])
        df_train_all = pd.concat(train_all, ignore_index=True)
        
        # Evaluate final holdout cost
        final_test_cost = cost_func(chosen_params, df_train_all, df_test)
        
        # If you want precision/recall, let's do the actual simulation:
        stats = simulate_method_on_test(
            method_name, chosen_params,
            df_train_all, df_test
        )
        
        final_records.append({
            "Method": method_name,
            "Cost": stats["cost"],
            "Precision (%)": 100*stats["precision"],
            "Recall (%)": 100*stats["recall"],
            "Avg. Treat Time": stats["avg_treatment_time"],
            "Sum-of-CV-Cost": chosen_sum_cost  # just to see how good it was on CV
        })
    
    df_final = pd.DataFrame(final_records)
    return df_final

def simulate_method_on_test(method_name, params, df_train_all, df_test):
    """
    For each method, we do the same steps as in 'evaluate_cost_...'
    but then actually run 'simulate_policy' to get precision, recall, etc.
    """
    if method_name == "Constant Threshold":
        depth, lr, thr = params
        risk_test = train_and_predict_model(depth, lr, df_train_all, df_test)
        df_test_eval = df_test.copy()
        df_test_eval["predicted_risk"] = risk_test
        
        def policy_func(grp):
            return policy_func_constant_threshold(grp, thr)
        stats = simulate_policy(df_test_eval, policy_func)
        return stats
    
    elif method_name == "Dynamic Threshold":
        depth = params[0]
        lr    = params[1]
        thr_vec = params[2:]
        risk_test = train_and_predict_model(depth, lr, df_train_all, df_test)
        df_test_eval = df_test.copy()
        df_test_eval["predicted_risk"] = risk_test
        
        def policy_func(grp):
            return policy_func_dynamic_threshold(grp, thr_vec)
        stats = simulate_policy(df_test_eval, policy_func)
        return stats
    
    elif method_name == "Linear Threshold":
        depth, lr, A, B = params
        risk_test = train_and_predict_model(depth, lr, df_train_all, df_test)
        df_test_eval = df_test.copy()
        df_test_eval["predicted_risk"] = risk_test
        
        def policy_func(grp):
            return policy_func_linear_threshold(grp, A, B)
        stats = simulate_policy(df_test_eval, policy_func)
        return stats
    
    elif method_name == "Wait Till End":
        depth, lr, thr = params
        risk_test = train_and_predict_model(depth, lr, df_train_all, df_test)
        df_test_eval = df_test.copy()
        df_test_eval["predicted_risk"] = risk_test
        
        def policy_func(grp):
            return policy_func_wait_till_end(grp, thr)
        stats = simulate_policy(df_test_eval, policy_func)
        return stats
    
    elif method_name == "Data-Driven DP (SPSA-catboost)":
        depth, lr = params
        # train on df_train_all
        risk_train = train_and_predict_model(depth, lr, df_train_all, df_train_all)
        df_train_dp = df_train_all.copy()
        df_train_dp["predicted_risk"] = risk_train
        df_train_dp["risk_bucket"]    = df_train_dp["predicted_risk"].apply(assign_buckets)
        p_trans, p_sick = estimate_transition_and_sick_probs(df_train_dp, T=T_MAX, n_buckets=5)
        V, pi_ = train_data_driven_dp(p_trans, p_sick, FP=FP_COST, FN=FN_COST, D=D_COST, gamma=GAMMA, T=T_MAX)
        
        # predict on df_test
        risk_test = train_and_predict_model(depth, lr, df_train_all, df_test)
        df_test_eval = df_test.copy()
        df_test_eval["predicted_risk"] = risk_test
        df_test_eval["risk_bucket"]    = df_test_eval["predicted_risk"].apply(assign_buckets)
        
        dp_policy = make_data_driven_dp_policy(V, pi_, p_sick, T=T_MAX)
        stats = simulate_policy(df_test_eval, dp_policy)
        return stats
    
    else:
        raise ValueError(f"Unknown method: {method_name}")

###############################################################################
# 9. MAIN
###############################################################################
def main():
    df_all = pd.read_csv("dp_favoring_synthetic_patients.csv")  # your dataset
    
    # run 1 replication with a given random seed
    n_replications = 1
    all_tables = []
    
    for rep in range(n_replications):
        seed_val = rep
        print(f"\n=== RUNNING REPLICATION {rep+1}/{n_replications}, seed={seed_val} ===")
        
        df_final = run_experiment_algorithm5_fair(
            df_all, 
            n_splits=4,
            seed=seed_val,
            n_spsa_iters=20  # can be increased
        )
        print("\nResults on final holdout:")
        print(df_final)
        all_tables.append(df_final)
    
    # If multiple replications, aggregate them
    if n_replications > 1:
        df_agg = aggregate_results(all_tables)
        print("\n=== AGGREGATED RESULTS (Mean ± Std) ===")
        print(df_agg.to_string(index=False))

def aggregate_results(list_of_tables):
    """
    Aggregates multiple final_table DataFrames by computing mean ± std.
    """
    from collections import defaultdict
    data_accum = defaultdict(lambda: {
        "Cost": [],
        "Precision": [],
        "Recall": [],
        "Time": []
    })
    
    for df_table in list_of_tables:
        for idx in range(len(df_table)):
            row = df_table.iloc[idx]
            method = row["Method"]
            data_accum[method]["Cost"].append(row["Cost"])
            data_accum[method]["Precision"].append(row["Precision (%)"])
            data_accum[method]["Recall"].append(row["Recall (%)"])
            data_accum[method]["Time"].append(row["Avg. Treat Time"])
    
    results = []
    method_order = [
        "Constant Threshold",
        "Dynamic Threshold",
        "Linear Threshold",
        "Wait Till End",
        "Data-Driven DP (SPSA-catboost)"
    ]
    for m in method_order:
        arr_cost = np.array(data_accum[m]["Cost"])
        arr_prec = np.array(data_accum[m]["Precision"])
        arr_rec  = np.array(data_accum[m]["Recall"])
        arr_time = np.array(data_accum[m]["Time"])
        
        cost_mean, cost_std   = arr_cost.mean(), arr_cost.std()
        prec_mean, prec_std   = arr_prec.mean(), arr_prec.std()
        rec_mean,  rec_std    = arr_rec.mean(),  arr_rec.std()
        time_mean, time_std   = arr_time.mean(), arr_time.std()
        
        results.append({
            "Method": m,
            "Cost": f"{cost_mean:.2f} ± {cost_std:.2f}",
            "Precision (%)": f"{prec_mean:.2f} ± {prec_std:.2f}",
            "Recall (%)": f"{rec_mean:.2f} ± {rec_std:.2f}",
            "Avg. Treat Time": f"{time_mean:.2f} ± {time_std:.2f}"
        })
    
    return pd.DataFrame(results)

if __name__ == "__main__":
    main()


=== RUNNING REPLICATION 1/1, seed=0 ===

Results on final holdout:
                           Method  Cost  Precision (%)  Recall (%)  \
0              Constant Threshold   960      20.000000       100.0   
1               Dynamic Threshold  1008      20.000000       100.0   
2                Linear Threshold   960      20.000000       100.0   
3                   Wait Till End   456     100.000000       100.0   
4  Data-Driven DP (SPSA-catboost)   345      85.714286       100.0   

   Avg. Treat Time  Sum-of-CV-Cost  
0             0.00          3840.0  
1             2.00          4032.0  
2             0.00          3840.0  
3            19.00          1855.0  
4            12.75          1389.0  


In [13]:
"""
ALGORITHM 5 (SPSA) + BENCHMARK TABLE (Unconstrained Hemorrhage Diagnosis & Treatment)

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 1
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)

###############################################################################
# 2. HELPER FUNCTIONS (DATA SPLIT, METRICS, ETC.)
###############################################################################
def split_into_four_groups(df, seed=0):
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    i1 = int(0.25 * n)
    i2 = int(0.50 * n)
    i3 = int(0.75 * n)
    
    G1_pids = unique_pids[: i1]
    G2_pids = unique_pids[i1 : i2]
    G3_pids = unique_pids[i2 : i3]
    G4_pids = unique_pids[i3 : ]
    
    G1 = df[df['patient_id'].isin(G1_pids)].copy()
    G2 = df[df['patient_id'].isin(G2_pids)].copy()
    G3 = df[df['patient_id'].isin(G3_pids)].copy()
    G4 = df[df['patient_id'].isin(G4_pids)].copy()
    return G1, G2, G3, G4


def k_plus_1_splits(df, k=3, seed=0):
    """
    For Algorithm 5 (SPSA), we often split the data into k+1 groups:
      G1, G2, ..., Gk, G_{k+1}.
    This is a k-fold style partition plus one final holdout.
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    fold_size = int(n/(k+1))
    
    groups = []
    start = 0
    for i in range(k):
        pids_fold = unique_pids[start : start+fold_size]
        start += fold_size
        g = df[df['patient_id'].isin(pids_fold)].copy()
        groups.append(g)
    
    # final group is everything leftover
    pids_fold = unique_pids[start:]
    g = df[df['patient_id'].isin(pids_fold)].copy()
    groups.append(g)
    return groups


def compute_auc_score(y_true, y_prob):
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)


###############################################################################
# 3. POLICY SIMULATION & BENCHMARK POLICY METHODS
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain columns:
      - patient_id
      - time
      - risk_score
      - label (0 or 1)

    policy_func(patient_rows) -> treat_time (int) or None
    Returns dict of cost, precision, recall, avg_treatment_time
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

# Benchmark threshold-based searches
def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    for thr in thresholds:
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                if row['risk_score'] >= thr:
                    return int(row['time'])
            return None
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func


def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                t = int(row['time'])
                if t < time_steps and row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.8,7)):
    best_A, best_B = None, None
    best_cost, best_stats = float('inf'), None
    
    for A in A_candidates:
        for B in B_candidates:
            def policy_func(patient_rows):
                for _, row in patient_rows.iterrows():
                    t = row['time']
                    thr = A*t + B
                    thr = np.clip(thr,0,1)
                    if row['risk_score'] >= thr:
                        return int(t)
                return None
            stats = simulate_policy(df, policy_func)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A,best_B), best_stats

def make_linear_threshold_policy(A,B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            final_t = patient_rows['time'].max()
            final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
            if final_row['risk_score'] >= thr:
                return int(final_t)
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func


###############################################################################
# 4. DATA-DRIVEN DP (UNCONSTRAINED)
###############################################################################
def to_bucket(prob):
    """Map a probability into a bucket 0..4."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows= grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom>0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T
    for b in range(n_buckets):
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from T-1 down to 0
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """Return a policy function that treats if pi[t,b]==1 at time t."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func


###############################################################################
# 5. (REFERENCE) ALGORITHM 0 FOR COMPARISON
###############################################################################
def train_and_select_best_model(X_train, y_train, X_val, y_val):
    """
    Trains multiple models (RF, GB, CatBoost) over small hyperparam grids,
    picks best by AUC. (Used by Algorithm 0 only.)
    """
    best_auc = -1.0
    best_model = None
    best_name  = None
    
    # Quick small grids:
    RF_PARAM_GRID = {
        'n_estimators': [50, 100],
        'max_depth': [3, 5]
    }
    GB_PARAM_GRID = {
        'n_estimators': [50, 100],
        'learning_rate': [0.05, 0.1],
        'max_depth': [3, 5]
    }
    CATBOOST_PARAM_GRID = {
        'iterations': [50, 100],
        'learning_rate': [0.05, 0.1],
        'depth': [3, 5]
    }
    
    for params in ParameterGrid(RF_PARAM_GRID):
        rf = RandomForestClassifier(random_state=0, **params)
        rf.fit(X_train, y_train)
        val_prob = rf.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = rf
            best_name  = f"RandomForest_{params}"
    
    for params in ParameterGrid(GB_PARAM_GRID):
        gb = GradientBoostingClassifier(random_state=0, **params)
        gb.fit(X_train, y_train)
        val_prob = gb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = gb
            best_name  = f"GradientBoosting_{params}"
    
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        cb = CatBoostClassifier(verbose=0, random_state=0, **params)
        cb.fit(X_train, y_train)
        val_prob = cb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = cb
            best_name  = f"CatBoost_{params}"
    
    return best_model, best_auc, best_name

def run_algorithm0_unconstrained(df_all, seed=0):
    G1, G2, G3, G4 = split_into_four_groups(df_all, seed=seed)

    X_train = G1[['EIT','NIRS','EIS']].values
    y_train = G1['label'].values
    X_val   = G2[['EIT','NIRS','EIS']].values
    y_val   = G2['label'].values
    
    best_model, best_auc, best_name = train_and_select_best_model(
        X_train, y_train, X_val, y_val
    )
    
    G12 = pd.concat([G1, G2], ignore_index=True)
    X_12 = G12[['EIT','NIRS','EIS']].values
    y_12 = G12['label'].values
    best_model.fit(X_12, y_12)
    
    # Prepare G3, G12
    G3 = G3.copy()
    prob_3 = best_model.predict_proba(G3[['EIT','NIRS','EIS']])[:,1]
    G3['risk_score'] = prob_3
    
    G12 = G12.copy()
    prob_12 = best_model.predict_proba(G12[['EIT','NIRS','EIS']])[:,1]
    G12['risk_score'] = prob_12
    
    # TUNE threshold-based on G3
    thr_const_g3, _ = constant_threshold_search(G3)
    thr_vec_g3, _    = dynamic_threshold_random_search(G3, time_steps=T_MAX)
    (A_lin_g3, B_lin_g3), _ = linear_threshold_search(G3)
    thr_wte_g3, _    = wait_till_end_search(G3)
    
    # TUNE DP discount factor on G3
    G12['risk_bucket'] = G12['risk_score'].apply(to_bucket)
    p_trans, p_sick    = estimate_transition_and_sick_probs(G12, T=T_MAX, n_buckets=5)
    
    G3_dp = G3.copy()
    G3_dp['risk_bucket'] = G3_dp['risk_score'].apply(to_bucket)
    
    GAMMA_CANDIDATES = [0.95, 0.99]
    best_gamma = None
    best_cost_dp = float('inf')
    best_V = None
    best_pi= None
    
    for gamma_ in GAMMA_CANDIDATES:
        V_temp, pi_temp = train_data_driven_dp_unconstrained(
            p_trans, p_sick,
            FP=FP_COST, FN=FN_COST, D=D_COST,
            gamma=gamma_, T=T_MAX
        )
        dp_policy_temp = make_dp_policy(V_temp, pi_temp, T=T_MAX)
        stats_temp = simulate_policy(G3_dp, dp_policy_temp)
        
        if stats_temp['cost'] < best_cost_dp:
            best_cost_dp = stats_temp['cost']
            best_gamma   = gamma_
            best_V       = V_temp
            best_pi      = pi_temp
    
    # Evaluate all on G4
    G4 = G4.copy()
    prob_4 = best_model.predict_proba(G4[['EIT','NIRS','EIS']])[:,1]
    G4['risk_score'] = prob_4
    
    policy_const = make_constant_threshold_policy(thr_const_g3)
    stats_const  = simulate_policy(G4, policy_const)
    
    policy_dyn   = make_dynamic_threshold_policy(thr_vec_g3)
    stats_dyn    = simulate_policy(G4, policy_dyn)
    
    policy_lin   = make_linear_threshold_policy(A_lin_g3, B_lin_g3)
    stats_lin    = simulate_policy(G4, policy_lin)
    
    policy_wte   = make_wait_till_end_policy(thr_wte_g3)
    stats_wte    = simulate_policy(G4, policy_wte)
    
    dp_policy_final = make_dp_policy(best_V, best_pi, T=T_MAX)
    G4_dp = G4.copy()
    G4_dp['risk_bucket'] = G4_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy(G4_dp, dp_policy_final)
    
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'Dynamic Threshold-DP'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    return table


###############################################################################
# 6. ALGORITHM 5 (SPSA) IMPLEMENTATION
###############################################################################
def spsa_optimize(cost_func, param_init, n_iter=20,
                  alpha=0.602, gamma=0.101, a=0.1, c=0.1, seed=0):
    rng = np.random.RandomState(seed)
    p = param_init.copy()
    best_p = p.copy()
    best_cost = cost_func(p)
    
    for k in range(1, n_iter+1):
        ak = a / (k**alpha)
        ck = c / (k**gamma)
        delta = rng.choice([-1,1], size=len(p))
        
        p_plus  = p + ck*delta
        p_minus = p - ck*delta
        
        cost_plus  = cost_func(p_plus)
        cost_minus = cost_func(p_minus)
        
        g_approx = (cost_plus - cost_minus)/(2.0*ck) * delta
        
        p = p - ak*g_approx
        
        current_cost = cost_func(p)
        if current_cost < best_cost:
            best_cost = current_cost
            best_p = p.copy()
    
    return best_p, best_cost


def parse_spsa_params(params):
    """
    Map a real vector `params` into discrete hyperparams + DP gamma:
      params[0]: model_type (0=RF,1=GB,2=CB)
      params[1]: n_estimators [10..200]
      params[2]: learning_rate [0.01..0.2]
      params[3]: max_depth [2..10]
      params[4]: gamma [0.90..0.999]
    """
    p0 = int(round(np.clip(params[0], 0, 2)))
    p1 = int(round(np.clip(params[1], 10, 200)))
    p2 = float(np.clip(params[2], 0.01, 0.2))
    p3 = int(round(np.clip(params[3], 2, 10)))
    p4 = float(np.clip(params[4], 0.90, 0.999))
    return (p0, p1, p2, p3, p4)


def spsa_cost_function(param_vector, df_train, df_val):
    """
    Decision-aware cost:
      1) Parse param_vector => (model_type, n_estimators, learning_rate, max_depth, gamma)
      2) Train model on df_train
      3) Predict risk on df_val
      4) Build DP with gamma (from param_vector) using transitions from df_train
      5) Evaluate cost on df_val
    """
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(param_vector)

    X_tr = df_train[['EIT','NIRS','EIS']].values
    y_tr = df_train['label'].values
    
    if model_type == 0:
        clf = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                         max_depth=m_depth, random_state=0)
    else:
        clf = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                 verbose=0, random_state=0)
    clf.fit(X_tr, y_tr)
    
    X_val = df_val[['EIT','NIRS','EIS']].values
    prob_val = clf.predict_proba(X_val)[:,1]
    df_val_ = df_val.copy()
    df_val_['risk_score'] = prob_val
    
    # DP transitions from df_train
    df_tr_ = df_train.copy()
    train_probs = clf.predict_proba(df_tr_[['EIT','NIRS','EIS']])[:,1]
    df_tr_['risk_score'] = train_probs
    df_tr_['risk_bucket'] = df_tr_['risk_score'].apply(to_bucket)
    
    p_trans, p_sick = estimate_transition_and_sick_probs(df_tr_, T=T_MAX, n_buckets=5)
    V, pi_ = train_data_driven_dp_unconstrained(
        p_trans, p_sick,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    
    df_val_['risk_bucket'] = df_val_['risk_score'].apply(to_bucket)
    dp_policy = make_dp_policy(V, pi_, T=T_MAX)
    stats = simulate_policy(df_val_, dp_policy)
    return stats['cost']


def run_algorithm5_spsa_unconstrained(df_all, k=3, seed=0, n_spsa_iter=20):
    """
    SPSA Hyper-Parameter Tuning (Decision-Aware).
    
    1) Split data into k+1 groups: G1..Gk, G_{k+1}
    2) For each fold i in [1..k]:
         - define cost function that trains on G\\G_i, evaluates cost on G_i
         - run SPSA to find best param_i
    3) Among param_1..param_k, pick best overall param* with minimal sum of costs across folds
    4) Evaluate param* on G_{k+1}.
    5) For final report, we also show the five benchmark policies on the holdout
       *using the same final trained model from param* (and its DP discount factor).
    """
    groups = k_plus_1_splits(df_all, k=k, seed=seed)
    # groups[0..k-1] => folds for cross-validation
    # groups[k]      => final holdout
    
    # Step 2: For each fold, run SPSA
    param_init = np.array([1.0, 50.0, 0.05, 3.0, 0.95], dtype=float)
    
    spsa_solutions = []
    for i in range(1, k+1):
        df_val = groups[i-1]
        df_train_list = [groups[j] for j in range(k) if j != (i-1)]
        df_train_ = pd.concat(df_train_list, ignore_index=True)
        
        def fold_cost_func(p):
            return spsa_cost_function(p, df_train_, df_val)
        
        best_p_fold, best_c_fold = spsa_optimize(
            fold_cost_func,
            param_init,
            n_iter=n_spsa_iter,
            alpha=0.602,
            gamma=0.101,
            a=0.1,
            c=0.1,
            seed=seed+i
        )
        spsa_solutions.append( (best_p_fold, best_c_fold) )
    
    # Step 3: Among these k solutions, pick best overall param*
    # with minimal sum of costs across all k folds
    k_ = k
    fold_cost_matrix = np.zeros((k_, k_), dtype=float)
    for i in range(k_):
        param_i = spsa_solutions[i][0]
        for j in range(k_):
            df_val_j = groups[j]
            df_train_j_list = [groups[m] for m in range(k_) if m != j]
            df_train_j = pd.concat(df_train_j_list, ignore_index=True)
            c_ij = spsa_cost_function(param_i, df_train_j, df_val_j)
            fold_cost_matrix[i,j] = c_ij
    
    total_cost_per_param = fold_cost_matrix.sum(axis=1)
    best_index = np.argmin(total_cost_per_param)
    best_param = spsa_solutions[best_index][0]
    
    # Step 4: Evaluate best_param on final holdout G_{k+1}
    df_holdout = groups[k]
    df_train_for_holdout = pd.concat(groups[:k], ignore_index=True)
    
    # We'll train the final ML model + DP discount factor = best_param,
    # then produce a table of 5 methods on the holdout:
    #   1) Constant Threshold
    #   2) Dynamic Threshold-R
    #   3) Linear Threshold
    #   4) Wait Till End
    #   5) Dynamic Threshold-DP (SPSA)
    
    # (A) Train final ML model using best_param on the union of the k folds
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(best_param)

    # Build the classifier
    X_train2 = df_train_for_holdout[['EIT','NIRS','EIS']].values
    y_train2 = df_train_for_holdout['label'].values
    
    if model_type == 0:
        clf_final = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf_final = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                               max_depth=m_depth, random_state=0)
    else:
        clf_final = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                       verbose=0, random_state=0)
    clf_final.fit(X_train2, y_train2)
    
    # (B) For DP transitions, we also need risk scores on df_train_for_holdout
    df_train2 = df_train_for_holdout.copy()
    train_probs2 = clf_final.predict_proba(df_train2[['EIT','NIRS','EIS']])[:,1]
    df_train2['risk_score'] = train_probs2
    df_train2['risk_bucket'] = df_train2['risk_score'].apply(to_bucket)
    p_trans2, p_sick2 = estimate_transition_and_sick_probs(df_train2, T=T_MAX, n_buckets=5)
    
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans2, p_sick2,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    # (C) Generate risk scores for holdout
    df_holdout2 = df_holdout.copy()
    holdout_probs = clf_final.predict_proba(df_holdout2[['EIT','NIRS','EIS']])[:,1]
    df_holdout2['risk_score'] = holdout_probs
    
    # -- Now we do threshold-based searches on the "training" set or holdout?
    # Typically you'd tune thresholds on a validation set. For demonstration,
    # let's do exactly what Algorithm 0 does: tune threshold on "training",
    # then evaluate on holdout.
    
    # TUNE threshold-based on df_train2:
    thr_const, _ = constant_threshold_search(df_train2)
    thr_vec, _    = dynamic_threshold_random_search(df_train2, time_steps=T_MAX)
    (A_lin, B_lin), _ = linear_threshold_search(df_train2)
    thr_wte, _    = wait_till_end_search(df_train2)
    
    # Evaluate all 5 on holdout
    # 1) Constant Threshold
    pol_const = make_constant_threshold_policy(thr_const)
    stats_const = simulate_policy(df_holdout2, pol_const)
    
    # 2) Dynamic Threshold-R
    pol_dyn = make_dynamic_threshold_policy(thr_vec)
    stats_dyn = simulate_policy(df_holdout2, pol_dyn)
    
    # 3) Linear Threshold
    pol_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin = simulate_policy(df_holdout2, pol_lin)
    
    # 4) Wait Till End
    pol_wte = make_wait_till_end_policy(thr_wte)
    stats_wte = simulate_policy(df_holdout2, pol_wte)
    
    # 5) Dynamic Threshold-DP (SPSA)
    df_holdout2_dp = df_holdout2.copy()
    df_holdout2_dp['risk_bucket'] = df_holdout2_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy(df_holdout2_dp, dp_policy_final)
    
    # Build final table
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'Dynamic Threshold-DP (SPSA)'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table


###############################################################################
# 7. MAIN
###############################################################################
def main():
    # Load synthetic data
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    # Filter to time < T_MAX if needed
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # Check required columns
    required = {'patient_id','time','EIT','NIRS','EIS','label'}
    if not required.issubset(df_all.columns):
        raise ValueError(f"Your CSV must have columns at least: {required}. Found: {df_all.columns}")
    

    
    #-----------------------------------------------------------------------
    # RUN ALGORITHM 5 (SPSA) - DECISION-AWARE
    # Produce the same kind of table with 5 policy lines
    #-----------------------------------------------------------------------
    table_alg5 = run_algorithm5_spsa_unconstrained(df_all, k=3, seed=42, n_spsa_iter=20)
    print("\n=== ALGORITHM 5 (SPSA) RESULTS (Unconstrained) ===")
    print(table_alg5.to_string(index=False))


if __name__ == "__main__":
    main()


=== ALGORITHM 5 (SPSA) RESULTS (Unconstrained) ===
                     Method  Cost  Precision (%)  Recall (%)  Avg Treat Time
         Constant Threshold   399      51.785714  100.000000        2.303571
        Dynamic Threshold-R   448      47.540984  100.000000        8.065574
           Linear Threshold   448      46.774194  100.000000        2.064516
              Wait Till End   610     100.000000   96.551724       20.000000
Dynamic Threshold-DP (SPSA)   192      87.878788  100.000000        5.666667


In [15]:
"""
ALGORITHM 5 (SPSA) FOR UNCONSTRAINED HEMORRHAGE DIAGNOSIS & TREATMENT
- Multi-run version, aggregating mean ± std results across replicates.

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 1
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)

###############################################################################
# 2. HELPER FUNCTIONS (DATA SPLIT, METRICS, ETC.)
###############################################################################
def split_into_four_groups(df, seed=0):
    """
    A quick 4-group split. (Not used by Algorithm 5 directly, 
    but shown here for reference if needed.)
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    i1 = int(0.25 * n)
    i2 = int(0.50 * n)
    i3 = int(0.75 * n)
    
    G1_pids = unique_pids[: i1]
    G2_pids = unique_pids[i1 : i2]
    G3_pids = unique_pids[i2 : i3]
    G4_pids = unique_pids[i3 : ]
    
    G1 = df[df['patient_id'].isin(G1_pids)].copy()
    G2 = df[df['patient_id'].isin(G2_pids)].copy()
    G3 = df[df['patient_id'].isin(G3_pids)].copy()
    G4 = df[df['patient_id'].isin(G4_pids)].copy()
    return G1, G2, G3, G4

def k_plus_1_splits(df, k=3, seed=0):
    """
    For Algorithm 5 (SPSA), we often split into (k+1) groups:
      G1, G2, ..., Gk, G_{k+1}.
    The first k groups are used like cross-validation folds;
    the (k+1)-th is the final holdout.
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    fold_size = int(n/(k+1))
    
    groups = []
    start = 0
    for i in range(k):
        pids_fold = unique_pids[start : start+fold_size]
        start += fold_size
        g = df[df['patient_id'].isin(pids_fold)].copy()
        groups.append(g)
    
    # final group is everything leftover
    pids_fold = unique_pids[start:]
    g = df[df['patient_id'].isin(pids_fold)].copy()
    groups.append(g)
    return groups

def compute_auc_score(y_true, y_prob):
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

###############################################################################
# 3. POLICY SIMULATION & BENCHMARK POLICY METHODS
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain columns:
      - patient_id
      - time
      - risk_score
      - label (0 or 1)

    policy_func(patient_rows) -> treat_time (int) or None
    Returns dict of cost, precision, recall, avg_treatment_time
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

# Benchmark threshold-based searches
def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    for thr in thresholds:
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                if row['risk_score'] >= thr:
                    return int(row['time'])
            return None
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                t = int(row['time'])
                if t < time_steps and row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.8,7)):
    best_A, best_B = None, None
    best_cost, best_stats = float('inf'), None
    
    for A in A_candidates:
        for B in B_candidates:
            def policy_func(patient_rows):
                for _, row in patient_rows.iterrows():
                    t = row['time']
                    thr = A*t + B
                    thr = np.clip(thr,0,1)
                    if row['risk_score'] >= thr:
                        return int(t)
                return None
            stats = simulate_policy(df, policy_func)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A,best_B), best_stats

def make_linear_threshold_policy(A,B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            final_t = patient_rows['time'].max()
            final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
            if final_row['risk_score'] >= thr:
                return int(final_t)
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func

###############################################################################
# 4. DATA-DRIVEN DP (UNCONSTRAINED)
###############################################################################
def to_bucket(prob):
    """Map a probability into a bucket 0..4."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows= grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom>0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T
    for b in range(n_buckets):
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from T-1 down to 0
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """Return a policy function that treats if pi[t,b]==1 at time t."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# 5. SPSA IMPLEMENTATION
###############################################################################
def spsa_optimize(cost_func, param_init, n_iter=20,
                  alpha=0.602, gamma=0.101, a=0.1, c=0.1, seed=0):
    """
    Basic SPSA routine for iterative approximate gradient-based optimization.
    """
    rng = np.random.RandomState(seed)
    p = param_init.copy()
    best_p = p.copy()
    best_cost = cost_func(p)
    
    for k in range(1, n_iter+1):
        ak = a / (k**alpha)
        ck = c / (k**gamma)
        delta = rng.choice([-1,1], size=len(p))
        
        p_plus  = p + ck*delta
        p_minus = p - ck*delta
        
        cost_plus  = cost_func(p_plus)
        cost_minus = cost_func(p_minus)
        
        g_approx = (cost_plus - cost_minus)/(2.0*ck) * delta
        
        p = p - ak*g_approx
        
        current_cost = cost_func(p)
        if current_cost < best_cost:
            best_cost = current_cost
            best_p = p.copy()
    
    return best_p, best_cost

def parse_spsa_params(params):
    """
    Map a real vector `params` into discrete hyperparams + DP gamma:
      params[0]: model_type (0=RF,1=GB,2=CB)
      params[1]: n_estimators [10..200]
      params[2]: learning_rate [0.01..0.2]
      params[3]: max_depth [2..10]
      params[4]: gamma [0.90..0.999]
    """
    p0 = int(round(np.clip(params[0], 0, 2)))
    p1 = int(round(np.clip(params[1], 10, 200)))
    p2 = float(np.clip(params[2], 0.01, 0.2))
    p3 = int(round(np.clip(params[3], 2, 10)))
    p4 = float(np.clip(params[4], 0.90, 0.999))
    return (p0, p1, p2, p3, p4)

def spsa_cost_function(param_vector, df_train, df_val):
    """
    Decision-aware cost:
      1) Parse param_vector => (model_type, n_estimators, learning_rate, max_depth, gamma)
      2) Train model on df_train
      3) Predict risk on df_val
      4) Build DP with gamma using transitions from df_train
      5) Evaluate cost on df_val
    """
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(param_vector)

    X_tr = df_train[['EIT','NIRS','EIS']].values
    y_tr = df_train['label'].values
    
    if model_type == 0:
        clf = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                         max_depth=m_depth, random_state=0)
    else:
        clf = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                 verbose=0, random_state=0)
    clf.fit(X_tr, y_tr)
    
    X_val = df_val[['EIT','NIRS','EIS']].values
    prob_val = clf.predict_proba(X_val)[:,1]
    df_val_ = df_val.copy()
    df_val_['risk_score'] = prob_val
    
    # DP transitions from df_train
    df_tr_ = df_train.copy()
    train_probs = clf.predict_proba(df_tr_[['EIT','NIRS','EIS']])[:,1]
    df_tr_['risk_score'] = train_probs
    df_tr_['risk_bucket'] = df_tr_['risk_score'].apply(to_bucket)
    
    p_trans, p_sick = estimate_transition_and_sick_probs(df_tr_, T=T_MAX, n_buckets=5)
    V, pi_ = train_data_driven_dp_unconstrained(
        p_trans, p_sick,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    
    df_val_['risk_bucket'] = df_val_['risk_score'].apply(to_bucket)
    dp_policy = make_dp_policy(V, pi_, T=T_MAX)
    stats = simulate_policy(df_val_, dp_policy)
    return stats['cost']

def run_algorithm5_spsa_unconstrained(df_all, k=3, seed=0, n_spsa_iter=20):
    """
    SPSA Hyper-Parameter Tuning (Decision-Aware).
    
    1) Split data into k+1 groups: G1..Gk, G_{k+1}
    2) For each fold i in [1..k]:
         - define cost function that trains on G\\G_i, evaluates cost on G_i
         - run SPSA to find best param_i
    3) Among param_1..param_k, pick best overall param* with minimal sum of costs across folds
    4) Evaluate param* on G_{k+1}.
    5) For final report, we also show the five benchmark policies on the holdout
       *using the same final trained model from param* (and its DP discount factor).
    """
    # (0) Filter to time < T_MAX if needed
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # (1) Create k+1 splits
    groups = k_plus_1_splits(df_all, k=k, seed=seed)
    # groups[0..k-1] => folds for cross-validation
    # groups[k]      => final holdout
    
    # (2) For each fold, run SPSA to get a candidate param_i
    param_init = np.array([1.0, 50.0, 0.05, 3.0, 0.95], dtype=float)
    
    spsa_solutions = []
    for i in range(1, k+1):
        df_val = groups[i-1]
        df_train_list = [groups[j] for j in range(k) if j != (i-1)]
        df_train_ = pd.concat(df_train_list, ignore_index=True)
        
        def fold_cost_func(p):
            return spsa_cost_function(p, df_train_, df_val)
        
        best_p_fold, best_c_fold = spsa_optimize(
            fold_cost_func,
            param_init,
            n_iter=n_spsa_iter,
            alpha=0.602,
            gamma=0.101,
            a=0.1,
            c=0.1,
            seed=seed+i
        )
        spsa_solutions.append( (best_p_fold, best_c_fold) )
    
    # (3) Among these k solutions, pick best overall param*
    # with minimal sum of costs across all k folds
    k_ = k
    fold_cost_matrix = np.zeros((k_, k_), dtype=float)
    for i in range(k_):
        param_i = spsa_solutions[i][0]
        for j in range(k_):
            df_val_j = groups[j]
            df_train_j_list = [groups[m] for m in range(k_) if m != j]
            df_train_j = pd.concat(df_train_j_list, ignore_index=True)
            c_ij = spsa_cost_function(param_i, df_train_j, df_val_j)
            fold_cost_matrix[i,j] = c_ij
    
    total_cost_per_param = fold_cost_matrix.sum(axis=1)
    best_index = np.argmin(total_cost_per_param)
    best_param = spsa_solutions[best_index][0]
    
    # (4) Evaluate best_param on final holdout G_{k+1}
    df_holdout = groups[k]
    df_train_for_holdout = pd.concat(groups[:k], ignore_index=True)
    
    # We'll train the final ML model + DP discount factor = best_param,
    # then produce a table of 5 methods on the holdout
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(best_param)

    # Build the classifier on df_train_for_holdout
    X_train2 = df_train_for_holdout[['EIT','NIRS','EIS']].values
    y_train2 = df_train_for_holdout['label'].values
    
    if model_type == 0:
        clf_final = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf_final = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                               max_depth=m_depth, random_state=0)
    else:
        clf_final = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                       verbose=0, random_state=0)
    clf_final.fit(X_train2, y_train2)
    
    # DP transitions from df_train_for_holdout
    df_train2 = df_train_for_holdout.copy()
    train_probs2 = clf_final.predict_proba(df_train2[['EIT','NIRS','EIS']])[:,1]
    df_train2['risk_score'] = train_probs2
    df_train2['risk_bucket'] = df_train2['risk_score'].apply(to_bucket)
    p_trans2, p_sick2 = estimate_transition_and_sick_probs(df_train2, T=T_MAX, n_buckets=5)
    
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans2, p_sick2,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    # Generate risk scores for holdout
    df_holdout2 = df_holdout.copy()
    holdout_probs = clf_final.predict_proba(df_holdout2[['EIT','NIRS','EIS']])[:,1]
    df_holdout2['risk_score'] = holdout_probs
    
    # TUNE threshold-based policies on df_train2 for consistency
    thr_const, _ = constant_threshold_search(df_train2)
    thr_vec, _    = dynamic_threshold_random_search(df_train2, time_steps=T_MAX)
    (A_lin, B_lin), _ = linear_threshold_search(df_train2)
    thr_wte, _    = wait_till_end_search(df_train2)
    
    # Evaluate all on holdout
    # 1) Constant Threshold
    pol_const = make_constant_threshold_policy(thr_const)
    stats_const = simulate_policy(df_holdout2, pol_const)
    
    # 2) Dynamic Threshold-R
    pol_dyn = make_dynamic_threshold_policy(thr_vec)
    stats_dyn = simulate_policy(df_holdout2, pol_dyn)
    
    # 3) Linear Threshold
    pol_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin = simulate_policy(df_holdout2, pol_lin)
    
    # 4) Wait Till End
    pol_wte = make_wait_till_end_policy(thr_wte)
    stats_wte = simulate_policy(df_holdout2, pol_wte)
    
    # 5) Dynamic Threshold-DP (SPSA)
    df_holdout2_dp = df_holdout2.copy()
    df_holdout2_dp['risk_bucket'] = df_holdout2_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy(df_holdout2_dp, dp_policy_final)
    
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'Dynamic Threshold-DP (SPSA)'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table

###############################################################################
# 6. MAIN - RUN MULTIPLE REPLICATES AND AGGREGATE
###############################################################################
def main():
    # Number of replicates to run
    NUM_REPLICATES = 30
    
    # Load synthetic data
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    # Optionally filter to time < T_MAX
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # Check required columns
    required = {'patient_id','time','EIT','NIRS','EIS','label'}
    if not required.issubset(df_all.columns):
        raise ValueError(
            f"Your CSV must have columns at least: {required}. Found: {df_all.columns}"
        )
    
    # Run ALGORITHM 5 (SPSA) multiple times with different seeds
    all_tables = []
    for rep in range(NUM_REPLICATES):
        seed_for_this_run = 412 + rep
        print(f"\n=== Running replicate {rep+1}/{NUM_REPLICATES} (seed={seed_for_this_run}) ===")
        
        result_table = run_algorithm5_spsa_unconstrained(
            df_all, 
            k=3, 
            seed=seed_for_this_run, 
            n_spsa_iter=20
        )
        all_tables.append(result_table)
    
    # Combine all replicate results into one DataFrame
    combined_df = pd.concat(all_tables, ignore_index=True)
    
    # Group by 'Method' and compute mean ± std for each numeric column
    grouped = combined_df.groupby('Method')
    
    final_rows = []
    for method, group_data in grouped:
        cost_mean = group_data['Cost'].mean()
        cost_std  = group_data['Cost'].std()
        
        prec_mean = group_data['Precision (%)'].mean()
        prec_std  = group_data['Precision (%)'].std()
        
        rec_mean  = group_data['Recall (%)'].mean()
        rec_std   = group_data['Recall (%)'].std()
        
        time_mean = group_data['Avg Treat Time'].mean()
        time_std  = group_data['Avg Treat Time'].std()
        
        final_rows.append({
            'Method': method,
            'Cost': f"{cost_mean:.2f} ± {cost_std:.2f}",
            'Precision (%)': f"{prec_mean:.2f} ± {prec_std:.2f}",
            'Recall (%)': f"{rec_mean:.2f} ± {rec_std:.2f}",
            'Avg Treat Time': f"{time_mean:.2f} ± {time_std:.2f}"
        })
    
    final_df = pd.DataFrame(final_rows)
    
    print("\n=== ALGORITHM 5 (SPSA) - MULTI-REPLICATE RESULTS (Unconstrained) ===")
    print(f"Ran {NUM_REPLICATES} replicates. Aggregated (mean ± std) results:")
    print(final_df.to_string(index=False))

if __name__ == "__main__":
    main()


=== Running replicate 1/30 (seed=412) ===

=== Running replicate 2/30 (seed=413) ===

=== Running replicate 3/30 (seed=414) ===

=== Running replicate 4/30 (seed=415) ===

=== Running replicate 5/30 (seed=416) ===

=== Running replicate 6/30 (seed=417) ===

=== Running replicate 7/30 (seed=418) ===

=== Running replicate 8/30 (seed=419) ===

=== Running replicate 9/30 (seed=420) ===

=== Running replicate 10/30 (seed=421) ===

=== Running replicate 11/30 (seed=422) ===

=== Running replicate 12/30 (seed=423) ===

=== Running replicate 13/30 (seed=424) ===

=== Running replicate 14/30 (seed=425) ===

=== Running replicate 15/30 (seed=426) ===

=== Running replicate 16/30 (seed=427) ===

=== Running replicate 17/30 (seed=428) ===

=== Running replicate 18/30 (seed=429) ===

=== Running replicate 19/30 (seed=430) ===

=== Running replicate 20/30 (seed=431) ===

=== Running replicate 21/30 (seed=432) ===

=== Running replicate 22/30 (seed=433) ===

=== Running replicate 23/30 (seed=434) =