In [1]:

"""
Full CROSS-VALIDATION (ALGORITHM 2) FOR CONSTRAINED HEMORRHAGE DIAGNOSIS & TREATMENT
WITH A 50% CAP ON SICK PATIENTS).

Requirements:
  pip install numpy pandas scikit-learn catboost
"""
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# GLOBAL PARAMETERS (same as your original)
###############################################################################
FP_COST = 10    # False positive cost
FN_COST = 50    # False negative cost
D_COST  = 1     # Delay cost per time
T_MAX   = 21    # maximum discrete time steps (0..T_MAX-1)

# We'll interpret mu as the discount factor gamma
MU_CANDIDATES = [0.95, 0.99]

# Example hyperparameter grids
RF_PARAM_GRID = {
    'n_estimators': [50, 100],
    'max_depth': [3, 5]
}
GB_PARAM_GRID = {
    'n_estimators': [50, 100],
    'learning_rate': [0.05, 0.1],
    'max_depth': [3, 5]
}
CATBOOST_PARAM_GRID = {
    'iterations': [50, 100],
    'learning_rate': [0.05, 0.1],
    'depth': [3, 5]
}

###############################################################################
# HELPERS FOR SPLITTING DATA, AUC, ETC.
###############################################################################
def split_into_nplus1_groups(df, n=4, seed=0):
    """Shuffle patient IDs and split ~evenly into (n+1) groups: G1..G_{n+1}."""
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    N = len(unique_pids)
    group_size = int(np.ceil(N/(n+1)))
    
    groups = []
    start_idx = 0
    for i in range(n+1):
        end_idx = min(start_idx + group_size, N)
        group_pids = unique_pids[start_idx:end_idx]
        group_df   = df[df['patient_id'].isin(group_pids)].copy()
        groups.append(group_df)
        start_idx = end_idx
    return groups

def compute_auc_score(y_true, y_prob):
    """Compute AUC safely. If only one class, return 0.5 to avoid errors."""
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

###############################################################################
# POLICY SIMULATION (UNCONSTRAINED) - REFERENCE ONLY
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain columns: patient_id, time, label, risk_score
    policy_func(patient_rows) -> treat_time (int) or None
    
    Returns dict of: cost, precision, recall, avg_treatment_time
    """
    results = []
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        treat_time = policy_func(grp)
        
        if treat_time is None:  # never treat
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    precision = 0.0
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    recall = 0.0
    if total_sick > 0:
        recall = tp_sum / total_sick
    
    avg_tt = 0.0
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        if len(valid_tt) > 0:
            avg_tt = valid_tt.mean()
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# POLICY SIMULATION WITH SICK CAPACITY (IMPORTANT!)
###############################################################################
def simulate_policy_with_sick_capacity(df, policy_func, capacity_frac=0.5):
    """
    We enforce that at most (capacity_frac) fraction of the *sick* patients
    can be treated. Everyone else recommended (healthy) is treated with no limit.

    Steps:
      1. Identify which patients are "recommended" for treatment by the policy.
      2. Separate recommended patients into "sick recommended" vs. "healthy recommended".
      3. Among the recommended *sick* patients, we can only treat up to
         floor(capacity_frac * total_sick). We'll choose the top (by risk_score).
      4. We treat all recommended *healthy* patients with no limit.
      5. Everyone else is not treated => FN cost if sick, 0 cost if healthy.
    """
    # 1) Collect recommendation info
    results = []
    recommended_sick = []     # (pid, label=1, time_treated, risk_score)
    recommended_healthy = []  # (pid, label=0, time_treated, risk_score)
    
    # Count total sick
    all_sick_df = df[df['label']==1]
    num_sick = all_sick_df['patient_id'].nunique()
    # capacity (max number of sick we can treat)
    capacity_num = int(np.floor(capacity_frac * num_sick)) if num_sick > 0 else 0
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        # The policy returns the time we "recommend" to treat or None
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # not recommended => cost to fill in later
            results.append({
                'patient_id': pid,
                'label': label,
                'treated': 0,
                'treat_time': None,
                'cost': None,  # fill in later
                'tp': 0,
                'fp': 0
            })
        else:
            # recommended => record separately
            row_t = grp[grp['time']==treat_time].iloc[0]
            recommended_risk = row_t['risk_score']
            if label == 1:
                recommended_sick.append((pid, label, treat_time, recommended_risk))
            else:
                recommended_healthy.append((pid, label, treat_time, recommended_risk))
    
    # 2) Sort recommended sick by descending risk_score
    recommended_sick.sort(key=lambda x: x[3], reverse=True)
    # 3) Actually treat only top capacity_num from recommended sick
    treat_sick_subset = recommended_sick[:capacity_num]
    not_treat_sick_subset = recommended_sick[capacity_num:]
    
    # 4) We treat ALL recommended healthy (no limit)
    treat_healthy_subset = recommended_healthy
    
    # Build final result records for the actually-treated subsets
    treat_results = []
    
    # 4a) For SICK actually treated
    for (pid, label, ttime, rsk) in treat_sick_subset:
        # label==1 => sick => cost = D_COST * ttime
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': D_COST * ttime,
            'tp': 1,
            'fp': 0
        })
    
    # 4b) For HEALTHY actually treated
    for (pid, label, ttime, rsk) in treat_healthy_subset:
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': FP_COST,  # healthy => false positive cost
            'tp': 0,
            'fp': 1
        })
    
    # 5) Build final records for not-treated subsets
    not_treat_results = []
    # (a) Sick recommended but not treated
    for (pid, label, ttime, rsk) in not_treat_sick_subset:
        not_treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 0,
            'treat_time': None,
            'cost': FN_COST,  # sick not treated => FN
            'tp': 0,
            'fp': 0
        })
    # (b) Those who were never recommended
    for row in results:
        if row['cost'] is None:
            if row['label'] == 1:
                row['cost'] = FN_COST
            else:
                row['cost'] = 0
            not_treat_results.append(row)
    
    df_res = pd.DataFrame(treat_results + not_treat_results)
    
    # Compute stats
    total_cost = df_res['cost'].sum()
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    precision = 0.0
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    recall = 0.0
    if total_sick > 0:
        recall = tp_sum / total_sick
    
    avg_tt = 0.0
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        if len(valid_tt) > 0:
            avg_tt = valid_tt.mean()
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# THRESHOLD-based Policies
###############################################################################
def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost = None, float('inf')
    best_stats = None
    for thr in thresholds:
        policy = make_constant_threshold_policy(thr)
        stats = simulate_policy(df, policy)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec) and row['risk_score'] >= thr_vec[t]:
                return t
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        policy  = make_dynamic_threshold_policy(thr_vec)
        stats   = simulate_policy(df, policy)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_linear_threshold_policy(A, B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr, 0, 1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.6,4)):
    best_A, best_B = None, None
    best_cost = float('inf')
    best_stats= None
    for A in A_candidates:
        for B in B_candidates:
            policy = make_linear_threshold_policy(A, B)
            stats  = simulate_policy(df, policy)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A, best_B), best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time'] == final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost = None, float('inf')
    best_stats = None
    for thr in thresholds:
        policy = make_wait_till_end_policy(thr)
        stats  = simulate_policy(df, policy)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

###############################################################################
# DP-related helpers
###############################################################################
def to_bucket(prob):
    """Map a probability into 5 discrete buckets [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    """
    Estimate p_trans[t,b,b_next] and p_sick[t,b].
    df_train must have columns: 'patient_id','time','risk_bucket','label'
    """
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows = grp.to_dict('records')
        for i, row in enumerate(rows):
            t   = int(row['time'])
            b   = int(row['risk_bucket'])
            lbl = row['label']
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom > 0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom > 0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard unconstrained DP for each bucket b at each time t:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V  = np.zeros((T+1, n_buckets))
    pi = np.zeros((T,   n_buckets), dtype=int)
    
    # boundary at t = T
    for b in range(n_buckets):
        # if we "treat" at the last time (T-1):
        cost_treat   = p_sick[T-1,b] * (D*(T-1)) + (1 - p_sick[T-1,b])*FP
        # if we "do not treat" at all:
        cost_notreat = p_sick[T-1,b] * FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from t = T-1 down to 0
    for t_ in reversed(range(T)):
        for b in range(n_buckets):
            cost_treat = p_sick[t_,b]*(D*t_) + (1 - p_sick[t_,b])*FP
            # cost_wait:
            if t_ == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t_,b,b_next] * V[t_+1,b_next]
                cost_wait = gamma * exp_future
            if cost_treat <= cost_wait:
                V[t_,b] = cost_treat
                pi[t_,b] = 1  # treat now
            else:
                V[t_,b] = cost_wait
                pi[t_,b] = 0  # wait
    return V, pi

