In [1]:
"""
SEQUENTIAL OPTIMIZATION (ALGORITHM 3) FOR HEMORRHAGE DIAGNOSIS & TREATMENT
WITH A 50% CAP ON SICK PATIENTS.

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 1
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)
GAMMA_CANDIDATES = [0.95, 0.99]  # Example DP discount factors to try

# Hyperparameter grids for ML models
RF_PARAM_GRID = {
    'n_estimators': [50, 100],
    'max_depth': [3, 5]
}
GB_PARAM_GRID = {
    'n_estimators': [50, 100],
    'learning_rate': [0.05, 0.1],
    'max_depth': [3, 5]
}
CATBOOST_PARAM_GRID = {
    'iterations': [50, 100],
    'learning_rate': [0.05, 0.1],
    'depth': [3, 5]
}

###############################################################################
# 2. HELPER FUNCTIONS (DATA SPLITS, MODEL TRAINING, ETC.)
###############################################################################
def split_into_nplus1_groups(df, n=4, seed=0):
    """
    Shuffle patient IDs and split ~evenly into (n+1) groups: G1, G2, ..., G_{n+1}.
    Example usage: n=4 => 5 groups total.
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    # We'll cut into (n+1) roughly-equal slices
    # For n=4 => 5 slices
    N = len(unique_pids)
    group_size = int(np.ceil(N/(n+1)))
    
    groups = []
    start_idx = 0
    for i in range(n+1):
        end_idx = min(start_idx+group_size, N)
        group_pids = unique_pids[start_idx:end_idx]
        group_df   = df[df['patient_id'].isin(group_pids)].copy()
        groups.append(group_df)
        start_idx = end_idx
    return groups  # list of dataframes: G[0], G[1], ..., G[n]

def compute_auc_score(y_true, y_prob):
    """Compute AUC safely. If only one class, return 0.5."""
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

def train_and_select_best_model(X_train, y_train, X_val, y_val):
    """
    Trains multiple models (RandomForest, GB, CatBoost)
    over small hyperparam grids, picks best by AUC.
    
    Returns: (best_model, best_auc, best_model_name)
    """
    best_auc = -1.0
    best_model = None
    best_name  = None
    
    # 1) RandomForest
    for params in ParameterGrid(RF_PARAM_GRID):
        rf = RandomForestClassifier(random_state=0, **params)
        rf.fit(X_train, y_train)
        val_prob = rf.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = rf
            best_name  = f"RandomForest_{params}"
    
    # 2) GradientBoosting
    for params in ParameterGrid(GB_PARAM_GRID):
        gb = GradientBoostingClassifier(random_state=0, **params)
        gb.fit(X_train, y_train)
        val_prob = gb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = gb
            best_name  = f"GradientBoosting_{params}"
    
    # 3) CatBoost
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        cb = CatBoostClassifier(verbose=0, random_state=0, **params)
        cb.fit(X_train, y_train)
        val_prob = cb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = cb
            best_name  = f"CatBoost_{params}"
    
    return best_model, best_auc, best_name


###############################################################################
# 3. POLICY SIMULATION FUNCTIONS
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain:
      - patient_id
      - time
      - risk_score
      - label (0 or 1)
    
    policy_func(patient_rows) -> treat_time (int) or None
    
    Returns a dictionary of relevant metrics: total cost, precision, recall, etc.
    (Unconstrained treatment for all recommended patients.)
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

