In [3]:
"""
ALGORITHM 4 (SPSA) + BENCHMARK TABLE
WITH 50% CAP ON SICK PATIENTS .

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 1
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)

###############################################################################
# 2. HELPER FUNCTIONS
###############################################################################
def split_into_four_groups(df, seed=0):
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    i1 = int(0.25 * n)
    i2 = int(0.50 * n)
    i3 = int(0.75 * n)
    
    G1_pids = unique_pids[: i1]
    G2_pids = unique_pids[i1 : i2]
    G3_pids = unique_pids[i2 : i3]
    G4_pids = unique_pids[i3 : ]
    
    G1 = df[df['patient_id'].isin(G1_pids)].copy()
    G2 = df[df['patient_id'].isin(G2_pids)].copy()
    G3 = df[df['patient_id'].isin(G3_pids)].copy()
    G4 = df[df['patient_id'].isin(G4_pids)].copy()
    return G1, G2, G3, G4

def k_plus_1_splits(df, k=3, seed=0):
    """
    For Algorithm 5 (SPSA), we often split the data into k+1 groups:
      G1, G2, ..., Gk, G_{k+1}.
    This is a k-fold style partition plus one final holdout.
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    fold_size = int(n/(k+1))
    
    groups = []
    start = 0
    for i in range(k):
        pids_fold = unique_pids[start : start+fold_size]
        start += fold_size
        g = df[df['patient_id'].isin(pids_fold)].copy()
        groups.append(g)
    
    # final group is everything leftover
    pids_fold = unique_pids[start:]
    g = df[df['patient_id'].isin(pids_fold)].copy()
    groups.append(g)
    return groups

def compute_auc_score(y_true, y_prob):
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