def make_dp_policy(V, pi_, T=20):
    """
    Return a policy function that treats at time t if pi_[t,bucket] == 1
    """
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# ALGORITHM 2 (FULL MATCH) BUT WITH SICK CAP IN FINAL FOLD
###############################################################################
def run_algorithm2_full_match(df_all, n=4, seed=0):
    """
    Implements Algorithm 2 from your snippet exactly, but in the *final evaluation*
    (G_{n+1}) we enforce a 50% cap on sick patients. 
    """
    # Keep only time < T_MAX
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # 1) Split => G1..Gn, G_{n+1}
    groups = split_into_nplus1_groups(df_all, n=n, seed=seed)
    G_cv   = groups[:-1] # G1..Gn
    G_test = groups[-1]  # G_{n+1}
    
    # Build "lambda_candidates" as tuples: (model_type, (sorted param items)) so they're hashable
    lambda_candidates = []
    for p in ParameterGrid(RF_PARAM_GRID):
        param_tuple = tuple(sorted(p.items()))
        lambda_candidates.append(('rf', param_tuple))
    for p in ParameterGrid(GB_PARAM_GRID):
        param_tuple = tuple(sorted(p.items()))
        lambda_candidates.append(('gb', param_tuple))
    for p in ParameterGrid(CATBOOST_PARAM_GRID):
        param_tuple = tuple(sorted(p.items()))
        lambda_candidates.append(('cat', param_tuple))
    
    def dict_from_param_tuple(param_tuple):
        return dict(param_tuple)
    
    # Helper function to train + compute "AUCCost = 1 - AUC"
    def compute_AUCCost(lambda_cand, train_df, val_df):
        model_type, param_tuple = lambda_cand
        param_dict = dict_from_param_tuple(param_tuple)
        
        X_train = train_df[['EIT','NIRS','EIS']].values
        y_train = train_df['label'].values
        X_val   = val_df[['EIT','NIRS','EIS']].values
        y_val   = val_df['label'].values
        
        if model_type == 'rf':
            mdl = RandomForestClassifier(random_state=0, **param_dict)
        elif model_type == 'gb':
            mdl = GradientBoostingClassifier(random_state=0, **param_dict)
        else:
            mdl = CatBoostClassifier(verbose=0, random_state=0, **param_dict)
        
        mdl.fit(X_train, y_train)
        prob_val = mdl.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, prob_val)
        return 1.0 - auc_val  # "AUCCost"
    
    # Helper to train final model + build DP => "ActualCost"
    def compute_ActualCost(lambda_cand, mu, train_df, val_df):
        """
        mu = discount factor
        1) train model on train_df
        2) get risk scores on val_df
        3) make DP policy with discount=mu
        4) simulate => cost
        """
        model_type, param_tuple = lambda_cand
        param_dict = dict_from_param_tuple(param_tuple)
        
        X_train = train_df[['EIT','NIRS','EIS']].values
        y_train = train_df['label'].values
        
        if model_type == 'rf':
            mdl = RandomForestClassifier(random_state=0, **param_dict)
        elif model_type == 'gb':
            mdl = GradientBoostingClassifier(random_state=0, **param_dict)
        else:
            mdl = CatBoostClassifier(verbose=0, random_state=0, **param_dict)
        mdl.fit(X_train, y_train)
        
        # Risk scores
        val_df_ = val_df.copy()
        X_val   = val_df_[['EIT','NIRS','EIS']].values
        prob_val= mdl.predict_proba(X_val)[:,1]
        val_df_['risk_score'] = prob_val
        val_df_['risk_bucket']= val_df_['risk_score'].apply(to_bucket)
        
        # Build DP from train
        train_df_ = train_df.copy()
        prob_tr   = mdl.predict_proba(train_df_[['EIT','NIRS','EIS']].values)[:,1]
        train_df_['risk_score']  = prob_tr
        train_df_['risk_bucket'] = train_df_['risk_score'].apply(to_bucket)
        p_trans, p_sick = estimate_transition_and_sick_probs(train_df_,
                                                             T=T_MAX,
                                                             n_buckets=5)
        V, pi_ = train_data_driven_dp_unconstrained(
            p_trans, p_sick,
            FP=FP_COST, FN=FN_COST, D=D_COST,
            gamma=mu, T=T_MAX
        )
        dp_policy = make_dp_policy(V, pi_, T=T_MAX)
        
        stats = simulate_policy(val_df_, dp_policy)
        return stats['cost']
    
    nfolds = len(G_cv)  # = n
    ###################################################################
    # 2) Outer loop j = 1..n
    ###################################################################
    best_lambda_for_j = [None]*n
    cost_j_mu = {j: {} for j in range(n)}
    
    for j in range(n):
        G_outer_val = G_cv[j]
        train_list = [G_cv[m] for m in range(n) if m != j]
        G_outer_train = pd.concat(train_list, ignore_index=True)
        
        # (a) Inner cross among i != j => sum of AUCCost
        sum_auccost = {lam: 0.0 for lam in lambda_candidates}
        i_indices = [x for x in range(n) if x != j]
        
        for i_ in i_indices:
            G_inner_val = G_cv[i_]
            inner_train_list = [G_cv[m] for m in range(n) if (m != j) and (m != i_)]
            G_inner_train = pd.concat(inner_train_list, ignore_index=True)
            
            for lam in lambda_candidates:
                cost_ij = compute_AUCCost(lam, G_inner_train, G_inner_val)
                sum_auccost[lam] += cost_ij
        
        # pick lambda_j^*
        best_lam_j = None
        best_val = float('inf')
        for lam in lambda_candidates:
            if sum_auccost[lam] < best_val:
                best_val = sum_auccost[lam]
                best_lam_j = lam
        best_lambda_for_j[j] = best_lam_j
        
        # (b) Retrain on G_outer_train with lambda_j^*, measure ActualCost for each mu
        for mu_ in MU_CANDIDATES:
            c_j_mu = compute_ActualCost(best_lam_j, mu_, G_outer_train, G_outer_val)
            cost_j_mu[j][mu_] = c_j_mu
    
    # (c) pick final mu^* by summing cost_j_mu across j
    mu_star = None
    best_sum_cost = float('inf')
    for mu_ in MU_CANDIDATES:
        sum_c = 0.0
        for j in range(n):
            sum_c += cost_j_mu[j][mu_]
        if sum_c < best_sum_cost:
            best_sum_cost = sum_c
            mu_star = mu_
    
    ###################################################################
    # 4) "Full cross" pass #1: pick final lambda^*
    ###################################################################
    sum_auccost_all_i = {lam: 0.0 for lam in lambda_candidates}
    
    for i_ in range(n):
        G_val_i = G_cv[i_]
        G_train_i = pd.concat([G_cv[k] for k in range(n) if k != i_], ignore_index=True)
        for lam in lambda_candidates:
            c_auccost = compute_AUCCost(lam, G_train_i, G_val_i)
            sum_auccost_all_i[lam] += c_auccost
    
    lambda_star = None
    best_val2 = float('inf')
    for lam in lambda_candidates:
        if sum_auccost_all_i[lam] < best_val2:
            best_val2 = sum_auccost_all_i[lam]
            lambda_star = lam
    
    ###################################################################
    # 5) "Full cross" pass #2: pick final mu^*
    #    In practice, we already have mu_star from step (3).
    #    We'll still do it to match snippet's structure if needed.
    ###################################################################
    cost_i_mu_2 = {mu_: 0.0 for mu_ in MU_CANDIDATES}
    
    for i_ in range(n):
        G_val_i = G_cv[i_]
        G_train_i = pd.concat([G_cv[k] for k in range(n) if k != i_], ignore_index=True)
        for mu_ in MU_CANDIDATES:
            c_i_mu = compute_ActualCost(lambda_star, mu_, G_train_i, G_val_i)
            cost_i_mu_2[mu_] += c_i_mu
    
    mu_star_final = None
    best_cost_2 = float('inf')
    for mu_ in MU_CANDIDATES:
        if cost_i_mu_2[mu_] < best_cost_2:
            best_cost_2 = cost_i_mu_2[mu_]
            mu_star_final = mu_
    
    final_lambda = lambda_star
    final_mu     = mu_star_final
    
    ###################################################################
    # 6) Retrain final model on G1..Gn with final_lambda,
    #    build DP with final_mu, evaluate on G_{n+1} with capacity=50% on sick
    ###################################################################
    G_cv_concat = pd.concat(G_cv, ignore_index=True)
    
    model_type_final, param_tuple_final = final_lambda
    param_dict_final = dict_from_param_tuple(param_tuple_final)
    if model_type_final == 'rf':
        final_mdl = RandomForestClassifier(random_state=0, **param_dict_final)
    elif model_type_final == 'gb':
        final_mdl = GradientBoostingClassifier(random_state=0, **param_dict_final)
    else:
        final_mdl = CatBoostClassifier(verbose=0, random_state=0, **param_dict_final)
    
    X_cv_all = G_cv_concat[['EIT','NIRS','EIS']].values
    y_cv_all = G_cv_concat['label'].values
    final_mdl.fit(X_cv_all, y_cv_all)
    
    # Build final DP policy
    df_dp_train = G_cv_concat.copy()
    prob_dp_tr  = final_mdl.predict_proba(df_dp_train[['EIT','NIRS','EIS']].values)[:,1]
    df_dp_train['risk_score']  = prob_dp_tr
    df_dp_train['risk_bucket'] = df_dp_train['risk_score'].apply(to_bucket)
    
    p_trans_final, p_sick_final = estimate_transition_and_sick_probs(df_dp_train,
                                                                     T=T_MAX,
                                                                     n_buckets=5)
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans_final, p_sick_final,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=final_mu, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    # Evaluate on G_{n+1} with *50% capacity on sick*
    G_test_eval = G_test.copy()
    prob_test   = final_mdl.predict_proba(G_test_eval[['EIT','NIRS','EIS']].values)[:,1]
    G_test_eval['risk_score'] = prob_test
    G_test_eval['risk_bucket'] = G_test_eval['risk_score'].apply(to_bucket)
    
    stats_dp = simulate_policy_with_sick_capacity(G_test_eval, dp_policy_final, capacity_frac=0.5)
    
    ###################################################################
    # OPTIONAL: Evaluate threshold-based policies on G_{n+1} with 50% sick cap
    ###################################################################
    # We tune thresholds on G1..Gn, then evaluate on G_{n+1} with capacity limit
    prob_cv_final = final_mdl.predict_proba(G_cv_concat[['EIT','NIRS','EIS']].values)[:,1]
    G_cv_concat['risk_score'] = prob_cv_final
    
    # 1) Constant threshold
    best_thr_const, _ = constant_threshold_search(G_cv_concat)
    policy_const = make_constant_threshold_policy(best_thr_const)
    stats_const  = simulate_policy_with_sick_capacity(G_test_eval, policy_const, capacity_frac=0.5)
    
    # 2) Dynamic threshold (random search)
    best_thr_vec, _ = dynamic_threshold_random_search(G_cv_concat,
                    time_steps=T_MAX-1,
                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                    n_samples=200, seed=123)
    policy_dyn = make_dynamic_threshold_policy(best_thr_vec)
    stats_dyn  = simulate_policy_with_sick_capacity(G_test_eval, policy_dyn, capacity_frac=0.5)
    
    # 3) Linear threshold
    (A_lin, B_lin), _ = linear_threshold_search(G_cv_concat)
    policy_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin  = simulate_policy_with_sick_capacity(G_test_eval, policy_lin, capacity_frac=0.5)
    
    # 4) Wait till end
    best_thr_wte, _ = wait_till_end_search(G_cv_concat)
    policy_wte = make_wait_till_end_policy(best_thr_wte)
    stats_wte  = simulate_policy_with_sick_capacity(G_test_eval, policy_wte, capacity_frac=0.5)
    
    # Build final table
    final_table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'Dynamic Threshold-DP'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    print("=== Final chosen hyperparams (lambda^*) ===")
    print(final_lambda)
    print("=== Final chosen discount factor (mu^*) ===")
    print(final_mu)
    
    return final_table