def simulate_policy_with_sick_capacity(df, policy_func, capacity_frac=0.5):
    """
    We enforce that at most (capacity_frac) fraction of the *sick* patients
    in this fold can be treated.

    Steps:
      1. Identify which patients are "recommended" for treatment by the policy_func.
      2. Separate recommended patients into "sick recommended" vs. "healthy recommended".
      3. Among the recommended *sick* patients, we can only treat up to
         floor(capacity_frac * total_sick_in_this_fold). We'll choose the top (by risk_score).
      4. We treat all recommended *healthy* patients with no limit.
      5. Everyone else is not treated (FN cost if sick, 0 cost if healthy).
    """
    results = []
    recommended_sick = []     # (pid, label=1, time_treated, risk_score)
    recommended_healthy = []  # (pid, label=0, time_treated, risk_score)
    
    # Count total sick in this fold
    all_sick_df = df[df['label']==1]
    num_sick = all_sick_df['patient_id'].nunique()
    # capacity = max number of sick we can treat
    capacity_num = int(np.floor(capacity_frac * num_sick)) if num_sick>0 else 0
    
    # 1) Check policy recommendation
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # not recommended => cost is assigned later
            results.append({
                'patient_id': pid,
                'label': label,
                'treated': 0,
                'treat_time': None,
                'cost': None,
                'tp': 0,
                'fp': 0
            })
        else:
            # recommended => store in recommended_sick or recommended_healthy
            row_t = grp[grp['time']==treat_time].iloc[0]
            recommended_risk = row_t['risk_score']
            if label == 1:
                recommended_sick.append((pid, label, treat_time, recommended_risk))
            else:
                recommended_healthy.append((pid, label, treat_time, recommended_risk))
    
    # 2) Sort recommended sick by descending risk_score
    recommended_sick.sort(key=lambda x: x[3], reverse=True)
    # 3) Actually treat only top capacity_num from recommended sick
    treat_sick_subset = recommended_sick[:capacity_num]
    not_treat_sick_subset = recommended_sick[capacity_num:]
    
    # 4) We treat ALL recommended healthy, no limit
    treat_healthy_subset = recommended_healthy
    
    # Build final result records for the treated subsets
    treat_results = []
    
    # 4a) SICK actually treated
    for (pid, label, ttime, rsk) in treat_sick_subset:
        cost_ = D_COST * ttime  # treat_time * D
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 1,
            'fp': 0
        })
    
    # 4b) HEALTHY actually treated
    for (pid, label, ttime, rsk) in treat_healthy_subset:
        cost_ = FP_COST  # healthy => false positive cost
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 0,
            'fp': 1
        })
    
    # 5) Build final result records for not-treated subsets
    not_treat_results = []
    
    # (a) SICK recommended but not treated (exceeds capacity)
    for (pid, label, ttime, rsk) in not_treat_sick_subset:
        # label==1 => sick
        cost_ = FN_COST
        not_treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 0,
            'treat_time': None,
            'cost': cost_,
            'tp': 0,
            'fp': 0
        })
    
    # (b) Those who were never recommended
    for row in results:
        if row['cost'] is None:
            # not recommended
            if row['label'] == 1:
                row['cost'] = FN_COST
            else:
                row['cost'] = 0
            not_treat_results.append(row)
    
    df_res = pd.DataFrame(treat_results + not_treat_results)
    
    # Compute final stats
    total_cost = df_res['cost'].sum()
    treated_df = df_res[df_res['treated']==1]
    
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# 4. BENCHMARK THRESHOLD-BASED POLICIES
###############################################################################
def constant_threshold_search(df, thresholds=None):
    """
    Search over possible constant thresholds in [0, 1].
    Returns best_thr, best_stats.
    """
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                if row['risk_score'] >= thr:
                    return int(row['time'])
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_constant_threshold_policy(thr):
    """Returns a policy function that treats as soon as risk_score >= thr."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    """
    We sample random threshold vectors across time_steps
    and pick the one with minimal cost on 'df'.
    """
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                t = int(row['time'])
                if t < time_steps and row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_dynamic_threshold_policy(thr_vec):
    """Returns a policy function using a time-dependent threshold vector thr_vec."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.6,2)):
    """
    Search policies of the form threshold(t) = clip(A*t + B, 0, 1).
    """
    best_A, best_B = None, None
    best_cost, best_stats = float('inf'), None
    
    for A in A_candidates:
        for B in B_candidates:
            def policy_func(patient_rows):
                for _, row in patient_rows.iterrows():
                    t = row['time']
                    thr = A*t + B
                    thr = np.clip(thr,0,1)
                    if row['risk_score'] >= thr:
                        return int(t)
                return None
            
            stats = simulate_policy(df, policy_func)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A,best_B), best_stats