###############################################################################
# 3. POLICY SIMULATION (WITH & WITHOUT CAPACITY)
###############################################################################
def simulate_policy(df, policy_func):
    """
    Unconstrained simulation (no limit on how many sick get treated).
    df must have columns: patient_id, time, risk_score, label.
    policy_func(patient_rows) -> treat_time (int) or None
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }


def simulate_policy_with_sick_capacity(df, policy_func, capacity_frac=0.5):
    """
    Enforce that at most (capacity_frac) fraction of the *sick* patients
    can be treated (as in Algorithm 0 for G4).
    
    Steps:
      1. Identify which patients are "recommended" for treatment by the policy_func.
      2. Separate them into recommended_sick vs. recommended_healthy.
      3. Among recommended_sick, only treat the top floor(capacity_frac * total_sick) 
         by risk_score. The rest remain untreated.
      4. All recommended_healthy are treated (no limit).
      5. Everyone else is not treated, incurring FN cost if sick, 0 if healthy.
    """
    results = []
    recommended_sick = []     # (pid, label=1, time_treated, risk_score)
    recommended_healthy = []  # (pid, label=0, time_treated, risk_score)
    
    # 1) Count total sick
    all_sick_df = df[df['label']==1]
    num_sick = all_sick_df['patient_id'].nunique()
    capacity_num = int(np.floor(capacity_frac * num_sick)) if num_sick>0 else 0
    
    # Gather recommended patients
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # not recommended -> fill cost later
            results.append({
                'patient_id': pid,
                'label': label,
                'treated': 0,
                'treat_time': None,
                'cost': None,
                'tp': 0,
                'fp': 0
            })
        else:
            # recommended
            row_t = grp[grp['time']==treat_time].iloc[0]
            recommended_risk = row_t['risk_score']
            if label == 1:
                recommended_sick.append((pid, label, treat_time, recommended_risk))
            else:
                recommended_healthy.append((pid, label, treat_time, recommended_risk))
    
    # 2) Sort recommended sick by descending risk_score
    recommended_sick.sort(key=lambda x: x[3], reverse=True)
    
    # 3) Actually treat only top capacity_num from recommended_sick
    treat_sick_subset = recommended_sick[:capacity_num]
    not_treat_sick_subset = recommended_sick[capacity_num:]
    
    # 4) Treat all recommended healthy
    treat_healthy_subset = recommended_healthy
    
    # Build final records for the "treated" subsets
    treat_results = []
    
    # SICK actually treated
    for (pid, label, ttime, rsk) in treat_sick_subset:
        cost_ = D_COST * ttime   # sick => cost is D * ttime
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 1,
            'fp': 0
        })
    
    # HEALTHY actually treated
    for (pid, label, ttime, rsk) in treat_healthy_subset:
        cost_ = FP_COST  # healthy => false positive cost
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 0,
            'fp': 1
        })
    
    # Build final records for "not-treated" subsets
    not_treat_results = []
    
    # recommended sick but not in top capacity
    for (pid, label, ttime, rsk) in not_treat_sick_subset:
        cost_ = FN_COST  # sick but not treated
        not_treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 0,
            'treat_time': None,
            'cost': cost_,
            'tp': 0,
            'fp': 0
        })
    
    # never recommended (cost=None above)
    for row in results:
        if row['cost'] is None:
            if row['label'] == 1:
                row['cost'] = FN_COST
            else:
                row['cost'] = 0
            not_treat_results.append(row)
    
    # Combine
    df_res = pd.DataFrame(treat_results + not_treat_results)
    
    # Compute final stats
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# 4. BENCHMARK POLICIES (Threshold-based, DP, etc.)
###############################################################################
def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    for thr in thresholds:
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                if row['risk_score'] >= thr:
                    return int(row['time'])
            return None
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                t = int(row['time'])
                if t < time_steps and row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.8,7)):
    best_A, best_B = None, None
    best_cost, best_stats = float('inf'), None
    
    for A in A_candidates:
        for B in B_candidates:
            def policy_func(patient_rows):
                for _, row in patient_rows.iterrows():
                    t = row['time']
                    thr = A*t + B
                    thr = np.clip(thr,0,1)
                    if row['risk_score'] >= thr:
                        return int(t)
                return None
            stats = simulate_policy(df, policy_func)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A,best_B), best_stats

def make_linear_threshold_policy(A,B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            final_t = patient_rows['time'].max()
            final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
            if final_row['risk_score'] >= thr:
                return int(final_t)
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func

###############################################################################
# 5. DATA-DRIVEN DP (Unconstrained)
###############################################################################
def to_bucket(prob):
    """Map a probability into a bucket [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows= grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom > 0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick,
                                       FP=10, FN=50, D=1,
                                       gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T
    for b in range(n_buckets):
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from T-1 down to 0
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """Return a policy function that treats if pi[t,b]==1 at time t."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# 6. SPSA ALGORITHM + COST FUNCTION
###############################################################################
def train_and_select_best_model(X_train, y_train, X_val, y_val):
    """
    (Reference from Algorithm 0)
    Trains multiple models (RF, GB, CatBoost) over small hyperparam grids,
    picks best by AUC.
    """
    best_auc = -1.0
    best_model = None
    best_name  = None
    
    RF_PARAM_GRID = {
        'n_estimators': [50, 100],
        'max_depth': [3, 5]
    }
    GB_PARAM_GRID = {
        'n_estimators': [50, 100],
        'learning_rate': [0.05, 0.1],
        'max_depth': [3, 5]
    }
    CATBOOST_PARAM_GRID = {
        'iterations': [50, 100],
        'learning_rate': [0.05, 0.1],
        'depth': [3, 5]
    }
    
    for params in ParameterGrid(RF_PARAM_GRID):
        rf = RandomForestClassifier(random_state=0, **params)
        rf.fit(X_train, y_train)
        val_prob = rf.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = rf
            best_name  = f"RandomForest_{params}"
    
    for params in ParameterGrid(GB_PARAM_GRID):
        gb = GradientBoostingClassifier(random_state=0, **params)
        gb.fit(X_train, y_train)
        val_prob = gb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = gb
            best_name  = f"GradientBoosting_{params}"
    
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        cb = CatBoostClassifier(verbose=0, random_state=0, **params)
        cb.fit(X_train, y_train)
        val_prob = cb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = cb
            best_name  = f"CatBoost_{params}"
    
    return best_model, best_auc, best_name

def spsa_optimize(cost_func, param_init, n_iter=20,
                  alpha=0.602, gamma=0.101, a=0.1, c=0.1, seed=0):
    """
    Generic SPSA optimizer to minimize a cost_func over a real vector space.
    """
    rng = np.random.RandomState(seed)
    p = param_init.copy()
    best_p = p.copy()
    best_cost = cost_func(p)
    
    for k in range(1, n_iter+1):
        ak = a / (k**alpha)
        ck = c / (k**gamma)
        delta = rng.choice([-1,1], size=len(p))
        
        p_plus  = p + ck*delta
        p_minus = p - ck*delta
        
        cost_plus  = cost_func(p_plus)
        cost_minus = cost_func(p_minus)
        
        g_approx = (cost_plus - cost_minus)/(2.0*ck) * delta
        
        p = p - ak*g_approx
        
        current_cost = cost_func(p)
        if current_cost < best_cost:
            best_cost = current_cost
            best_p = p.copy()
    
    return best_p, best_cost

def parse_spsa_params(params):
    """
    Map a real vector `params` into discrete hyperparams + DP gamma:
      params[0]: model_type (0=RF,1=GB,2=CB)
      params[1]: n_estimators [10..200]
      params[2]: learning_rate [0.01..0.2]
      params[3]: max_depth [2..10]
      params[4]: gamma [0.90..0.999]
    """
    p0 = int(round(np.clip(params[0], 0, 2)))       # model_type
    p1 = int(round(np.clip(params[1], 10, 200)))    # n_estimators
    p2 = float(np.clip(params[2], 0.01, 0.2))       # learning_rate
    p3 = int(round(np.clip(params[3], 2, 10)))      # max_depth
    p4 = float(np.clip(params[4], 0.90, 0.999))      # gamma
    return (p0, p1, p2, p3, p4)

def spsa_cost_function(param_vector, df_train, df_val):
    """
    Decision-aware cost function used by SPSA:
      1) Parse param_vector => (model_type, n_estimators, learning_rate, max_depth, gamma)
      2) Train model on df_train
      3) Predict risk on df_val
      4) Build DP with gamma (from param_vector) using transitions from df_train
      5) Evaluate cost on df_val (unconstrained).
    """
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(param_vector)

    X_tr = df_train[['EIT','NIRS','EIS']].values
    y_tr = df_train['label'].values
    
    # Build the model
    if model_type == 0:
        clf = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                         max_depth=m_depth, random_state=0)
    else:
        clf = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                 verbose=0, random_state=0)
    clf.fit(X_tr, y_tr)
    
    # Evaluate on df_val with DP policy
    X_val = df_val[['EIT','NIRS','EIS']].values
    prob_val = clf.predict_proba(X_val)[:,1]
    df_val_ = df_val.copy()
    df_val_['risk_score'] = prob_val
    
    # DP transitions from df_train
    df_tr_ = df_train.copy()
    train_probs = clf.predict_proba(df_tr_[['EIT','NIRS','EIS']])[:,1]
    df_tr_['risk_score'] = train_probs
    df_tr_['risk_bucket'] = df_tr_['risk_score'].apply(to_bucket)
    
    p_trans, p_sick = estimate_transition_and_sick_probs(df_tr_, T=T_MAX, n_buckets=5)
    V, pi_ = train_data_driven_dp_unconstrained(
        p_trans, p_sick,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    
    # Evaluate that DP policy on df_val
    df_val_['risk_bucket'] = df_val_['risk_score'].apply(to_bucket)
    dp_policy = make_dp_policy(V, pi_, T=T_MAX)
    stats = simulate_policy(df_val_, dp_policy)  # unconstrained cost
    return stats['cost']


###############################################################################
# 7. ALGORITHM 5 (SPSA) WITH 50% CAP ON FINAL HOLDOUT
###############################################################################
def run_algorithm5_spsa_with_capacity(df_all, k=3, seed=0, n_spsa_iter=20, capacity_frac=0.5):
    """
    SPSA Hyper-Parameter Tuning (Decision-Aware) + 50% cap on final holdout.
    
    1) Split data into k+1 groups: G1..Gk, G_{k+1}
    2) For each fold i in [1..k]:
         - define cost function that trains on G\\G_i, evaluates cost on G_i
         - run SPSA to find best param_i
    3) Among param_1..param_k, pick best overall param* with minimal sum-of-costs across folds
    4) Evaluate param* on G_{k+1} (the holdout),
       but we apply the 50%-cap to *sick* patients in the final holdout only.
    5) Build final table with 5 methods:
       - Constant Threshold
       - Dynamic Threshold-R
       - Linear Threshold
       - Wait Till End
       - DP-based (SPSA)
      and each one is simulated with the 50%-cap on sick.
    """
    # Split into k+1 groups
    groups = k_plus_1_splits(df_all, k=k, seed=seed)
    
    # param_init is the starting guess for SPSA hyperparams
    param_init = np.array([1.0, 50.0, 0.05, 3.0, 0.95], dtype=float)
    
    # 2) For each fold i in [1..k], run SPSA to get best param_i
    spsa_solutions = []
    for i in range(1, k+1):
        df_val = groups[i-1]
        df_train_list = [groups[j] for j in range(k) if j != (i-1)]
        df_train_ = pd.concat(df_train_list, ignore_index=True)
        
        def fold_cost_func(p):
            return spsa_cost_function(p, df_train_, df_val)
        
        best_p_fold, best_c_fold = spsa_optimize(
            fold_cost_func,
            param_init,
            n_iter=n_spsa_iter,
            alpha=0.602,
            gamma=0.101,
            a=0.1,
            c=0.1,
            seed=seed+i
        )
        spsa_solutions.append( (best_p_fold, best_c_fold) )
    
    # 3) Among these k solutions, pick best overall param*
    #    with minimal sum-of-costs across all k folds
    k_ = k
    fold_cost_matrix = np.zeros((k_, k_), dtype=float)
    for i in range(k_):
        param_i = spsa_solutions[i][0]
        for j in range(k_):
            df_val_j = groups[j]
            df_train_j_list = [groups[m] for m in range(k_) if m != j]
            df_train_j = pd.concat(df_train_j_list, ignore_index=True)
            c_ij = spsa_cost_function(param_i, df_train_j, df_val_j)
            fold_cost_matrix[i,j] = c_ij
    
    total_cost_per_param = fold_cost_matrix.sum(axis=1)
    best_index = np.argmin(total_cost_per_param)
    best_param = spsa_solutions[best_index][0]
    
    # 4) Evaluate best_param on final holdout G_{k+1}, with 50%-cap
    df_holdout = groups[k]
    df_train_for_holdout = pd.concat(groups[:k], ignore_index=True)
    
    # Parse best_param -> final ML model + DP discount factor
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(best_param)
    
    # Train final classifier on union of k folds
    X_train2 = df_train_for_holdout[['EIT','NIRS','EIS']].values
    y_train2 = df_train_for_holdout['label'].values
    
    if model_type == 0:
        clf_final = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf_final = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                               max_depth=m_depth, random_state=0)
    else:
        clf_final = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                       verbose=0, random_state=0)
    clf_final.fit(X_train2, y_train2)
    
    # Build DP transitions from df_train_for_holdout
    df_train2 = df_train_for_holdout.copy()
    train_probs2 = clf_final.predict_proba(df_train2[['EIT','NIRS','EIS']])[:,1]
    df_train2['risk_score'] = train_probs2
    df_train2['risk_bucket'] = df_train2['risk_score'].apply(to_bucket)
    p_trans2, p_sick2 = estimate_transition_and_sick_probs(df_train2, T=T_MAX, n_buckets=5)
    
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans2, p_sick2,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    # 5) Tune threshold-based policies on the same train set (df_train2)
    thr_const, _ = constant_threshold_search(df_train2)
    thr_vec, _ = dynamic_threshold_random_search(df_train2, time_steps=T_MAX)
    (A_lin, B_lin), _ = linear_threshold_search(df_train2)
    thr_wte, _ = wait_till_end_search(df_train2)
    
    # Evaluate final policies on holdout with capacity
    df_holdout2 = df_holdout.copy()
    holdout_probs = clf_final.predict_proba(df_holdout2[['EIT','NIRS','EIS']])[:,1]
    df_holdout2['risk_score'] = holdout_probs
    
    # 1) Constant threshold
    pol_const = make_constant_threshold_policy(thr_const)
    stats_const = simulate_policy_with_sick_capacity(df_holdout2, pol_const, capacity_frac=capacity_frac)
    
    # 2) Dynamic threshold - random
    pol_dyn = make_dynamic_threshold_policy(thr_vec)
    stats_dyn = simulate_policy_with_sick_capacity(df_holdout2, pol_dyn, capacity_frac=capacity_frac)
    
    # 3) Linear threshold
    pol_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin = simulate_policy_with_sick_capacity(df_holdout2, pol_lin, capacity_frac=capacity_frac)
    
    # 4) Wait till end
    pol_wte = make_wait_till_end_policy(thr_wte)
    stats_wte = simulate_policy_with_sick_capacity(df_holdout2, pol_wte, capacity_frac=capacity_frac)
    
    # 5) DP-based policy from SPSA
    df_holdout2_dp = df_holdout2.copy()
    df_holdout2_dp['risk_bucket'] = df_holdout2_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy_with_sick_capacity(df_holdout2_dp, dp_policy_final,
                                                  capacity_frac=capacity_frac)
    
    # Build final table
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'DP-based (SPSA)'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table

###############################################################################
# 8. MAIN
###############################################################################
def main():
    # Load synthetic data
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    # Filter to time < T_MAX if needed
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # Check required columns
    required = {'patient_id','time','EIT','NIRS','EIS','label'}
    if not required.issubset(df_all.columns):
        raise ValueError(f"CSV must have columns at least: {required}. Found: {df_all.columns}")
    
    # Run ALGORITHM 5 (SPSA) with 50%-cap in final holdout
    table_alg5 = run_algorithm5_spsa_with_capacity(
        df_all,
        k=3,
        seed=42,
        n_spsa_iter=20,
        capacity_frac=0.5
    )
    print("\n=== ALGORITHM 5 (SPSA) + 50% CAP ON SICK, FINAL HOLDOUT ===")
    print(table_alg5.to_string(index=False))


if __name__ == "__main__":
    main()


=== ALGORITHM 5 (SPSA) + 50% CAP ON SICK, FINAL HOLDOUT ===
             Method  Cost  Precision (%)  Recall (%)  Avg Treat Time
 Constant Threshold  1063      34.146341   48.275862        1.048780
Dynamic Threshold-R  1450       0.000000    0.000000        0.000000
   Linear Threshold  1109      29.787234   48.275862        0.829787
      Wait Till End  1030     100.000000   48.275862       20.000000
    DP-based (SPSA)   842      77.777778   48.275862        4.833333


In [5]:
"""
ALGORITHM 5 (SPSA) + BENCHMARK TABLE
WITH 50% CAP ON SICK PATIENTS -- RUNNING MULTIPLE REPLICATES.

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 1
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)