def main():
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    final_table = run_algorithm2_full_match(df_all, n=4, seed=42)
    print("\n=== ALGORITHM 2 (FULL MATCH) WITH 50% CAP ON FINAL FOLD (SICK ONLY) ===")
    print(final_table.to_string(index=False))

if __name__ == "__main__":
    main()

=== Final chosen hyperparams (lambda^*) ===
('cat', (('depth', 3), ('iterations', 50), ('learning_rate', 0.05)))
=== Final chosen discount factor (mu^*) ===
0.99

=== ALGORITHM 2 (FULL MATCH) WITH 50% CAP ON FINAL FOLD (SICK ONLY) ===
              Method  Cost  Precision (%)  Recall (%)  Avg Treat Time
  Constant Threshold   734      40.000000   47.619048        1.400000
 Dynamic Threshold-R   678      52.631579   47.619048        6.578947
    Linear Threshold   881      25.000000   47.619048        0.825000
       Wait Till End   750     100.000000   47.619048       20.000000
Dynamic Threshold-DP   634      83.333333   47.619048        7.000000


In [3]:
"""
Full CROSS-VALIDATION (ALGORITHM 2) FOR CONSTRAINED HEMORRHAGE DIAGNOSIS & TREATMENT
WITH A 50% CAP ON SICK PATIENTS.  REPEATED 30x (by default) WITH MEAN & STD REPORTED.

Requirements:
  pip install numpy pandas scikit-learn catboost
"""
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# GLOBAL PARAMETERS
###############################################################################
FP_COST = 10    # False positive cost
FN_COST = 50    # False negative cost
D_COST  = 1     # Delay cost per time
T_MAX   = 21    # maximum discrete time steps (0..T_MAX-1)