def make_linear_threshold_policy(A,B):
    """Returns a policy function threshold(t) = clip(A*t + B, 0, 1)."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return t
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    """
    Policy: wait until the final time for each patient, treat if final risk >= thr.
    """
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            final_t = patient_rows['time'].max()
            final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
            if final_row['risk_score'] >= thr:
                return int(final_t)
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_wait_till_end_policy(thr):
    """Returns a policy function that treats only at final time if risk_score >= thr."""
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func


###############################################################################
# 5. DATA-DRIVEN DP (UNCONSTRAINED)
###############################################################################
def to_bucket(prob):
    """Simple function to map prob into a 5-bucket scale [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    """
    p_trans[t,b,b_next], p_sick[t,b]
    df_train has columns: patient_id, time, risk_bucket, label
    """
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows= grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom>0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                # if no data, assume self-transition
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, gamma * expected_future_if_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T
    # we only have valid decisions up to t = T-1
    for b in range(n_buckets):
        # cost if treat exactly at T-1:
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        # cost if never treated:
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from T-1 down to 0
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """
    Return a policy function that treats if pi[t,b]==1 at time t
    for the bucket b of the risk score.
    """
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# 6. ALGORITHM 3 (SEQUENTIAL OPTIMIZATION) - BUT FINAL FOLD HAS 50% CAP
###############################################################################
def run_algorithm3_sequential_optimization_with_cap(df_all, n=4, seed=0, capacity_frac=0.5):
    """
    Implements Algorithm 3 (Sequential Optimization) for hemorrhage diagnosis,
    unconstrained on G1..G_n, BUT with a capacity_frac limit on how many sick 
    can be treated in the final holdout fold (G_{n+1}).

    1) We split the data into (n+1) groups: G1,...,G_{n+1}.
         Let G_{n+1} be the final holdout/test set.
         G_cv = [G1..G_n] is used for cross-validation.
    2) Stage 1: Choose best ML hyperparams by CV (maximize AUC).
    3) Retrain final ML model on G1..G_n with best hyperparams.
    4) Stage 2: With ML model fixed, choose best policy hyperparams by CV (min cost).
    5) Evaluate final chosen methods on G_{n+1} with capacity_frac *SICK* limit.
    """
    # 0) Filter df_all if needed
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # 1) Split
    groups = split_into_nplus1_groups(df_all, n=n, seed=seed)
    G_test = groups[-1]
    G_cv   = groups[:-1]
    
    # Combine G1..G_n for final retraining
    G_cv_concat = pd.concat(G_cv, ignore_index=True)
    
    ###########################################################################
    # STAGE 1: ML hyperparam selection by CV (maximize AUC)
    ###########################################################################
    def cv_auc_for_ml(params, model_type):
        total_auc = 0.0
        for i_cv in range(n):
            # training = G_cv except G_cv[i_cv]
            train_df_list = [G_cv[j] for j in range(n) if j != i_cv]
            train_df = pd.concat(train_df_list, ignore_index=True)
            val_df   = G_cv[i_cv]
            
            X_train = train_df[['EIT','NIRS','EIS']].values
            y_train = train_df['label'].values
            
            X_val   = val_df[['EIT','NIRS','EIS']].values
            y_val   = val_df['label'].values
            
            if model_type == 'rf':
                mdl = RandomForestClassifier(random_state=0, **params)
            elif model_type == 'gb':
                mdl = GradientBoostingClassifier(random_state=0, **params)
            else:
                mdl = CatBoostClassifier(verbose=0, random_state=0, **params)
            
            mdl.fit(X_train, y_train)
            val_prob = mdl.predict_proba(X_val)[:,1]
            auc_val  = compute_auc_score(y_val, val_prob)
            total_auc += auc_val
        return total_auc
    
    best_overall_auc = -1.0
    best_overall_params = None
    best_model_type = None
    
    # 1) RF
    for params in ParameterGrid(RF_PARAM_GRID):
        sum_auc = cv_auc_for_ml(params, 'rf')
        if sum_auc > best_overall_auc:
            best_overall_auc = sum_auc
            best_overall_params = params
            best_model_type = 'rf'
    
    # 2) GB
    for params in ParameterGrid(GB_PARAM_GRID):
        sum_auc = cv_auc_for_ml(params, 'gb')
        if sum_auc > best_overall_auc:
            best_overall_auc = sum_auc
            best_overall_params = params
            best_model_type = 'gb'
    
    # 3) CatBoost
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        sum_auc = cv_auc_for_ml(params, 'cat')
        if sum_auc > best_overall_auc:
            best_overall_auc = sum_auc
            best_overall_params = params
            best_model_type = 'cat'
    
    # Retrain final ML on entire G_cv
    if best_model_type == 'rf':
        best_model = RandomForestClassifier(random_state=0, **best_overall_params)
    elif best_model_type == 'gb':
        best_model = GradientBoostingClassifier(random_state=0, **best_overall_params)
    else:
        best_model = CatBoostClassifier(verbose=0, random_state=0, **best_overall_params)
    
    X_cv_final = G_cv_concat[['EIT','NIRS','EIS']].values
    y_cv_final = G_cv_concat['label'].values
    best_model.fit(X_cv_final, y_cv_final)
    
    ###########################################################################
    # STAGE 2: With ML model fixed, pick best policy hyperparams by CV (min cost)
    ###########################################################################
    def evaluate_policy_cost_cv(policy_maker_func, param):
        """
        policy_maker_func is a callable that, given 'param', returns policy_func.
        We sum the cost across G1..G_n, each with risk scores from best_model.
        """
        total_cost = 0.0
        for i_cv in range(n):
            G_i = G_cv[i_cv].copy()
            X_i = G_i[['EIT','NIRS','EIS']].values
            prob_i = best_model.predict_proba(X_i)[:,1]
            G_i['risk_score'] = prob_i
            
            policy_func = policy_maker_func(param)
            stats_i = simulate_policy(G_i, policy_func)  # unconstrained in CV
            total_cost += stats_i['cost']
        return total_cost
    
    # (A) Constant threshold
    possible_thresholds = np.linspace(0,1,21)
    best_thr_const = None
    best_const_cost = float('inf')
    for thr_candidate in possible_thresholds:
        cost_cv = evaluate_policy_cost_cv(make_constant_threshold_policy, thr_candidate)
        if cost_cv < best_const_cost:
            best_const_cost = cost_cv
            best_thr_const  = thr_candidate
    
    # (B) Dynamic threshold (random search)
    rng = np.random.RandomState(42)
    threshold_candidates = [0.0,0.2,0.4,0.6,0.8,1.0]
    dynamic_param_candidates = []
    N_SAMPLES = 30
    for _ in range(N_SAMPLES):
        thr_vec = rng.choice(threshold_candidates, size=T_MAX-1)
        dynamic_param_candidates.append(tuple(thr_vec))
    
    best_thr_vec = None
    best_dyn_cost = float('inf')
    for candidate_vec in dynamic_param_candidates:
        cost_cv = evaluate_policy_cost_cv(make_dynamic_threshold_policy, candidate_vec)
        if cost_cv < best_dyn_cost:
            best_dyn_cost = cost_cv
            best_thr_vec  = candidate_vec
    
    # (C) Linear threshold
    A_candidates = np.linspace(-0.05, 0.01, 7)
    B_candidates = np.linspace(0,0.6,2)
    best_lin = None
    best_lin_cost = float('inf')
    for A_ in A_candidates:
        for B_ in B_candidates:
            cost_cv = 0.0
            for i_cv in range(n):
                G_i = G_cv[i_cv].copy()
                X_i = G_i[['EIT','NIRS','EIS']].values
                prob_i = best_model.predict_proba(X_i)[:,1]
                G_i['risk_score'] = prob_i
                
                policy_lin = make_linear_threshold_policy(A_, B_)
                stats_i = simulate_policy(G_i, policy_lin)
                cost_cv += stats_i['cost']
            if cost_cv < best_lin_cost:
                best_lin_cost = cost_cv
                best_lin = (A_, B_)
    A_lin, B_lin = best_lin
    
    # (D) Wait-till-end threshold
    best_thr_wte = None
    best_wte_cost= float('inf')
    for thr_candidate in possible_thresholds:
        cost_cv = evaluate_policy_cost_cv(make_wait_till_end_policy, thr_candidate)
        if cost_cv < best_wte_cost:
            best_wte_cost   = cost_cv
            best_thr_wte    = thr_candidate
    
    # (E) Data-driven DP: pick best gamma by cross-validation
    def evaluate_dp_gamma_cv(gamma_val):
        total_cost = 0.0
        for i_cv in range(n):
            # train transition model on all folds except i_cv
            train_df_list = [G_cv[j] for j in range(n) if j != i_cv]
            train_df = pd.concat(train_df_list, ignore_index=True)
            
            X_train = train_df[['EIT','NIRS','EIS']].values
            prob_train = best_model.predict_proba(X_train)[:,1]
            train_df['risk_score'] = prob_train
            train_df['risk_bucket'] = train_df['risk_score'].apply(to_bucket)
            
            p_trans, p_sick = estimate_transition_and_sick_probs(train_df, T=T_MAX, n_buckets=5)
            V_temp, pi_temp = train_data_driven_dp_unconstrained(
                p_trans, p_sick,
                FP=FP_COST, FN=FN_COST, D=D_COST,
                gamma=gamma_val, T=T_MAX
            )
            dp_policy_temp = make_dp_policy(V_temp, pi_temp, T=T_MAX)
            
            # Evaluate on G_cv[i_cv]
            G_i = G_cv[i_cv].copy()
            X_i = G_i[['EIT','NIRS','EIS']].values
            prob_i = best_model.predict_proba(X_i)[:,1]
            G_i['risk_score'] = prob_i
            G_i['risk_bucket'] = G_i['risk_score'].apply(to_bucket)
            
            stats_i = simulate_policy(G_i, dp_policy_temp)  # unconstrained in CV
            total_cost += stats_i['cost']
        return total_cost
    
    best_gamma = None
    best_dp_cost= float('inf')
    for gamma_ in GAMMA_CANDIDATES:
        cost_cv_gamma = evaluate_dp_gamma_cv(gamma_)
        if cost_cv_gamma < best_dp_cost:
            best_dp_cost = cost_cv_gamma
            best_gamma   = gamma_
    
    # Train final DP on G_cv_concat with best_gamma
    G_cv_concat_dp = G_cv_concat.copy()
    X_dp = G_cv_concat_dp[['EIT','NIRS','EIS']].values
    prob_dp = best_model.predict_proba(X_dp)[:,1]
    G_cv_concat_dp['risk_score'] = prob_dp
    G_cv_concat_dp['risk_bucket'] = G_cv_concat_dp['risk_score'].apply(to_bucket)
    
    p_trans_final, p_sick_final = estimate_transition_and_sick_probs(
        G_cv_concat_dp, T=T_MAX, n_buckets=5
    )
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans_final, p_sick_final,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=best_gamma, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    ###########################################################################
    # 7) Evaluate all final chosen methods on G_{n+1} with 50% capacity
    ###########################################################################
    G_test_eval = G_test.copy()
    X_test = G_test_eval[['EIT','NIRS','EIS']].values
    prob_test = best_model.predict_proba(X_test)[:,1]
    G_test_eval['risk_score'] = prob_test
    
    # We now simulate with the capacity constraint
    policy_const = make_constant_threshold_policy(best_thr_const)
    stats_const  = simulate_policy_with_sick_capacity(G_test_eval, policy_const, capacity_frac=capacity_frac)
    
    policy_dyn = make_dynamic_threshold_policy(best_thr_vec)
    stats_dyn  = simulate_policy_with_sick_capacity(G_test_eval, policy_dyn, capacity_frac=capacity_frac)
    
    policy_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin  = simulate_policy_with_sick_capacity(G_test_eval, policy_lin, capacity_frac=capacity_frac)
    
    policy_wte = make_wait_till_end_policy(best_thr_wte)
    stats_wte  = simulate_policy_with_sick_capacity(G_test_eval, policy_wte, capacity_frac=capacity_frac)
    
    # For DP, need bucket info
    G_test_eval_dp = G_test_eval.copy()
    G_test_eval_dp['risk_bucket'] = G_test_eval_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy_with_sick_capacity(G_test_eval_dp, dp_policy_final, capacity_frac=capacity_frac)
    
    # Build final table
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'DP-based Policy'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table

###############################################################################
# 7. MAIN
###############################################################################
def main():
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    final_table = run_algorithm3_sequential_optimization_with_cap(
        df_all, 
        n=4,            
        seed=4,
        capacity_frac=0.5  
    )
    
    print("\n=== ALGORITHM 3 (SEQUENTIAL OPTIMIZATION) WITH 50% CAP ON FINAL FOLD SICK PATIENTS ===")
    print(final_table.to_string(index=False))

if __name__ == "__main__":
    main()


=== ALGORITHM 3 (SEQUENTIAL OPTIMIZATION) WITH 50% CAP ON FINAL FOLD SICK PATIENTS ===
             Method  Cost  Precision (%)  Recall (%)  Avg Treat Time
 Constant Threshold  1004      37.142857   48.148148        2.400000
Dynamic Threshold-R  1179      23.214286   48.148148        2.857143
   Linear Threshold  1099      27.083333   48.148148        1.020833
      Wait Till End   960     100.000000   48.148148       20.000000
    DP-based Policy   831     100.000000   48.148148       10.076923


In [3]:
"""
SEQUENTIAL OPTIMIZATION (ALGORITHM 3) FOR HEMORRHAGE DIAGNOSIS & TREATMENT
WITH A 50% CAP ON SICK PATIENTS, RUN MULTIPLE REPLICATES AND REPORT MEAN/STD.

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier


###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 1
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)
GAMMA_CANDIDATES = [0.95, 0.99]  # Example DP discount factors to try

# Hyperparameter grids for ML models
RF_PARAM_GRID = {
    'n_estimators': [50, 100],
    'max_depth': [3, 5]
}
GB_PARAM_GRID = {
    'n_estimators': [50, 100],
    'learning_rate': [0.05, 0.1],
    'max_depth': [3, 5]
}
CATBOOST_PARAM_GRID = {
    'iterations': [50, 100],
    'learning_rate': [0.05, 0.1],
    'depth': [3, 5]
}

###############################################################################
# 2. HELPER FUNCTIONS (DATA SPLITS, MODEL TRAINING, ETC.)
###############################################################################
def split_into_nplus1_groups(df, n=4, seed=0):
    """
    Shuffle patient IDs and split ~evenly into (n+1) groups: G1, G2, ..., G_{n+1}.
    Example usage: n=4 => 5 groups total.
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    N = len(unique_pids)
    num_groups = n + 1
    group_size = int(np.ceil(N / num_groups))
    
    groups = []
    start_idx = 0
    for i in range(num_groups):
        end_idx = min(start_idx + group_size, N)
        group_pids = unique_pids[start_idx:end_idx]
        group_df   = df[df['patient_id'].isin(group_pids)].copy()
        groups.append(group_df)
        start_idx = end_idx
    return groups

def compute_auc_score(y_true, y_prob):
    """Compute AUC safely. If only one class, return 0.5."""
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

def train_and_select_best_model(X_train, y_train, X_val, y_val):
    """
    Trains multiple models (RandomForest, GB, CatBoost)
    over small hyperparam grids, picks best by AUC.
    
    Returns: (best_model, best_auc, best_model_name)
    """
    best_auc = -1.0
    best_model = None
    best_name  = None
    
    # 1) RandomForest
    for params in ParameterGrid(RF_PARAM_GRID):
        rf = RandomForestClassifier(random_state=0, **params)
        rf.fit(X_train, y_train)
        val_prob = rf.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = rf
            best_name  = f"RandomForest_{params}"
    
    # 2) GradientBoosting
    for params in ParameterGrid(GB_PARAM_GRID):
        gb = GradientBoostingClassifier(random_state=0, **params)
        gb.fit(X_train, y_train)
        val_prob = gb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = gb
            best_name  = f"GradientBoosting_{params}"
    
    # 3) CatBoost
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        cb = CatBoostClassifier(verbose=0, random_state=0, **params)
        cb.fit(X_train, y_train)
        val_prob = cb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = cb
            best_name  = f"CatBoost_{params}"
    
    return best_model, best_auc, best_name


###############################################################################
# 3. POLICY SIMULATION FUNCTIONS
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain:
      - patient_id
      - time
      - risk_score
      - label (0 or 1)
    
    policy_func(patient_rows) -> treat_time (int) or None
    
    Returns a dictionary of relevant metrics: total cost, precision, recall, etc.
    (Unconstrained treatment for all recommended patients.)
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

def simulate_policy_with_sick_capacity(df, policy_func, capacity_frac=0.5):
    """
    We enforce that at most (capacity_frac) fraction of the *sick* patients
    in this fold can be treated.

    Steps:
      1. Identify which patients are "recommended" for treatment by the policy_func.
      2. Separate recommended patients into "sick recommended" vs. "healthy recommended".
      3. Among the recommended *sick* patients, we can only treat up to
         floor(capacity_frac * total_sick_in_this_fold). We'll choose the top (by risk_score).
      4. We treat all recommended *healthy* patients with no limit.
      5. Everyone else is not treated (FN cost if sick, 0 cost if healthy).
    """
    results = []
    recommended_sick = []     # (pid, label=1, time_treated, risk_score)
    recommended_healthy = []  # (pid, label=0, time_treated, risk_score)
    
    # Count total sick in this fold
    all_sick_df = df[df['label']==1]
    num_sick = all_sick_df['patient_id'].nunique()
    capacity_num = int(np.floor(capacity_frac * num_sick)) if num_sick > 0 else 0
    
    # Step 1) Check policy recommendation
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # not recommended => cost is assigned later
            results.append({
                'patient_id': pid,
                'label': label,
                'treated': 0,
                'treat_time': None,
                'cost': None,
                'tp': 0,
                'fp': 0
            })
        else:
            # recommended => store in recommended_sick or recommended_healthy
            row_t = grp[grp['time']==treat_time].iloc[0]
            recommended_risk = row_t['risk_score']
            if label == 1:
                recommended_sick.append((pid, label, treat_time, recommended_risk))
            else:
                recommended_healthy.append((pid, label, treat_time, recommended_risk))
    
    # Step 2) Sort recommended sick by descending risk_score
    recommended_sick.sort(key=lambda x: x[3], reverse=True)
    # Step 3) Actually treat only top capacity_num from recommended sick
    treat_sick_subset = recommended_sick[:capacity_num]
    not_treat_sick_subset = recommended_sick[capacity_num:]
    
    # Step 4) We treat ALL recommended healthy, no limit
    treat_healthy_subset = recommended_healthy
    
    # Build final result records for the treated subsets
    treat_results = []
    # 4a) SICK actually treated
    for (pid, label, ttime, rsk) in treat_sick_subset:
        cost_ = D_COST * ttime  # treat_time * D
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 1,
            'fp': 0
        })
    # 4b) HEALTHY actually treated
    for (pid, label, ttime, rsk) in treat_healthy_subset:
        cost_ = FP_COST  # healthy => false positive cost
        treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 1,
            'treat_time': ttime,
            'cost': cost_,
            'tp': 0,
            'fp': 1
        })
    
    # Step 5) Build final result records for not-treated subsets
    not_treat_results = []
    # (a) SICK recommended but not treated (exceeds capacity)
    for (pid, label, ttime, rsk) in not_treat_sick_subset:
        cost_ = FN_COST  # label==1 => sick
        not_treat_results.append({
            'patient_id': pid,
            'label': label,
            'treated': 0,
            'treat_time': None,
            'cost': cost_,
            'tp': 0,
            'fp': 0
        })
    # (b) Those who were never recommended
    for row in results:
        if row['cost'] is None:
            # not recommended
            if row['label'] == 1:
                row['cost'] = FN_COST
            else:
                row['cost'] = 0
            not_treat_results.append(row)
    
    df_res = pd.DataFrame(treat_results + not_treat_results)
    
    # Compute final stats
    total_cost = df_res['cost'].sum()
    treated_df = df_res[df_res['treated']==1]
    
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }


###############################################################################
# 4. BENCHMARK THRESHOLD-BASED POLICIES
###############################################################################
def make_constant_threshold_policy(thr):
    """Returns a policy function that treats as soon as risk_score >= thr."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def make_dynamic_threshold_policy(thr_vec):
    """Returns a policy function using a time-dependent threshold vector thr_vec."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
        return None
    return policy_func

def make_linear_threshold_policy(A,B):
    """Returns a policy function threshold(t) = clip(A*t + B, 0, 1)."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return t
        return None
    return policy_func

def make_wait_till_end_policy(thr):
    """Returns a policy function that treats only at final time if risk_score >= thr."""
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func


###############################################################################
# 5. DATA-DRIVEN DP (UNCONSTRAINED)
###############################################################################
def to_bucket(prob):
    """Simple function to map prob into a 5-bucket scale [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    """
    p_trans[t,b,b_next], p_sick[t,b]
    df_train has columns: patient_id, time, risk_bucket, label
    """
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows = grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                # Count transitions only if t_next = t+1 (consecutive time)
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom>0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                # if no data, assume "self-transition"
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, gamma * expected_future_if_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T => no more decisions
    for b in range(n_buckets):
        # If we ended up treating exactly at T-1, cost could be:
        #    cost if sick: D*(T-1)
        #    cost if healthy: FP
        # If not treated at all: cost if sick: FN; else 0
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """
    Return a policy function that treats if pi[t,b]==1 at time t
    for the bucket b of the risk score.
    """
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# 6. ALGORITHM 3 (SEQUENTIAL OPTIMIZATION) - BUT FINAL FOLD HAS 50% CAP
###############################################################################
def run_algorithm3_sequential_optimization_with_cap(df_all, n=4, seed=0, capacity_frac=0.5):
    """
    Implements Algorithm 3 (Sequential Optimization) for hemorrhage diagnosis,
    unconstrained on G1..G_n, BUT with a capacity_frac limit on how many sick 
    can be treated in the final holdout fold (G_{n+1}).

    1) We split the data into (n+1) groups: G1,...,G_{n+1}.
         Let G_{n+1} be the final holdout/test set.
         G_cv = [G1..G_n] is used for cross-validation.
    2) Stage 1: Choose best ML hyperparams by CV (maximize AUC).
    3) Retrain final ML model on G1..G_n with best hyperparams.
    4) Stage 2: With ML model fixed, choose best policy hyperparams by CV (min cost).
    5) Evaluate final chosen methods on G_{n+1} with capacity_frac *SICK* limit.
    """
    # Only consider time < T_MAX
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # 1) Split into n+1 groups
    groups = split_into_nplus1_groups(df_all, n=n, seed=seed)
    G_test = groups[-1]
    G_cv   = groups[:-1]
    
    # Combine G1..G_n for final retraining
    G_cv_concat = pd.concat(G_cv, ignore_index=True)
    
    ###########################################################################
    # STAGE 1: ML hyperparam selection by CV (maximize AUC)
    ###########################################################################
    def cv_auc_for_ml(params, model_type):
        total_auc = 0.0
        for i_cv in range(n):
            train_df_list = [G_cv[j] for j in range(n) if j != i_cv]
            train_df = pd.concat(train_df_list, ignore_index=True)
            val_df   = G_cv[i_cv]
            
            X_train = train_df[['EIT','NIRS','EIS']].values
            y_train = train_df['label'].values
            
            X_val   = val_df[['EIT','NIRS','EIS']].values
            y_val   = val_df['label'].values
            
            if model_type == 'rf':
                mdl = RandomForestClassifier(random_state=0, **params)
            elif model_type == 'gb':
                mdl = GradientBoostingClassifier(random_state=0, **params)
            else:
                mdl = CatBoostClassifier(verbose=0, random_state=0, **params)
            
            mdl.fit(X_train, y_train)
            val_prob = mdl.predict_proba(X_val)[:,1]
            auc_val  = compute_auc_score(y_val, val_prob)
            total_auc += auc_val
        return total_auc
    
    best_overall_auc = -1.0
    best_overall_params = None
    best_model_type = None
    
    # 1) RandomForest
    for params in ParameterGrid(RF_PARAM_GRID):
        sum_auc = cv_auc_for_ml(params, 'rf')
        if sum_auc > best_overall_auc:
            best_overall_auc = sum_auc
            best_overall_params = params
            best_model_type = 'rf'
    # 2) GradientBoosting
    for params in ParameterGrid(GB_PARAM_GRID):
        sum_auc = cv_auc_for_ml(params, 'gb')
        if sum_auc > best_overall_auc:
            best_overall_auc = sum_auc
            best_overall_params = params
            best_model_type = 'gb'
    # 3) CatBoost
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        sum_auc = cv_auc_for_ml(params, 'cat')
        if sum_auc > best_overall_auc:
            best_overall_auc = sum_auc
            best_overall_params = params
            best_model_type = 'cat'
    
    # Retrain final ML on entire G_cv
    if best_model_type == 'rf':
        best_model = RandomForestClassifier(random_state=0, **best_overall_params)
    elif best_model_type == 'gb':
        best_model = GradientBoostingClassifier(random_state=0, **best_overall_params)
    else:
        best_model = CatBoostClassifier(verbose=0, random_state=0, **best_overall_params)
    
    X_cv_final = G_cv_concat[['EIT','NIRS','EIS']].values
    y_cv_final = G_cv_concat['label'].values
    best_model.fit(X_cv_final, y_cv_final)
    
    ###########################################################################
    # STAGE 2: With ML model fixed, pick best policy hyperparams by CV (min cost)
    ###########################################################################
    def evaluate_policy_cost_cv(policy_maker_func, param):
        total_cost = 0.0
        for i_cv in range(n):
            G_i = G_cv[i_cv].copy()
            X_i = G_i[['EIT','NIRS','EIS']].values
            prob_i = best_model.predict_proba(X_i)[:,1]
            G_i['risk_score'] = prob_i
            
            policy_func = policy_maker_func(param)
            stats_i = simulate_policy(G_i, policy_func)  # unconstrained in CV
            total_cost += stats_i['cost']
        return total_cost
    
    # (A) Constant threshold
    possible_thresholds = np.linspace(0,1,21)
    best_thr_const = None
    best_const_cost = float('inf')
    for thr_candidate in possible_thresholds:
        cost_cv = evaluate_policy_cost_cv(make_constant_threshold_policy, thr_candidate)
        if cost_cv < best_const_cost:
            best_const_cost = cost_cv
            best_thr_const  = thr_candidate
    
    # (B) Dynamic threshold (random search)
    rng = np.random.RandomState(42)
    threshold_candidates = [0.0,0.2,0.4,0.6,0.8,1.0]
    dynamic_param_candidates = []
    N_SAMPLES = 30
    for _ in range(N_SAMPLES):
        thr_vec = rng.choice(threshold_candidates, size=T_MAX-1)
        dynamic_param_candidates.append(tuple(thr_vec))
    
    best_thr_vec = None
    best_dyn_cost = float('inf')
    for candidate_vec in dynamic_param_candidates:
        cost_cv = evaluate_policy_cost_cv(make_dynamic_threshold_policy, candidate_vec)
        if cost_cv < best_dyn_cost:
            best_dyn_cost = cost_cv
            best_thr_vec  = candidate_vec
    
    # (C) Linear threshold
    A_candidates = np.linspace(-0.05, 0.01, 7)
    B_candidates = np.linspace(0,0.6,2)
    best_lin = None
    best_lin_cost = float('inf')
    for A_ in A_candidates:
        for B_ in B_candidates:
            cost_cv_fold = 0.0
            for i_cv in range(n):
                G_i = G_cv[i_cv].copy()
                X_i = G_i[['EIT','NIRS','EIS']].values
                prob_i = best_model.predict_proba(X_i)[:,1]
                G_i['risk_score'] = prob_i
                
                policy_lin = make_linear_threshold_policy(A_, B_)
                stats_i = simulate_policy(G_i, policy_lin)
                cost_cv_fold += stats_i['cost']
            if cost_cv_fold < best_lin_cost:
                best_lin_cost = cost_cv_fold
                best_lin = (A_, B_)
    A_lin, B_lin = best_lin
    
    # (D) Wait-till-end threshold
    best_thr_wte = None
    best_wte_cost= float('inf')
    for thr_candidate in possible_thresholds:
        cost_cv = evaluate_policy_cost_cv(make_wait_till_end_policy, thr_candidate)
        if cost_cv < best_wte_cost:
            best_wte_cost   = cost_cv
            best_thr_wte    = thr_candidate
    
    # (E) Data-driven DP: pick best gamma by cross-validation
    def evaluate_dp_gamma_cv(gamma_val):
        total_cost = 0.0
        for i_cv in range(n):
            # train transition model on all folds except i_cv
            train_df_list = [G_cv[j] for j in range(n) if j != i_cv]
            train_df = pd.concat(train_df_list, ignore_index=True)
            
            X_train = train_df[['EIT','NIRS','EIS']].values
            prob_train = best_model.predict_proba(X_train)[:,1]
            train_df['risk_score'] = prob_train
            train_df['risk_bucket'] = train_df['risk_score'].apply(to_bucket)
            
            p_trans, p_sick = estimate_transition_and_sick_probs(train_df, T=T_MAX, n_buckets=5)
            V_temp, pi_temp = train_data_driven_dp_unconstrained(
                p_trans, p_sick,
                FP=FP_COST, FN=FN_COST, D=D_COST,
                gamma=gamma_val, T=T_MAX
            )
            dp_policy_temp = make_dp_policy(V_temp, pi_temp, T=T_MAX)
            
            # Evaluate on hold-out fold i_cv
            G_i = G_cv[i_cv].copy()
            X_i = G_i[['EIT','NIRS','EIS']].values
            prob_i = best_model.predict_proba(X_i)[:,1]
            G_i['risk_score'] = prob_i
            G_i['risk_bucket'] = G_i['risk_score'].apply(to_bucket)
            
            stats_i = simulate_policy(G_i, dp_policy_temp)
            total_cost += stats_i['cost']
        return total_cost
    
    best_gamma = None
    best_dp_cost= float('inf')
    for gamma_ in GAMMA_CANDIDATES:
        cost_cv_gamma = evaluate_dp_gamma_cv(gamma_)
        if cost_cv_gamma < best_dp_cost:
            best_dp_cost = cost_cv_gamma
            best_gamma   = gamma_
    
    # Train final DP on G_cv_concat with best_gamma
    G_cv_concat_dp = G_cv_concat.copy()
    X_dp = G_cv_concat_dp[['EIT','NIRS','EIS']].values
    prob_dp = best_model.predict_proba(X_dp)[:,1]
    G_cv_concat_dp['risk_score'] = prob_dp
    G_cv_concat_dp['risk_bucket'] = G_cv_concat_dp['risk_score'].apply(to_bucket)
    
    p_trans_final, p_sick_final = estimate_transition_and_sick_probs(
        G_cv_concat_dp, T=T_MAX, n_buckets=5
    )
    V_final, pi_final = train_data_driven_dp_unconstrained(
        p_trans_final, p_sick_final,
        FP=FP_COST, FN=FN_COST, D=D_COST,
        gamma=best_gamma, T=T_MAX
    )
    dp_policy_final = make_dp_policy(V_final, pi_final, T=T_MAX)
    
    ###########################################################################
    # 7) Evaluate all final chosen methods on G_{n+1} with capacity constraint
    ###########################################################################
    G_test_eval = G_test.copy()
    X_test = G_test_eval[['EIT','NIRS','EIS']].values
    prob_test = best_model.predict_proba(X_test)[:,1]
    G_test_eval['risk_score'] = prob_test
    
    # Simulate with capacity constraint
    policy_const = make_constant_threshold_policy(best_thr_const)
    stats_const  = simulate_policy_with_sick_capacity(G_test_eval, policy_const, capacity_frac=capacity_frac)
    
    policy_dyn = make_dynamic_threshold_policy(best_thr_vec)
    stats_dyn  = simulate_policy_with_sick_capacity(G_test_eval, policy_dyn, capacity_frac=capacity_frac)
    
    policy_lin = make_linear_threshold_policy(A_lin, B_lin)
    stats_lin  = simulate_policy_with_sick_capacity(G_test_eval, policy_lin, capacity_frac=capacity_frac)
    
    policy_wte = make_wait_till_end_policy(best_thr_wte)
    stats_wte  = simulate_policy_with_sick_capacity(G_test_eval, policy_wte, capacity_frac=capacity_frac)
    
    # DP policy
    G_test_eval_dp = G_test_eval.copy()
    G_test_eval_dp['risk_bucket'] = G_test_eval_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy_with_sick_capacity(G_test_eval_dp, dp_policy_final, capacity_frac=capacity_frac)
    
    # Build final table
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'DP-based Policy'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table


###############################################################################
# 7. MAIN - RUN REPLICATES AND REPORT MEAN/STD
###############################################################################
def main():
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    NUM_REPLICATES = 30
    all_results = []  # will store results (as DataFrames) for each replicate
    
    for rep in range(NUM_REPLICATES):
        seed = 412 + rep
        print(f"\n=== Running replicate {rep+1}/{NUM_REPLICATES} (seed={seed}) ===")
        
        final_table = run_algorithm3_sequential_optimization_with_cap(
            df_all,
            n=4,  # number of CV folds
            seed=seed,
            capacity_frac=0.5
        )
        # Add replicate info (so we can track across runs)
        final_table['Replicate'] = rep
        all_results.append(final_table)
        
        # Optional: Print or inspect final_table for each run
        print(final_table.to_string(index=False))
    
    # Combine all replicate results
    combined_df = pd.concat(all_results, ignore_index=True)
    
    # Now compute mean & std across replicates, grouped by method
    summary = combined_df.groupby('Method').agg(
        mean_cost=('Cost','mean'),
        std_cost=('Cost','std'),
        mean_precision=('Precision (%)','mean'),
        std_precision=('Precision (%)','std'),
        mean_recall=('Recall (%)','mean'),
        std_recall=('Recall (%)','std'),
        mean_ttime=('Avg Treat Time','mean'),
        std_ttime=('Avg Treat Time','std')
    ).reset_index()
    
    print("\n=== SUMMARY ACROSS ALL REPLICATES ===")
    print(summary.to_string(index=False))


if __name__ == "__main__":
    main()


=== Running replicate 1/30 (seed=412) ===
             Method  Cost  Precision (%)  Recall (%)  Avg Treat Time  Replicate
 Constant Threshold   692      34.615385        50.0        2.769231          0
Dynamic Threshold-R   955      16.363636        50.0        3.363636          0
   Linear Threshold   850      20.454545        50.0        1.181818          0
      Wait Till End   660      75.000000        50.0       20.000000          0
    DP-based Policy   534     100.000000        50.0        9.333333          0

=== Running replicate 2/30 (seed=413) ===
             Method  Cost  Precision (%)  Recall (%)  Avg Treat Time  Replicate
 Constant Threshold   849      33.333333        50.0        2.393939          1
Dynamic Threshold-R  1042      19.642857        50.0        2.910714          1
   Linear Threshold   974      22.916667        50.0        1.125000          1
      Wait Till End   770     100.000000        50.0       20.000000          1
    DP-based Policy   651      91.