###############################################################################
# 2. HELPER FUNCTIONS
###############################################################################
def split_into_four_groups(df, seed=0):
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    i1 = int(0.25 * n)
    i2 = int(0.50 * n)
    i3 = int(0.75 * n)
    
    G1_pids = unique_pids[: i1]
    G2_pids = unique_pids[i1 : i2]
    G3_pids = unique_pids[i2 : i3]
    G4_pids = unique_pids[i3 : ]
    
    G1 = df[df['patient_id'].isin(G1_pids)].copy()
    G2 = df[df['patient_id'].isin(G2_pids)].copy()
    G3 = df[df['patient_id'].isin(G3_pids)].copy()
    G4 = df[df['patient_id'].isin(G4_pids)].copy()
    return G1, G2, G3, G4

def k_plus_1_splits(df, k=3, seed=0):
    """
    For Algorithm 5 (SPSA), we often split the data into k+1 groups:
      G1, G2, ..., Gk, G_{k+1}.
    This is a k-fold style partition plus one final holdout.
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    fold_size = int(n/(k+1))
    
    groups = []
    start = 0
    for i in range(k):
        pids_fold = unique_pids[start : start+fold_size]
        start += fold_size
        g = df[df['patient_id'].isin(pids_fold)].copy()
        groups.append(g)
    
    # final group is everything leftover
    pids_fold = unique_pids[start:]
    g = df[df['patient_id'].isin(pids_fold)].copy()
    groups.append(g)
    return groups

def compute_auc_score(y_true, y_prob):
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

###############################################################################
# 3. POLICY SIMULATION (WITH & WITHOUT CAPACITY)
###############################################################################
def simulate_policy(df, policy_func):
    """
    Unconstrained simulation (no limit on how many sick get treated).
    df must have columns: patient_id, time, risk_score, label.
    policy_func(patient_rows) -> treat_time (int) or None
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }


def simulate_policy_with_sick_capacity(df, policy_func, capacity_frac=0.5):
    """
    Enforce that at most (capacity_frac) fraction of the *sick* patients
    can be treated (as in Algorithm 0 for G4).
    
    Steps:
      1. Identify which patients are "recommended" for treatment by the policy_func.
      2. Separate them into recommended_sick vs. recommended_healthy.
      3. Among recommended_sick, only treat the top floor(capacity_frac * total_sick) 
         by risk_score. The rest remain untreated.
      4. All recommended_healthy are treated (no limit).
      5. Everyone else is not treated, incurring FN cost if sick, 0 if healthy.
    """
    results = []
    recommended_sick = []     # (pid, label=1, time_treated, risk_score)
    recommended_healthy = []  # (pid, label=0, time_treated, risk_score)
    
    # 1) Count total sick
    all_sick_df = df[df['label']==1]
    num_sick = all_sick_df['patient_id'].nunique()
    capacity_num = int(np.floor(capacity_frac * num_sick)) if num_sick>0 else 0
    
    # Gather recommended patients
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # not recommended -> fill cost later
            results.append({
                'patient_id': pid,
                'label': label,
                'treated': 0,
                'treat_time': None,
                'cost': None,
                'tp': 0,
                'fp': 0
            })
        else:
            # recommended
            row_t = grp[grp['time']==treat_time].iloc[0]
            recommended_risk = row_t['risk_score']
            if label == 1:
                recommended_sick.append((pid, label, treat_time, recommended_risk))
            else:
                recommended_healthy.append((pid, label, treat_time, recommended_risk))
    
    # 2) Sort recommended sick by descending risk_score
    recommended_sick.sort(key=lambda x: x[3], reverse=True)
    
    # 3) Actually treat only top capacity_num from recommended_sick
    treat_sick_subset = recommended_sick[:capacity_num]
    not_treat_sick_subset = recommended_sick[capacity_num:]
    
    # 4) Treat all recommended healthy
    treat_healthy_subset = recommended_healthy
    
    # Build final records for the "treated" subsets
    treat_results = []
    
    # SICK actually treated
    for (pid, label, ttime, rsk) in treat_sick_subset:
        cost_ = D_COST * ttime   # sick => cost is D * ttime
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 1,
            'fp': 0
        })
    
    # HEALTHY actually treated
    for (pid, label, ttime, rsk) in treat_healthy_subset:
        cost_ = FP_COST  # healthy => false positive cost
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 0,
            'fp': 1
        })
    
    # Build final records for "not-treated" subsets
    not_treat_results = []
    
    # recommended sick but not in top capacity
    for (pid, label, ttime, rsk) in not_treat_sick_subset:
        cost_ = FN_COST  # sick but not treated
        not_treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 0,
            'treat_time': None,
            'cost': cost_,
            'tp': 0,
            'fp': 0
        })
    
    # never recommended (cost=None above)
    for row in results:
        if row['cost'] is None:
            if row['label'] == 1:
                row['cost'] = FN_COST
            else:
                row['cost'] = 0
            not_treat_results.append(row)
    
    # Combine
    df_res = pd.DataFrame(treat_results + not_treat_results)
    
    # Compute final stats
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# 4. BENCHMARK POLICIES (Threshold-based, DP, etc.)
###############################################################################
def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    for thr in thresholds:
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                if row['risk_score'] >= thr:
                    return int(row['time'])
            return None
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                t = int(row['time'])
                if t < time_steps and row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.8,7)):
    best_A, best_B = None, None
    best_cost, best_stats = float('inf'), None
    
    for A in A_candidates:
        for B in B_candidates:
            def policy_func(patient_rows):
                for _, row in patient_rows.iterrows():
                    t = row['time']
                    thr = A*t + B
                    thr = np.clip(thr,0,1)
                    if row['risk_score'] >= thr:
                        return int(t)
                return None
            stats = simulate_policy(df, policy_func)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A,best_B), best_stats