MU_CANDIDATES = [0.95, 0.99]

# Example hyperparameter grids
RF_PARAM_GRID = {
    'n_estimators': [50, 100],
    'max_depth': [3, 5]
}
GB_PARAM_GRID = {
    'n_estimators': [50, 100],
    'learning_rate': [0.05, 0.1],
    'max_depth': [3, 5]
}
CATBOOST_PARAM_GRID = {
    'iterations': [50, 100],
    'learning_rate': [0.05, 0.1],
    'depth': [3, 5]
}

###############################################################################
# HELPERS FOR SPLITTING DATA, AUC, ETC.
###############################################################################
def split_into_nplus1_groups(df, n=4, seed=0):
    """Shuffle patient IDs and split ~evenly into (n+1) groups: G1..G_{n+1}."""
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    N = len(unique_pids)
    group_size = int(np.ceil(N/(n+1)))
    
    groups = []
    start_idx = 0
    for i in range(n+1):
        end_idx = min(start_idx + group_size, N)
        group_pids = unique_pids[start_idx:end_idx]
        group_df   = df[df['patient_id'].isin(group_pids)].copy()
        groups.append(group_df)
        start_idx = end_idx
    return groups

def compute_auc_score(y_true, y_prob):
    """Compute AUC safely. If only one class, return 0.5 to avoid errors."""
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

###############################################################################
# POLICY SIMULATION (UNCONSTRAINED) - REFERENCE ONLY
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain columns: patient_id, time, label, risk_score
    policy_func(patient_rows) -> treat_time (int) or None
    
    Returns dict of: cost, precision, recall, avg_treatment_time
    """
    results = []
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        treat_time = policy_func(grp)
        
        if treat_time is None:  # never treat
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    precision = 0.0
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    recall = 0.0
    if total_sick > 0:
        recall = tp_sum / total_sick
    
    avg_tt = 0.0
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        if len(valid_tt) > 0:
            avg_tt = valid_tt.mean()
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# POLICY SIMULATION WITH SICK CAPACITY
###############################################################################
def simulate_policy_with_sick_capacity(df, policy_func, capacity_frac=0.5):
    """
    We enforce that at most (capacity_frac) fraction of the *sick* patients
    can be treated. Everyone else recommended (healthy) is treated with no limit.
    """
    # 1) Collect recommendation info
    results = []
    recommended_sick = []     # (pid, label=1, time_treated, risk_score)
    recommended_healthy = []  # (pid, label=0, time_treated, risk_score)
    
    # Count total sick
    all_sick_df = df[df['label']==1]
    num_sick = all_sick_df['patient_id'].nunique()
    # capacity (max number of sick we can treat)
    capacity_num = int(np.floor(capacity_frac * num_sick)) if num_sick > 0 else 0
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        # The policy returns the time we "recommend" to treat or None
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # not recommended => cost to fill in later
            results.append({
                'patient_id': pid,
                'label': label,
                'treated': 0,
                'treat_time': None,
                'cost': None,  # fill in later
                'tp': 0,
                'fp': 0
            })
        else:
            # recommended => record separately
            row_t = grp[grp['time']==treat_time].iloc[0]
            recommended_risk = row_t['risk_score']
            if label == 1:
                recommended_sick.append((pid, label, treat_time, recommended_risk))
            else:
                recommended_healthy.append((pid, label, treat_time, recommended_risk))
    
    # 2) Sort recommended sick by descending risk_score
    recommended_sick.sort(key=lambda x: x[3], reverse=True)
    # 3) Actually treat only top capacity_num from recommended sick
    treat_sick_subset = recommended_sick[:capacity_num]
    not_treat_sick_subset = recommended_sick[capacity_num:]
    
    # 4) We treat ALL recommended healthy (no limit)
    treat_healthy_subset = recommended_healthy
    
    # Build final result records for the actually-treated subsets
    treat_results = []
    
    # 4a) For SICK actually treated
    for (pid, label, ttime, rsk) in treat_sick_subset:
        # label==1 => sick => cost = D_COST * ttime
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': D_COST * ttime,
            'tp': 1,
            'fp': 0
        })
    
    # 4b) For HEALTHY actually treated
    for (pid, label, ttime, rsk) in treat_healthy_subset:
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': FP_COST,  # healthy => false positive cost
            'tp': 0,
            'fp': 1
        })
    
    # 5) Build final records for not-treated subsets
    not_treat_results = []
    # (a) Sick recommended but not treated
    for (pid, label, ttime, rsk) in not_treat_sick_subset:
        not_treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 0,
            'treat_time': None,
            'cost': FN_COST,  # sick not treated => FN
            'tp': 0,
            'fp': 0
        })
    # (b) Those who were never recommended
    for row in results:
        if row['cost'] is None:
            if row['label'] == 1:
                row['cost'] = FN_COST
            else:
                row['cost'] = 0
            not_treat_results.append(row)
    
    df_res = pd.DataFrame(treat_results + not_treat_results)
    
    # Compute stats
    total_cost = df_res['cost'].sum()
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    precision = 0.0
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    recall = 0.0
    if total_sick > 0:
        recall = tp_sum / total_sick
    
    avg_tt = 0.0
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        if len(valid_tt) > 0:
            avg_tt = valid_tt.mean()
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# THRESHOLD-based Policies
###############################################################################
def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost = None, float('inf')
    best_stats = None
    for thr in thresholds:
        policy = make_constant_threshold_policy(thr)
        stats = simulate_policy(df, policy)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec) and row['risk_score'] >= thr_vec[t]:
                return t
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        policy  = make_dynamic_threshold_policy(thr_vec)
        stats   = simulate_policy(df, policy)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_linear_threshold_policy(A, B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr, 0, 1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.6,4)):
    best_A, best_B = None, None
    best_cost = float('inf')
    best_stats= None
    for A in A_candidates:
        for B in B_candidates:
            policy = make_linear_threshold_policy(A, B)
            stats  = simulate_policy(df, policy)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A, best_B), best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time'] == final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost = None, float('inf')
    best_stats = None
    for thr in thresholds:
        policy = make_wait_till_end_policy(thr)
        stats  = simulate_policy(df, policy)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

###############################################################################
# DP-related helpers
###############################################################################
def to_bucket(prob):
    """Map a probability into 5 discrete buckets [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    """
    Estimate p_trans[t,b,b_next] and p_sick[t,b].
    df_train must have columns: 'patient_id','time','risk_bucket','label'
    """
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows = grp.to_dict('records')
        for i, row in enumerate(rows):
            t   = int(row['time'])
            b   = int(row['risk_bucket'])
            lbl = row['label']
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom > 0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                # If we never saw transitions from (t,b), assume self-loop
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom > 0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard unconstrained DP for each bucket b at each time t:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V  = np.zeros((T+1, n_buckets))
    pi = np.zeros((T,   n_buckets), dtype=int)
    
    # boundary at t = T: cost of deciding at the last time-step
    # (though we effectively only go up to T-1 in times)
    for b in range(n_buckets):
        # if we "treat" at the last time (T-1):
        cost_treat   = p_sick[T-1,b] * (D*(T-1)) + (1 - p_sick[T-1,b])*FP
        # if we "do not treat" at all:
        cost_notreat = p_sick[T-1,b] * FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from t = T-1 down to t=0
    for t_ in reversed(range(T)):
        for b in range(n_buckets):
            cost_treat = p_sick[t_,b]*(D*t_) + (1 - p_sick[t_,b])*FP
            # cost_wait:
            if t_ == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t_,b,b_next] * V[t_+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t_,b] = cost_treat
                pi[t_,b] = 1  # treat now
            else:
                V[t_,b] = cost_wait
                pi[t_,b] = 0  # wait
    return V, pi