def make_linear_threshold_policy(A,B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            final_t = patient_rows['time'].max()
            final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
            if final_row['risk_score'] >= thr:
                return int(final_t)
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func

###############################################################################
# 5. DATA-DRIVEN DP (Unconstrained)
###############################################################################
def to_bucket(prob):
    """Map a probability into a bucket [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows= grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom > 0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick,
                                       FP=10, FN=50, D=1,
                                       gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T
    for b in range(n_buckets):
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from T-1 down to 0
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """Return a policy function that treats if pi[t,b]==1 at time t."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# 6. SPSA ALGORITHM + COST FUNCTION
###############################################################################
def train_and_select_best_model(X_train, y_train, X_val, y_val):
    """
    (Reference from Algorithm 0)
    Trains multiple models (RF, GB, CatBoost) over small hyperparam grids,
    picks best by AUC.
    """
    best_auc = -1.0
    best_model = None
    best_name  = None
    
    RF_PARAM_GRID = {
        'n_estimators': [50, 100],
        'max_depth': [3, 5]
    }
    GB_PARAM_GRID = {
        'n_estimators': [50, 100],
        'learning_rate': [0.05, 0.1],
        'max_depth': [3, 5]
    }
    CATBOOST_PARAM_GRID = {
        'iterations': [50, 100],
        'learning_rate': [0.05, 0.1],
        'depth': [3, 5]
    }
    
    for params in ParameterGrid(RF_PARAM_GRID):
        rf = RandomForestClassifier(random_state=0, **params)
        rf.fit(X_train, y_train)
        val_prob = rf.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = rf
            best_name  = f"RandomForest_{params}"
    
    for params in ParameterGrid(GB_PARAM_GRID):
        gb = GradientBoostingClassifier(random_state=0, **params)
        gb.fit(X_train, y_train)
        val_prob = gb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = gb
            best_name  = f"GradientBoosting_{params}"
    
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        cb = CatBoostClassifier(verbose=0, random_state=0, **params)
        cb.fit(X_train, y_train)
        val_prob = cb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = cb
            best_name  = f"CatBoost_{params}"
    
    return best_model, best_auc, best_name

def spsa_optimize(cost_func, param_init, n_iter=20,
                  alpha=0.602, gamma=0.101, a=0.1, c=0.1, seed=0):
    """
    Generic SPSA optimizer to minimize a cost_func over a real vector space.
    """
    rng = np.random.RandomState(seed)
    p = param_init.copy()
    best_p = p.copy()
    best_cost = cost_func(p)
    
    for k in range(1, n_iter+1):
        ak = a / (k**alpha)
        ck = c / (k**gamma)
        delta = rng.choice([-1,1], size=len(p))
        
        p_plus  = p + ck*delta
        p_minus = p - ck*delta
        
        cost_plus  = cost_func(p_plus)
        cost_minus = cost_func(p_minus)
        
        g_approx = (cost_plus - cost_minus)/(2.0*ck) * delta
        
        p = p - ak*g_approx
        
        current_cost = cost_func(p)
        if current_cost < best_cost:
            best_cost = current_cost
            best_p = p.copy()
    
    return best_p, best_cost

def parse_spsa_params(params):
    """
    Map a real vector `params` into discrete hyperparams + DP gamma:
      params[0]: model_type (0=RF,1=GB,2=CB)
      params[1]: n_estimators [10..200]
      params[2]: learning_rate [0.01..0.2]
      params[3]: max_depth [2..10]
      params[4]: gamma [0.90..0.999]
    """
    p0 = int(round(np.clip(params[0], 0, 2)))       # model_type
    p1 = int(round(np.clip(params[1], 10, 200)))    # n_estimators
    p2 = float(np.clip(params[2], 0.01, 0.2))       # learning_rate
    p3 = int(round(np.clip(params[3], 2, 10)))      # max_depth
    p4 = float(np.clip(params[4], 0.90, 0.999))      # gamma
    return (p0, p1, p2, p3, p4)

def spsa_cost_function(param_vector, df_train, df_val):
    """
    Decision-aware cost function used by SPSA:
      1) Parse param_vector => (model_type, n_estimators, learning_rate, max_depth, gamma)
      2) Train model on df_train
      3) Predict risk on df_val
      4) Build DP with gamma (from param_vector) using transitions from df_train
      5) Evaluate cost on df_val (unconstrained).
    """
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(param_vector)

    X_tr = df_train[['EIT','NIRS','EIS']].values
    y_tr = df_train['label'].values
    
    # Build the model
    if model_type == 0:
        clf = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                         max_depth=m_depth, random_state=0)
    else:
        clf = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                 verbose=0, random_state=0)
    clf.fit(X_tr, y_tr)
    
    # Evaluate on df_val with DP policy
    X_val = df_val[['EIT','NIRS','EIS']].values
    prob_val = clf.predict_proba(X_val)[:,1]
    df_val_ = df_val.copy()
    df_val_['risk_score'] = prob_val
    
    # DP transitions from df_train
    df_tr_ = df_train.copy()
    train_probs = clf.predict_proba(df_tr_[['EIT','NIRS','EIS']])[:,1]
    df_tr_['risk_score'] = train_probs
    df_tr_['risk_bucket'] = df_tr_['risk_score'].apply(to_bucket)
    
    p_trans, p_sick = estimate_transition_and_sick_probs(df_tr_, T=T_MAX, n_buckets=5)
    V, pi_ = train_data_driven_dp_unconstrained(
        p_trans, p_sick,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    
    # Evaluate that DP policy on df_val
    df_val_['risk_bucket'] = df_val_['risk_score'].apply(to_bucket)
    dp_policy = make_dp_policy(V, pi_, T=T_MAX)
    stats = simulate_policy(df_val_, dp_policy)  # unconstrained cost
    return stats['cost']


###############################################################################
# 7. ALGORITHM 5 (SPSA) WITH 50% CAP ON FINAL HOLDOUT
###############################################################################
def run_algorithm5_spsa_with_capacity(df_all, k=3, seed=0, n_spsa_iter=20, capacity_frac=0.5):
    """
    SPSA Hyper-Parameter Tuning (Decision-Aware) + 50% cap on final holdout.
    
    1) Split data into k+1 groups: G1..Gk, G_{k+1}
    2) For each fold i in [1..k]:
         - define cost function that trains on G\\G_i, evaluates cost on G_i
         - run SPSA to find best param_i
    3) Among param_1..param_k, pick best overall param* with minimal sum-of-costs across folds
    4) Evaluate param* on G_{k+1} (the holdout),
       but we apply the 50%-cap to *sick* patients in the final holdout only.
    5) Build final table with 5 methods:
       - Constant Threshold
       - Dynamic Threshold-R
       - Linear Threshold
       - Wait Till End
       - DP-based (SPSA)
      and each one is simulated with the 50%-cap on sick.
    """
    # Split into k+1 groups
    groups = k_plus_1_splits(df_all, k=k, seed=seed)
    
    # param_init is the starting guess for SPSA hyperparams
    param_init = np.array([1.0, 50.0, 0.05, 3.0, 0.95], dtype=float)
    
    # 2) For each fold i in [1..k], run SPSA to get best param_i
    spsa_solutions = []
    for i in range(1, k+1):
        df_val = groups[i-1]
        df_train_list = [groups[j] for j in range(k) if j != (i-1)]
        df_train_ = pd.concat(df_train_list, ignore_index=True)
        
        def fold_cost_func(p):
            return spsa_cost_function(p, df_train_, df_val)
        
        best_p_fold, best_c_fold = spsa_optimize(
            fold_cost_func,
            param_init,
            n_iter=n_spsa_iter,
            alpha=0.602,
            gamma=0.101,
            a=0.1,
            c=0.1,
            seed=seed+i
        )
        spsa_solutions.append( (best_p_fold, best_c_fold) )
    
    # 3) Among these k solutions, pick best overall param*
    #    with minimal sum-of-costs across all k folds
    k_ = k
    fold_cost_matrix = np.zeros((k_, k_), dtype=float)
    for i in range(k_):
        param_i = spsa_solutions[i][0]
        for j in range(k_):
            df_val_j = groups[j]
            df_train_j_list = [groups[m] for m in range(k_) if m != j]
            df_train_j = pd.concat(df_train_j_list, ignore_index=True)
            c_ij = spsa_cost_function(param_i, df_train_j, df_val_j)
            fold_cost_matrix[i,j] = c_ij
    
    total_cost_per_param = fold_cost_matrix.sum(axis=1)
    best_index = np.argmin(total_cost_per_param)
    best_param = spsa_solutions[best_index][0]
    
    # 4) Evaluate best_param on final holdout G_{k+1}, with 50%-cap
    df_holdout = groups[k]
    df_train_for_holdout = pd.concat(groups[:k], ignore_index=True)
    
    # Parse best_param -> final ML model + DP discount factor
    model_type, n_est, lr, m_depth, gamma_ = parse_spsa_params(best_param)
    
    # Train final classifier on union of k folds
    X_train2 = df_train_for_holdout[['EIT','NIRS','EIS']].values
    y_train2 = df_train_for_holdout['label'].values
    
    if model_type == 0:
        clf_final = RandomForestClassifier(n_estimators=n_est, max_depth=m_depth, random_state=0)
    elif model_type == 1:
        clf_final = GradientBoostingClassifier(n_estimators=n_est, learning_rate=lr,
                                               max_depth=m_depth, random_state=0)
    else:
        clf_final = CatBoostClassifier(iterations=n_est, learning_rate=lr, depth=m_depth,
                                       verbose=0, random_state=0)
    clf_final.fit(X_train2, y_train2)
    
    # Build DP transitions from df_train_for_holdout
    df_train2 = df_train_for_holdout.copy()
    train_probs2 = clf_final.predict_proba(df_train2[['EIT','NIRS','EIS']])[:,1]
    df_train2['risk_score'] = train_probs2
    df_train2['risk_bucket'] = df_train2['risk_score'].apply(to_bucket)
    p_trans2, p_sick2 = estimate_transition_and_sick_probs(df_train2, T=T_MAX, n_buckets=5)
    
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans2, p_sick2,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=gamma_, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    # 5) Tune threshold-based policies on the same train set (df_train2)
    thr_const, _ = constant_threshold_search(df_train2)
    thr_vec, _ = dynamic_threshold_random_search(df_train2, time_steps=T_MAX)
    (A_lin, B_lin), _ = linear_threshold_search(df_train2)
    thr_wte, _ = wait_till_end_search(df_train2)
    
    # Evaluate final policies on holdout with capacity
    df_holdout2 = df_holdout.copy()
    holdout_probs = clf_final.predict_proba(df_holdout2[['EIT','NIRS','EIS']])[:,1]
    df_holdout2['risk_score'] = holdout_probs
    
    # 1) Constant threshold
    pol_const = make_constant_threshold_policy(thr_const)
    stats_const = simulate_policy_with_sick_capacity(df_holdout2, pol_const, capacity_frac=capacity_frac)
    
    # 2) Dynamic threshold - random
    pol_dyn = make_dynamic_threshold_policy(thr_vec)
    stats_dyn = simulate_policy_with_sick_capacity(df_holdout2, pol_dyn, capacity_frac=capacity_frac)
    
    # 3) Linear threshold
    pol_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin = simulate_policy_with_sick_capacity(df_holdout2, pol_lin, capacity_frac=capacity_frac)
    
    # 4) Wait till end
    pol_wte = make_wait_till_end_policy(thr_wte)
    stats_wte = simulate_policy_with_sick_capacity(df_holdout2, pol_wte, capacity_frac=capacity_frac)
    
    # 5) DP-based policy from SPSA
    df_holdout2_dp = df_holdout2.copy()
    df_holdout2_dp['risk_bucket'] = df_holdout2_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy_with_sick_capacity(df_holdout2_dp, dp_policy_final,
                                                  capacity_frac=capacity_frac)
    
    # Build final table
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'DP-based (SPSA)'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table


###############################################################################
# 8. MAIN - RUN MULTIPLE REPLICATES
###############################################################################
def main():
    # Load synthetic data
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    # Filter to time < T_MAX if needed
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # Check required columns
    required = {'patient_id','time','EIT','NIRS','EIS','label'}
    if not required.issubset(df_all.columns):
        raise ValueError(f"CSV must have columns at least: {required}. Found: {df_all.columns}")
    
    # Number of replicates to run
    NUM_REPLICATES = 30
    
    # Store each replicate's table in a list
    replicate_tables = []
    
    for rep in range(NUM_REPLICATES):
        # Use different seed each time to randomize grouping
        seed = 412 + rep
        print(f"\n=== Running replicate {rep+1}/{NUM_REPLICATES} (seed={seed}) ===")
        
        # Run ALGORITHM 5 (SPSA) with 50%-cap on final holdout
        table_alg5 = run_algorithm5_spsa_with_capacity(
            df_all,
            k=3,
            seed=seed,
            n_spsa_iter=20,
            capacity_frac=0.5
        )
        
        print(table_alg5.to_string(index=False))
        replicate_tables.append(table_alg5)
    
    # After running all replicates, concatenate them for summary stats
    all_results = pd.concat(replicate_tables, ignore_index=True)
    
    # Group by 'Method' and compute mean/std of each numeric column
    summary = all_results.groupby('Method').agg({
        'Cost': ['mean','std'],
        'Precision (%)': ['mean','std'],
        'Recall (%)': ['mean','std'],
        'Avg Treat Time': ['mean','std']
    })
    
    # Flatten the multi-level column index
    summary.columns = ['_'.join(col).strip() for col in summary.columns.values]
    summary.reset_index(inplace=True)

    print("\n=============================")
    print("FINAL SUMMARY ACROSS REPLICATES")
    print("=============================")
    print(summary.to_string(index=False))


if __name__ == "__main__":
    main()


=== Running replicate 1/30 (seed=412) ===
             Method  Cost  Precision (%)  Recall (%)  Avg Treat Time
 Constant Threshold  1043      43.750000   48.275862        3.562500
Dynamic Threshold-R  1450       0.000000    0.000000        0.000000
   Linear Threshold  1137      30.434783   48.275862        1.500000
      Wait Till End  1040      93.333333   48.275862       20.000000
    DP-based (SPSA)   905      82.352941   48.275862        8.529412

=== Running replicate 2/30 (seed=413) ===
             Method  Cost  Precision (%)  Recall (%)  Avg Treat Time
 Constant Threshold  1149      32.558140   48.275862        2.558140
Dynamic Threshold-R  1450       0.000000    0.000000        0.000000
   Linear Threshold  1195      26.923077   48.275862        1.288462
      Wait Till End  1030     100.000000   48.275862       20.000000
    DP-based (SPSA)   874     100.000000   48.275862        8.857143

=== Running replicate 3/30 (seed=414) ===
             Method  Cost  Precision (%)  R