def make_dp_policy(V, pi_, T=20):
    """
    Return a policy function that treats at time t if pi_[t,bucket] == 1
    """
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# ALGORITHM 2 (FULL MATCH) BUT WITH SICK CAP IN FINAL FOLD
###############################################################################
def run_algorithm2_full_match(df_all, n=4, seed=0):
    """
    Implements Algorithm 2 from your snippet exactly, but in the *final evaluation*
    (G_{n+1}) we enforce a 50% cap on sick patients. 
    Returns the final table of results for G_{n+1}.
    """
    # Keep only time < T_MAX
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # 1) Split => G1..Gn, G_{n+1}
    groups = split_into_nplus1_groups(df_all, n=n, seed=seed)
    G_cv   = groups[:-1]  # G1..Gn
    G_test = groups[-1]   # G_{n+1}
    
    # Build all hyperparam combinations
    lambda_candidates = []
    for p in ParameterGrid(RF_PARAM_GRID):
        param_tuple = tuple(sorted(p.items()))
        lambda_candidates.append(('rf', param_tuple))
    for p in ParameterGrid(GB_PARAM_GRID):
        param_tuple = tuple(sorted(p.items()))
        lambda_candidates.append(('gb', param_tuple))
    for p in ParameterGrid(CATBOOST_PARAM_GRID):
        param_tuple = tuple(sorted(p.items()))
        lambda_candidates.append(('cat', param_tuple))
    
    def dict_from_param_tuple(param_tuple):
        return dict(param_tuple)
    
    # Helper function to train + compute "AUCCost = 1 - AUC"
    def compute_AUCCost(lambda_cand, train_df, val_df):
        model_type, param_tuple = lambda_cand
        param_dict = dict_from_param_tuple(param_tuple)
        
        X_train = train_df[['EIT','NIRS','EIS']].values
        y_train = train_df['label'].values
        X_val   = val_df[['EIT','NIRS','EIS']].values
        y_val   = val_df['label'].values
        
        if model_type == 'rf':
            mdl = RandomForestClassifier(random_state=0, **param_dict)
        elif model_type == 'gb':
            mdl = GradientBoostingClassifier(random_state=0, **param_dict)
        else:
            mdl = CatBoostClassifier(verbose=0, random_state=0, **param_dict)
        
        mdl.fit(X_train, y_train)
        prob_val = mdl.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, prob_val)
        return 1.0 - auc_val  # "AUCCost"
    
    # Helper to train final model + build DP => "ActualCost"
    def compute_ActualCost(lambda_cand, mu, train_df, val_df):
        model_type, param_tuple = lambda_cand
        param_dict = dict_from_param_tuple(param_tuple)
        
        X_train = train_df[['EIT','NIRS','EIS']].values
        y_train = train_df['label'].values
        
        if model_type == 'rf':
            mdl = RandomForestClassifier(random_state=0, **param_dict)
        elif model_type == 'gb':
            mdl = GradientBoostingClassifier(random_state=0, **param_dict)
        else:
            mdl = CatBoostClassifier(verbose=0, random_state=0, **param_dict)
        mdl.fit(X_train, y_train)
        
        # Risk scores for val
        val_df_ = val_df.copy()
        X_val   = val_df_[['EIT','NIRS','EIS']].values
        prob_val= mdl.predict_proba(X_val)[:,1]
        val_df_['risk_score'] = prob_val
        val_df_['risk_bucket']= val_df_['risk_score'].apply(to_bucket)
        
        # Build DP from train
        train_df_ = train_df.copy()
        prob_tr   = mdl.predict_proba(train_df_[['EIT','NIRS','EIS']].values)[:,1]
        train_df_['risk_score']  = prob_tr
        train_df_['risk_bucket'] = train_df_['risk_score'].apply(to_bucket)
        
        p_trans, p_sick = estimate_transition_and_sick_probs(train_df_,
                                                             T=T_MAX,
                                                             n_buckets=5)
        V, pi_ = train_data_driven_dp_unconstrained(
            p_trans, p_sick,
            FP=FP_COST, FN=FN_COST, D=D_COST,
            gamma=mu, T=T_MAX
        )
        dp_policy = make_dp_policy(V, pi_, T=T_MAX)
        
        stats = simulate_policy(val_df_, dp_policy)
        return stats['cost']
    
    nfolds = len(G_cv)  # should be = n
    ###################################################################
    # 2) Outer loop j = 1..n
    ###################################################################
    best_lambda_for_j = [None]*n
    cost_j_mu = {j: {} for j in range(n)}
    
    for j in range(n):
        # G_outer_val = G_cv[j]
        G_outer_val = G_cv[j]
        train_list = [G_cv[m] for m in range(n) if m != j]
        G_outer_train = pd.concat(train_list, ignore_index=True)
        
        # (a) Inner cross among i != j => sum of AUCCost
        sum_auccost = {lam: 0.0 for lam in lambda_candidates}
        i_indices = [x for x in range(n) if x != j]
        
        for i_ in i_indices:
            G_inner_val = G_cv[i_]
            inner_train_list = [G_cv[m] for m in range(n) if (m != j) and (m != i_)]
            G_inner_train = pd.concat(inner_train_list, ignore_index=True)
            
            for lam in lambda_candidates:
                cost_ij = compute_AUCCost(lam, G_inner_train, G_inner_val)
                sum_auccost[lam] += cost_ij
        
        # pick lambda_j^*
        best_lam_j = None
        best_val = float('inf')
        for lam in lambda_candidates:
            if sum_auccost[lam] < best_val:
                best_val = sum_auccost[lam]
                best_lam_j = lam
        best_lambda_for_j[j] = best_lam_j
        
        # (b) Retrain on G_outer_train with lambda_j^*, measure ActualCost for each mu
        for mu_ in MU_CANDIDATES:
            c_j_mu = compute_ActualCost(best_lam_j, mu_, G_outer_train, G_outer_val)
            cost_j_mu[j][mu_] = c_j_mu
    
    # (c) pick final mu^* by summing cost_j_mu across j
    mu_star = None
    best_sum_cost = float('inf')
    for mu_ in MU_CANDIDATES:
        sum_c = 0.0
        for j in range(n):
            sum_c += cost_j_mu[j][mu_]
        if sum_c < best_sum_cost:
            best_sum_cost = sum_c
            mu_star = mu_
    
    ###################################################################
    # 4) "Full cross" pass #1: pick final lambda^*
    ###################################################################
    sum_auccost_all_i = {lam: 0.0 for lam in lambda_candidates}
    
    for i_ in range(n):
        G_val_i = G_cv[i_]
        G_train_i = pd.concat([G_cv[k] for k in range(n) if k != i_], ignore_index=True)
        for lam in lambda_candidates:
            c_auccost = compute_AUCCost(lam, G_train_i, G_val_i)
            sum_auccost_all_i[lam] += c_auccost
    
    lambda_star = None
    best_val2 = float('inf')
    for lam in lambda_candidates:
        if sum_auccost_all_i[lam] < best_val2:
            best_val2 = sum_auccost_all_i[lam]
            lambda_star = lam
    
    ###################################################################
    # 5) "Full cross" pass #2: pick final mu^*
    #    (In practice we already have mu_star from step (3).)
    ###################################################################
    cost_i_mu_2 = {mu_: 0.0 for mu_ in MU_CANDIDATES}
    for i_ in range(n):
        G_val_i = G_cv[i_]
        G_train_i = pd.concat([G_cv[k] for k in range(n) if k != i_], ignore_index=True)
        for mu_ in MU_CANDIDATES:
            c_i_mu = compute_ActualCost(lambda_star, mu_, G_train_i, G_val_i)
            cost_i_mu_2[mu_] += c_i_mu
    
    mu_star_final = None
    best_cost_2 = float('inf')
    for mu_ in MU_CANDIDATES:
        if cost_i_mu_2[mu_] < best_cost_2:
            best_cost_2 = cost_i_mu_2[mu_]
            mu_star_final = mu_
    
    final_lambda = lambda_star
    final_mu     = mu_star_final
    
    ###################################################################
    # 6) Retrain final model on G1..Gn with final_lambda,
    #    build DP with final_mu, evaluate on G_{n+1} with capacity=50%.
    ###################################################################
    G_cv_concat = pd.concat(G_cv, ignore_index=True)
    
    model_type_final, param_tuple_final = final_lambda
    param_dict_final = dict_from_param_tuple(param_tuple_final)
    
    if model_type_final == 'rf':
        final_mdl = RandomForestClassifier(random_state=0, **param_dict_final)
    elif model_type_final == 'gb':
        final_mdl = GradientBoostingClassifier(random_state=0, **param_dict_final)
    else:
        final_mdl = CatBoostClassifier(verbose=0, random_state=0, **param_dict_final)
    
    X_cv_all = G_cv_concat[['EIT','NIRS','EIS']].values
    y_cv_all = G_cv_concat['label'].values
    final_mdl.fit(X_cv_all, y_cv_all)
    
    # Build final DP policy
    df_dp_train = G_cv_concat.copy()
    prob_dp_tr  = final_mdl.predict_proba(df_dp_train[['EIT','NIRS','EIS']].values)[:,1]
    df_dp_train['risk_score']  = prob_dp_tr
    df_dp_train['risk_bucket'] = df_dp_train['risk_score'].apply(to_bucket)
    
    p_trans_final, p_sick_final = estimate_transition_and_sick_probs(df_dp_train,
                                                                     T=T_MAX,
                                                                     n_buckets=5)
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans_final, p_sick_final,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=final_mu, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    # Evaluate on G_{n+1} with *50% capacity on sick*
    G_test_eval = G_test.copy()
    prob_test   = final_mdl.predict_proba(G_test_eval[['EIT','NIRS','EIS']].values)[:,1]
    G_test_eval['risk_score'] = prob_test
    G_test_eval['risk_bucket']= G_test_eval['risk_score'].apply(to_bucket)
    
    stats_dp = simulate_policy_with_sick_capacity(G_test_eval, dp_policy_final, capacity_frac=0.5)
    
    ###################################################################
    # OPTIONAL: Evaluate threshold-based policies on G_{n+1} with 50% sick cap
    ###################################################################
    # We tune thresholds on G1..Gn, then evaluate on G_{n+1} with capacity limit
    prob_cv_final = final_mdl.predict_proba(G_cv_concat[['EIT','NIRS','EIS']].values)[:,1]
    G_cv_concat['risk_score'] = prob_cv_final
    
    # 1) Constant threshold
    best_thr_const, _ = constant_threshold_search(G_cv_concat)
    policy_const = make_constant_threshold_policy(best_thr_const)
    stats_const  = simulate_policy_with_sick_capacity(G_test_eval, policy_const, capacity_frac=0.5)
    
    # 2) Dynamic threshold (random search)
    best_thr_vec, _ = dynamic_threshold_random_search(G_cv_concat,
                    time_steps=T_MAX-1,
                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                    n_samples=200, seed=123)
    policy_dyn = make_dynamic_threshold_policy(best_thr_vec)
    stats_dyn  = simulate_policy_with_sick_capacity(G_test_eval, policy_dyn, capacity_frac=0.5)
    
    # 3) Linear threshold
    (A_lin, B_lin), _ = linear_threshold_search(G_cv_concat)
    policy_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin  = simulate_policy_with_sick_capacity(G_test_eval, policy_lin, capacity_frac=0.5)
    
    # 4) Wait till end
    best_thr_wte, _ = wait_till_end_search(G_cv_concat)
    policy_wte = make_wait_till_end_policy(best_thr_wte)
    stats_wte  = simulate_policy_with_sick_capacity(G_test_eval, policy_wte, capacity_frac=0.5)
    
    # Build final table
    final_table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'Dynamic Threshold-DP'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treatment Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    # Show final chosen hyperparams / discount factor for completeness
    print("=== Final chosen hyperparams (lambda^*) ===")
    print(final_lambda)
    print("=== Final chosen discount factor (mu^*) ===")
    print(final_mu)
    
    return final_table


###############################################################################
# MAIN: RUN MULTIPLE REPLICATES AND REPORT MEAN & STD
###############################################################################
def main():
    # Read data
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    NUM_REPLICATES = 30  # Change this if you want more or fewer replicates
    
    all_reps_results = []  # list of DataFrames, each replicate's final_table
    
    for rep in range(NUM_REPLICATES):
        # Different seed each time to randomize grouping
        seed = 412 + rep
        print(f"\n=== Running replicate {rep+1}/{NUM_REPLICATES} (seed={seed}) ===")
        
        final_table = run_algorithm2_full_match(df_all, n=4, seed=seed)
        # Tag each row of final_table with replicate index (just for clarity)
        final_table["Replicate"] = rep
        
        # Store
        all_reps_results.append(final_table)
    
    # Concatenate all replicate tables
    results_df = pd.concat(all_reps_results, ignore_index=True)
    
    # Compute mean & std across replicates, grouped by Method
    summary_df = (
        results_df
        .groupby("Method")
        .agg({
            "Cost": ["mean", "std"],
            "Precision (%)": ["mean", "std"],
            "Recall (%)": ["mean", "std"],
            "Avg Treatment Time": ["mean", "std"]
        })
    )
    # Flatten multi-level columns for readability
    summary_df.columns = ['_'.join(col).strip() for col in summary_df.columns.values]
    summary_df = summary_df.reset_index()
    
    print("\n=== SUMMARY OF 30 REPLICATES (MEAN ± STD) ===")
    print(summary_df.to_string(index=False))

if __name__ == "__main__":
    main()


=== Running replicate 1/30 (seed=412) ===
=== Final chosen hyperparams (lambda^*) ===
('cat', (('depth', 3), ('iterations', 50), ('learning_rate', 0.05)))
=== Final chosen discount factor (mu^*) ===
0.99

=== Running replicate 2/30 (seed=413) ===
=== Final chosen hyperparams (lambda^*) ===
('cat', (('depth', 3), ('iterations', 50), ('learning_rate', 0.05)))
=== Final chosen discount factor (mu^*) ===
0.99

=== Running replicate 3/30 (seed=414) ===
=== Final chosen hyperparams (lambda^*) ===
('cat', (('depth', 5), ('iterations', 50), ('learning_rate', 0.05)))
=== Final chosen discount factor (mu^*) ===
0.99

=== Running replicate 4/30 (seed=415) ===
=== Final chosen hyperparams (lambda^*) ===
('cat', (('depth', 3), ('iterations', 50), ('learning_rate', 0.05)))
=== Final chosen discount factor (mu^*) ===
0.99

=== Running replicate 5/30 (seed=416) ===
=== Final chosen hyperparams (lambda^*) ===
('cat', (('depth', 3), ('iterations', 50), ('learning_rate', 0.05)))
=== Final chosen discoun