In [25]:
"""
STANDARD VALIDATION (ALGORITHM 0) FOR UNCONSTRAINED HEMORRHAGE DIAGNOSIS & TREATMENT

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 3
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)
GAMMA_CANDIDATES = [1]  # Example DP discount factors to try

# For demonstration, we'll use a small hyperparameter grid for each ML model.
RF_PARAM_GRID = {
    'n_estimators': [50, 100],
    'max_depth': [3, 5]
}
GB_PARAM_GRID = {
    'n_estimators': [50, 100],
    'learning_rate': [0.05, 0.1],
    'max_depth': [3, 5]
}
CATBOOST_PARAM_GRID = {
    'iterations': [50, 100],
    'learning_rate': [0.05, 0.1],
    'depth': [3, 5]
}

###############################################################################
# 2. HELPER FUNCTIONS
###############################################################################
def split_into_four_groups(df, seed=0):
    """
    Shuffle patient IDs and split ~evenly into four groups: G1, G2, G3, G4.
    Used for Algorithm 0 (Standard Validation).
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    i1 = int(0.25 * n)
    i2 = int(0.50 * n)
    i3 = int(0.75 * n)
    
    G1_pids = unique_pids[: i1]
    G2_pids = unique_pids[i1 : i2]
    G3_pids = unique_pids[i2 : i3]
    G4_pids = unique_pids[i3 : ]
    
    G1 = df[df['patient_id'].isin(G1_pids)].copy()
    G2 = df[df['patient_id'].isin(G2_pids)].copy()
    G3 = df[df['patient_id'].isin(G3_pids)].copy()
    G4 = df[df['patient_id'].isin(G4_pids)].copy()
    
    return G1, G2, G3, G4

def filter_by_group(df, pid_set):
    return df[df['patient_id'].isin(pid_set)].copy()

def compute_auc_score(y_true, y_prob):
    """Compute AUC safely. If only one class, return 0.5."""
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

def train_and_select_best_model(X_train, y_train, X_val, y_val):
    """
    Trains multiple models (RandomForest, GB, CatBoost)
    over small hyperparam grids, picks best by AUC.
    
    Returns: (best_model, best_auc, best_model_name)
    """
    best_auc = -1.0
    best_model = None
    best_name  = None
    
    # 1) RandomForest
    for params in ParameterGrid(RF_PARAM_GRID):
        rf = RandomForestClassifier(random_state=0, **params)
        rf.fit(X_train, y_train)
        val_prob = rf.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = rf
            best_name  = f"RandomForest_{params}"
    
    # 2) GradientBoosting
    for params in ParameterGrid(GB_PARAM_GRID):
        gb = GradientBoostingClassifier(random_state=0, **params)
        gb.fit(X_train, y_train)
        val_prob = gb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = gb
            best_name  = f"GradientBoosting_{params}"
    
    # 3) CatBoost
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        cb = CatBoostClassifier(verbose=0, random_state=0, **params)
        cb.fit(X_train, y_train)
        val_prob = cb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = cb
            best_name  = f"CatBoost_{params}"
    
    return best_model, best_auc, best_name

###############################################################################
# 3. SIMULATE POLICY (Unconstrained)
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain:
      - patient_id
      - time
      - risk_score
      - label (0 or 1)
    
    policy_func(patient_rows) -> treat_time (int) or None
    
    Return dict of cost, precision, recall, avg_treatment_time
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                # cost = D * treat_time
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# 4. BENCHMARK THRESHOLD-BASED POLICIES
###############################################################################
def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            # treat at first time we see risk_score >= thr
            for _, row in patient_rows.iterrows():
                if row['risk_score'] >= thr:
                    return int(row['time'])
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    """
    We just sample random threshold vectors across time_steps
    and pick the one with minimal cost on 'df'.
    """
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                t = int(row['time'])
                if t < time_steps and row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.8,7)):
    best_A, best_B = None, None
    best_cost, best_stats = float('inf'), None
    
    for A in A_candidates:
        for B in B_candidates:
            def policy_func(patient_rows):
                for _, row in patient_rows.iterrows():
                    t = row['time']
                    thr = A*t + B
                    thr = np.clip(thr,0,1)
                    if row['risk_score'] >= thr:
                        return int(t)
                return None
            
            stats = simulate_policy(df, policy_func)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A,best_B), best_stats

def make_linear_threshold_policy(A,B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            final_t = patient_rows['time'].max()
            final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
            if final_row['risk_score'] >= thr:
                return int(final_t)
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func


###############################################################################
# 5. DATA-DRIVEN DP (UNCONSTRAINED)
###############################################################################
def to_bucket(prob):
    """Simple function to map prob into a 5-bucket scale [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    """
    p_trans[t,b,b_next], p_sick[t,b]
    df_train has columns: patient_id, time, risk_bucket, label
    """
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows= grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom>0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                # if no data, assume self-transition to avoid NaNs
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T
    # (We'll interpret T as T_MAX, so valid times are 0..T-1)
    # Here we compute final cost if we haven't treated by time T
    for b in range(n_buckets):
        # if we treat at T-1:
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        # if we never treat:
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from T-1 down to 0
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """Return a policy function that treats if pi[t,b]==1 at time t."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# 6. ALGORITHM 0 (STANDARD VALIDATION)
###############################################################################
def run_algorithm0_unconstrained(df_all, seed=0):
    """
    1) Split df_all -> G1, G2, G3, G4
    2) ML hyperparam search on (G1->G2)
    3) Retrain best ML on G1+G2
    4) (UPDATED) On G3, we tune:
         - DP discount factor gamma
         - threshold-based parameters
       Then pick the best from G3
    5) Evaluate all tuned policies on G4
    """
    # Step 1: Split
    G1, G2, G3, G4 = split_into_four_groups(df_all, seed=seed)

    
    # Step 2: ML hyperparam search on (G1->G2)
    X_train = G1[['EIT','NIRS','EIS']].values
    y_train = G1['label'].values
    
    X_val   = G2[['EIT','NIRS','EIS']].values
    y_val   = G2['label'].values
    
    best_model, best_auc, best_name = train_and_select_best_model(
        X_train, y_train, X_val, y_val
    )
    
    # Step 3: Retrain best ML on G1+G2
    G12 = pd.concat([G1, G2], ignore_index=True)
    X_12 = G12[['EIT','NIRS','EIS']].values
    y_12 = G12['label'].values
    
    best_model.fit(X_12, y_12)
    
    # Prepare G3 with final ML risk scores
    G3 = G3.copy()
    X_3 = G3[['EIT','NIRS','EIS']].values
    prob_3 = best_model.predict_proba(X_3)[:,1]
    G3['risk_score'] = prob_3
    
    # Prepare G12 with final ML risk scores (needed for DP transition estimation)
    G12 = G12.copy()
    prob_12 = best_model.predict_proba(G12[['EIT','NIRS','EIS']])[:,1]
    G12['risk_score'] = prob_12
    
    # (A) ------------------ TUNE THRESHOLD-BASED METHODS ON G3 ------------------ #
    # 1) Constant threshold
    thr_const_g3, _ = constant_threshold_search(G3)
    
    # 2) Dynamic threshold random
    thr_vec_g3, _ = dynamic_threshold_random_search(G3, time_steps=T_MAX)
    
    # 3) Linear threshold
    (A_lin_g3, B_lin_g3), _ = linear_threshold_search(G3)
    
    # 4) Wait-till-end threshold
    thr_wte_g3, _ = wait_till_end_search(G3)
    
    # (B) ------------------ TUNE DP DISCOUNT FACTOR ON G3 ----------------------- #
    G12['risk_bucket'] = G12['risk_score'].apply(to_bucket)
    p_trans, p_sick = estimate_transition_and_sick_probs(G12, T=T_MAX, n_buckets=5)
    
    best_gamma = None
    best_cost_dp = float('inf')
    best_V = None
    best_pi= None
    
    # Bucket G3 for DP simulation
    G3_dp = G3.copy()
    G3_dp['risk_bucket'] = G3_dp['risk_score'].apply(to_bucket)
    
    for gamma_ in GAMMA_CANDIDATES:
        V_temp, pi_temp = train_data_driven_dp_unconstrained(
            p_trans, p_sick,
            FP=FP_COST, FN=FN_COST, D=D_COST,
            gamma=gamma_, T=T_MAX
        )
        dp_policy_temp = make_dp_policy(V_temp, pi_temp, T=T_MAX)
        stats_temp = simulate_policy(G3_dp, dp_policy_temp)
        
        if stats_temp['cost'] < best_cost_dp:
            best_cost_dp = stats_temp['cost']
            best_gamma   = gamma_
            best_V       = V_temp
            best_pi      = pi_temp
    
    
    # Step 5: ------------------ EVALUATE ON G4 ------------------ #
    G4 = G4.copy()
    prob_4 = best_model.predict_proba(G4[['EIT','NIRS','EIS']])[:,1]
    G4['risk_score'] = prob_4
    
    # (A) Evaluate THRESHOLD-BASED policies with best params found on G3
    #     We do NOT re-search on G4, to avoid cheating.
    # 1) Constant threshold
    policy_const = make_constant_threshold_policy(thr_const_g3)
    stats_const  = simulate_policy(G4, policy_const)
    
    # 2) Dynamic threshold
    policy_dyn = make_dynamic_threshold_policy(thr_vec_g3)
    stats_dyn  = simulate_policy(G4, policy_dyn)
    
    # 3) Linear threshold
    policy_lin = make_linear_threshold_policy(A_lin_g3, B_lin_g3)
    stats_lin  = simulate_policy(G4, policy_lin)
    
    # 4) Wait-till-end threshold
    policy_wte = make_wait_till_end_policy(thr_wte_g3)
    stats_wte  = simulate_policy(G4, policy_wte)
    
    # (B) Evaluate DP-based policy
    dp_policy_final = make_dp_policy(best_V, best_pi, T=T_MAX)
    G4_dp = G4.copy()
    G4_dp['risk_bucket'] = G4_dp['risk_score'].apply(to_bucket)
    stats_dp = simulate_policy(G4_dp, dp_policy_final)
    
    # Build final table of results
    table = pd.DataFrame({
        'Method': [
            f'Constant Threshold',
            f'Dynamic Threshold-R',
            f'Linear Threshold ',
            f'Wait Till End ',
            f'Dynamic Threshold-DP'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table


###############################################################################
# 7. MAIN
###############################################################################
def main():
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    
    # If needed, filter df_all to time < T_MAX:
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # Check required columns:
    required = {'patient_id','time','EIT','NIRS','EIS','label'}
    if not required.issubset(df_all.columns):
        raise ValueError(f"Your CSV must have columns at least: {required}. Found: {df_all.columns}")
    
    # Run Algorithm 0
    final_table = run_algorithm0_unconstrained(df_all, seed=4)
    
    print("\n=== ALGORITHM 0 (STANDARD VALIDATION) RESULTS (Unconstrained) ===")
    print(final_table.to_string(index=False))

if __name__ == "__main__":
    main()


=== ALGORITHM 0 (STANDARD VALIDATION) RESULTS (Unconstrained) ===
              Method  Cost  Precision (%)  Recall (%)  Avg Treat Time
  Constant Threshold  1057      53.731343       100.0        3.731343
 Dynamic Threshold-R   825      70.588235       100.0        8.019608
   Linear Threshold   1024      47.368421       100.0        2.750000
      Wait Till End   1800       0.000000         0.0        0.000000
Dynamic Threshold-DP   855      75.000000       100.0        6.291667


In [64]:
"""
STANDARD VALIDATION (ALGORITHM 0) FOR UNCONSTRAINED HEMORRHAGE DIAGNOSIS & TREATMENT
Replicated multiple times, then aggregates results as mean ± std.

Requirements:
  pip install numpy pandas scikit-learn catboost
"""

import numpy as np
import pandas as pd
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# Sklearn models, metrics, etc.
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import ParameterGrid
# CatBoost
from catboost import CatBoostClassifier

###############################################################################
# 1. GLOBAL PARAMETERS
###############################################################################
FP_COST = 10
FN_COST = 50
D_COST  = 1
T_MAX   = 21   # maximum discrete time steps (0..T_MAX-1)
GAMMA_CANDIDATES = [0.95, 0.99]  # Example DP discount factors to try

# For demonstration, we'll use a small hyperparameter grid for each ML model.
RF_PARAM_GRID = {
    'n_estimators': [50, 100],
    'max_depth': [3, 5]
}
GB_PARAM_GRID = {
    'n_estimators': [50, 100],
    'learning_rate': [0.05, 0.1],
    'max_depth': [3, 5]
}
CATBOOST_PARAM_GRID = {
    'iterations': [50, 100],
    'learning_rate': [0.05, 0.1],
    'depth': [3, 5]
}

###############################################################################
# 2. HELPER FUNCTIONS
###############################################################################
def split_into_four_groups(df, seed=0):
    """
    Shuffle patient IDs and split ~evenly into four groups: G1, G2, G3, G4.
    Used for Algorithm 0 (Standard Validation).
    """
    rng = np.random.RandomState(seed)
    unique_pids = df['patient_id'].unique()
    rng.shuffle(unique_pids)
    
    n = len(unique_pids)
    i1 = int(0.25 * n)
    i2 = int(0.50 * n)
    i3 = int(0.75 * n)
    
    G1_pids = unique_pids[: i1]
    G2_pids = unique_pids[i1 : i2]
    G3_pids = unique_pids[i2 : i3]
    G4_pids = unique_pids[i3 : ]
    
    G1 = df[df['patient_id'].isin(G1_pids)].copy()
    G2 = df[df['patient_id'].isin(G2_pids)].copy()
    G3 = df[df['patient_id'].isin(G3_pids)].copy()
    G4 = df[df['patient_id'].isin(G4_pids)].copy()
    
    return G1, G2, G3, G4

def filter_by_group(df, pid_set):
    return df[df['patient_id'].isin(pid_set)].copy()

def compute_auc_score(y_true, y_prob):
    """Compute AUC safely. If only one class, return 0.5."""
    if len(np.unique(y_true)) < 2:
        return 0.5
    return roc_auc_score(y_true, y_prob)

def train_and_select_best_model(X_train, y_train, X_val, y_val):
    """
    Trains multiple models (RandomForest, GB, CatBoost)
    over small hyperparam grids, picks best by AUC.
    
    Returns: (best_model, best_auc, best_model_name)
    """
    best_auc = -1.0
    best_model = None
    best_name  = None
    
    # 1) RandomForest
    for params in ParameterGrid(RF_PARAM_GRID):
        rf = RandomForestClassifier(random_state=0, **params)
        rf.fit(X_train, y_train)
        val_prob = rf.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = rf
            best_name  = f"RandomForest_{params}"
    
    # 2) GradientBoosting
    for params in ParameterGrid(GB_PARAM_GRID):
        gb = GradientBoostingClassifier(random_state=0, **params)
        gb.fit(X_train, y_train)
        val_prob = gb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = gb
            best_name  = f"GradientBoosting_{params}"
    
    # 3) CatBoost
    for params in ParameterGrid(CATBOOST_PARAM_GRID):
        cb = CatBoostClassifier(verbose=0, random_state=0, **params)
        cb.fit(X_train, y_train)
        val_prob = cb.predict_proba(X_val)[:,1]
        auc_val  = compute_auc_score(y_val, val_prob)
        if auc_val > best_auc:
            best_auc   = auc_val
            best_model = cb
            best_name  = f"CatBoost_{params}"
    
    return best_model, best_auc, best_name

###############################################################################
# 3. SIMULATE POLICY (Unconstrained)
###############################################################################
def simulate_policy(df, policy_func):
    """
    df must contain:
      - patient_id
      - time
      - risk_score
      - label (0 or 1)
    
    policy_func(patient_rows) -> treat_time (int) or None
    
    Return dict of cost, precision, recall, avg_treatment_time
    """
    results = []
    
    for pid, grp in df.groupby('patient_id'):
        grp = grp.sort_values('time')
        label = grp['label'].iloc[0]
        
        treat_time = policy_func(grp)
        
        if treat_time is None:
            # never treated
            if label == 1:
                cost = FN_COST
                tp   = 0
            else:
                cost = 0
                tp   = 0
            fp = 0
            treat_flag = 0
            ttime = None
        else:
            treat_flag = 1
            if label == 1:
                cost = D_COST * treat_time
                tp   = 1
                fp   = 0
            else:
                cost = FP_COST
                tp   = 0
                fp   = 1
            ttime = treat_time
        
        results.append({
            'patient_id': pid,
            'label': label,
            'treated': treat_flag,
            'treat_time': ttime,
            'cost': cost,
            'tp': tp,
            'fp': fp
        })
    
    df_res     = pd.DataFrame(results)
    total_cost = df_res['cost'].sum()
    
    treated_df = df_res[df_res['treated']==1]
    tp_sum = treated_df['tp'].sum()
    fp_sum = treated_df['fp'].sum()
    
    if len(treated_df) > 0:
        precision = tp_sum / (tp_sum + fp_sum)
    else:
        precision = 0.0
    
    sick_df = df_res[df_res['label']==1]
    total_sick = len(sick_df)
    if total_sick > 0:
        recall = tp_sum / total_sick
    else:
        recall = 0.0
    
    if len(treated_df) > 0:
        valid_tt = treated_df['treat_time'].dropna()
        avg_tt   = valid_tt.mean() if len(valid_tt) > 0 else 0.0
    else:
        avg_tt = 0.0
    
    return {
        'cost': total_cost,
        'precision': precision,
        'recall': recall,
        'avg_treatment_time': avg_tt
    }

###############################################################################
# 4. BENCHMARK THRESHOLD-BASED POLICIES
###############################################################################
def constant_threshold_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            # treat at first time we see risk_score >= thr
            for _, row in patient_rows.iterrows():
                if row['risk_score'] >= thr:
                    return int(row['time'])
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_constant_threshold_policy(thr):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            if row['risk_score'] >= thr:
                return int(row['time'])
        return None
    return policy_func

def dynamic_threshold_random_search(df,
                                    time_steps=20,
                                    threshold_candidates=[0.0,0.2,0.4,0.6,0.8,1.0],
                                    n_samples=200,
                                    seed=0):
    """
    We just sample random threshold vectors across time_steps
    and pick the one with minimal cost on 'df'.
    """
    rng = np.random.RandomState(seed)
    best_vec = None
    best_cost= float('inf')
    best_stats=None
    
    for _ in range(n_samples):
        thr_vec = rng.choice(threshold_candidates, size=time_steps)
        
        def policy_func(patient_rows):
            for _, row in patient_rows.iterrows():
                t = int(row['time'])
                if t < time_steps and row['risk_score'] >= thr_vec[t]:
                    return t
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_vec  = thr_vec.copy()
            best_stats= stats
    return best_vec, best_stats

def make_dynamic_threshold_policy(thr_vec):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < len(thr_vec):
                if row['risk_score'] >= thr_vec[t]:
                    return t
        return None
    return policy_func

def linear_threshold_search(df,
                            A_candidates=np.linspace(-0.05, 0.01, 7),
                            B_candidates=np.linspace(0,0.6,2)):
    best_A, best_B = None, None
    best_cost, best_stats = float('inf'), None
    
    for A in A_candidates:
        for B in B_candidates:
            def policy_func(patient_rows):
                for _, row in patient_rows.iterrows():
                    t = row['time']
                    thr = A*t + B
                    thr = np.clip(thr,0,1)
                    if row['risk_score'] >= thr:
                        return int(t)
                return None
            
            stats = simulate_policy(df, policy_func)
            if stats['cost'] < best_cost:
                best_cost = stats['cost']
                best_A    = A
                best_B    = B
                best_stats= stats
    return (best_A,best_B), best_stats

def make_linear_threshold_policy(A,B):
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = row['time']
            thr = A*t + B
            thr = np.clip(thr,0,1)
            if row['risk_score'] >= thr:
                return int(t)
        return None
    return policy_func

def wait_till_end_search(df, thresholds=None):
    if thresholds is None:
        thresholds = np.linspace(0,1,21)
    best_thr, best_cost, best_stats = None, float('inf'), None
    
    for thr in thresholds:
        def policy_func(patient_rows):
            final_t = patient_rows['time'].max()
            final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
            if final_row['risk_score'] >= thr:
                return int(final_t)
            return None
        
        stats = simulate_policy(df, policy_func)
        if stats['cost'] < best_cost:
            best_cost = stats['cost']
            best_thr  = thr
            best_stats= stats
    return best_thr, best_stats

def make_wait_till_end_policy(thr):
    def policy_func(patient_rows):
        final_t = patient_rows['time'].max()
        final_row = patient_rows[patient_rows['time']==final_t].iloc[0]
        if final_row['risk_score'] >= thr:
            return int(final_t)
        return None
    return policy_func


###############################################################################
# 5. DATA-DRIVEN DP (UNCONSTRAINED)
###############################################################################
def to_bucket(prob):
    """Simple function to map prob into a 5-bucket scale [0..4]."""
    b = int(prob * 5)
    return min(b, 4)

def estimate_transition_and_sick_probs(df_train, T=20, n_buckets=5):
    """
    p_trans[t,b,b_next], p_sick[t,b]
    df_train has columns: patient_id, time, risk_bucket, label
    """
    transition_counts = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    bucket_counts     = np.zeros((T, n_buckets), dtype=float)
    sick_counts       = np.zeros((T, n_buckets), dtype=float)
    
    df_sorted = df_train.sort_values(['patient_id','time'])
    for pid, grp in df_sorted.groupby('patient_id'):
        grp = grp.sort_values('time')
        rows= grp.to_dict('records')
        
        for i, row in enumerate(rows):
            t = int(row['time'])
            b = int(row['risk_bucket'])
            lbl = row['label']
            
            if t < T:
                bucket_counts[t,b] += 1
                sick_counts[t,b]   += lbl
            
            if i < len(rows)-1:
                nxt = rows[i+1]
                t_next = nxt['time']
                b_next = nxt['risk_bucket']
                if (t_next == t+1) and (t < T-1):
                    transition_counts[t,b,b_next] += 1
    
    p_trans = np.zeros((T-1, n_buckets, n_buckets), dtype=float)
    for t_ in range(T-1):
        for b_ in range(n_buckets):
            denom = transition_counts[t_,b_,:].sum()
            if denom>0:
                p_trans[t_,b_,:] = transition_counts[t_,b_,:] / denom
            else:
                # if no data, assume self-transition to avoid NaNs
                p_trans[t_,b_,b_] = 1.0
    
    p_sick = np.zeros((T, n_buckets), dtype=float)
    for t_ in range(T):
        for b_ in range(n_buckets):
            denom = bucket_counts[t_,b_]
            if denom>0:
                p_sick[t_,b_] = sick_counts[t_,b_] / denom
            else:
                p_sick[t_,b_] = 0.0
    return p_trans, p_sick

def train_data_driven_dp_unconstrained(p_trans, p_sick, 
                                       FP=10, FN=50, D=1, gamma=0.99, T=20):
    """
    Standard DP for unconstrained scenario:
      V[t,b] = min( cost_treat_now, cost_wait )
    """
    n_buckets = p_sick.shape[1]
    V = np.zeros((T+1, n_buckets))
    pi_ = np.zeros((T, n_buckets), dtype=int)
    
    # boundary at t=T
    # (We'll interpret T as T_MAX, so valid times are 0..T-1)
    # Here we compute final cost if we haven't treated by time T
    for b in range(n_buckets):
        # if we treat at T-1:
        cost_treat   = p_sick[T-1,b]*(D*(T-1)) + (1-p_sick[T-1,b])*FP
        # if we never treat:
        cost_notreat = p_sick[T-1,b]*FN
        V[T,b] = min(cost_treat, cost_notreat)
    
    # fill from T-1 down to 0
    for t in reversed(range(T)):
        for b in range(n_buckets):
            # treat now
            cost_treat = p_sick[t,b]*(D*t) + (1-p_sick[t,b])*FP
            
            # wait
            if t == T-1:
                cost_wait = gamma * V[T,b]
            else:
                exp_future = 0.0
                for b_next in range(n_buckets):
                    exp_future += p_trans[t,b,b_next]*V[t+1,b_next]
                cost_wait = gamma * exp_future
            
            if cost_treat <= cost_wait:
                V[t,b]   = cost_treat
                pi_[t,b] = 1
            else:
                V[t,b]   = cost_wait
                pi_[t,b] = 0
    return V, pi_

def make_dp_policy(V, pi_, T=20):
    """Return a policy function that treats if pi[t,b]==1 at time t."""
    def policy_func(patient_rows):
        for _, row in patient_rows.iterrows():
            t = int(row['time'])
            if t < T:
                b = int(row['risk_bucket'])
                if pi_[t,b] == 1:
                    return t
        return None
    return policy_func

###############################################################################
# 6. ALGORITHM 0 (STANDARD VALIDATION)
###############################################################################
def run_algorithm0_unconstrained(df_all, seed=0):
    """
    1) Split df_all -> G1, G2, G3, G4
    2) ML hyperparam search on (G1->G2)
    3) Retrain best ML on G1+G2
    4) (UPDATED) On G3, we tune:
         - DP discount factor gamma
         - threshold-based parameters
       Then pick the best from G3
    5) Evaluate all tuned policies on G4
    Returns a Pandas DataFrame with columns:
        Method, Cost, Precision (%), Recall (%), Avg Treat Time
    """
    # Step 1: Split
    G1, G2, G3, G4 = split_into_four_groups(df_all, seed=seed)

    # Step 2: ML hyperparam search on (G1->G2)
    X_train = G1[['EIT','NIRS','EIS']].values
    y_train = G1['label'].values
    
    X_val   = G2[['EIT','NIRS','EIS']].values
    y_val   = G2['label'].values
    
    best_model, best_auc, best_name = train_and_select_best_model(
        X_train, y_train, X_val, y_val
    )
    
    # Step 3: Retrain best ML on G1+G2
    G12 = pd.concat([G1, G2], ignore_index=True)
    X_12 = G12[['EIT','NIRS','EIS']].values
    y_12 = G12['label'].values
    best_model.fit(X_12, y_12)
    
    # Prepare G3 with final ML risk scores
    G3 = G3.copy()
    X_3 = G3[['EIT','NIRS','EIS']].values
    prob_3 = best_model.predict_proba(X_3)[:,1]
    G3['risk_score'] = prob_3
    
    # Prepare G12 with final ML risk scores (needed for DP transition estimation)
    G12 = G12.copy()
    prob_12 = best_model.predict_proba(G12[['EIT','NIRS','EIS']])[:,1]
    G12['risk_score'] = prob_12
    
    # (A) TUNE THRESHOLD-BASED METHODS ON G3
    # 1) Constant threshold
    thr_const_g3, _ = constant_threshold_search(G3)
    
    # 2) Dynamic threshold random
    thr_vec_g3, _ = dynamic_threshold_random_search(G3, time_steps=T_MAX)
    
    # 3) Linear threshold
    (A_lin_g3, B_lin_g3), _ = linear_threshold_search(G3)
    
    # 4) Wait-till-end threshold
    thr_wte_g3, _ = wait_till_end_search(G3)
    
    # (B) TUNE DP DISCOUNT FACTOR ON G3
    G12['risk_bucket'] = G12['risk_score'].apply(to_bucket)
    p_trans, p_sick = estimate_transition_and_sick_probs(G12, T=T_MAX, n_buckets=5)
    
    best_gamma = None
    best_cost_dp = float('inf')
    best_V = None
    best_pi= None
    
    G3_dp = G3.copy()
    G3_dp['risk_bucket'] = G3_dp['risk_score'].apply(to_bucket)
    
    for gamma_ in GAMMA_CANDIDATES:
        V_temp, pi_temp = train_data_driven_dp_unconstrained(
            p_trans, p_sick,
            FP=FP_COST, FN=FN_COST, D=D_COST,
            gamma=gamma_, T=T_MAX
        )
        dp_policy_temp = make_dp_policy(V_temp, pi_temp, T=T_MAX)
        stats_temp = simulate_policy(G3_dp, dp_policy_temp)
        
        if stats_temp['cost'] < best_cost_dp:
            best_cost_dp = stats_temp['cost']
            best_gamma   = gamma_
            best_V       = V_temp
            best_pi      = pi_temp
    
    # Step 5: EVALUATE ON G4
    G4 = G4.copy()
    prob_4 = best_model.predict_proba(G4[['EIT','NIRS','EIS']])[:,1]
    G4['risk_score'] = prob_4
    
    # (A) Evaluate THRESHOLD-BASED policies
    policy_const = make_constant_threshold_policy(thr_const_g3)
    stats_const  = simulate_policy(G4, policy_const)
    
    policy_dyn = make_dynamic_threshold_policy(thr_vec_g3)
    stats_dyn  = simulate_policy(G4, policy_dyn)
    
    policy_lin = make_linear_threshold_policy(A_lin_g3, B_lin_g3)
    stats_lin  = simulate_policy(G4, policy_lin)
    
    policy_wte = make_wait_till_end_policy(thr_wte_g3)
    stats_wte  = simulate_policy(G4, policy_wte)
    
    # (B) Evaluate DP-based policy
    G4_dp = G4.copy()
    G4_dp['risk_bucket'] = G4_dp['risk_score'].apply(to_bucket)
    dp_policy_final = make_dp_policy(best_V, best_pi, T=T_MAX)
    stats_dp = simulate_policy(G4_dp, dp_policy_final)
    
    # Build final table of results
    table = pd.DataFrame({
        'Method': [
            'Constant Threshold',
            'Dynamic Threshold-R',
            'Linear Threshold',
            'Wait Till End',
            'Dynamic Threshold-DP'
        ],
        'Cost': [
            stats_const['cost'],
            stats_dyn['cost'],
            stats_lin['cost'],
            stats_wte['cost'],
            stats_dp['cost']
        ],
        'Precision (%)': [
            100*stats_const['precision'],
            100*stats_dyn['precision'],
            100*stats_lin['precision'],
            100*stats_wte['precision'],
            100*stats_dp['precision']
        ],
        'Recall (%)': [
            100*stats_const['recall'],
            100*stats_dyn['recall'],
            100*stats_lin['recall'],
            100*stats_wte['recall'],
            100*stats_dp['recall']
        ],
        'Avg Treat Time': [
            stats_const['avg_treatment_time'],
            stats_dyn['avg_treatment_time'],
            stats_lin['avg_treatment_time'],
            stats_wte['avg_treatment_time'],
            stats_dp['avg_treatment_time']
        ]
    })
    
    return table

###############################################################################
# 7. MAIN - RUN MULTIPLE REPLICATES AND AGGREGATE
###############################################################################
def main():
    # Number of replicates you want to run
    NUM_REPLICATES = 30

    # Read data
    df_all = pd.read_csv("synthetic_patients_with_features.csv")
    # Filter to time < T_MAX (if needed)
    df_all = df_all[df_all['time'] < T_MAX].copy()
    
    # Check required columns
    required = {'patient_id','time','EIT','NIRS','EIS','label'}
    if not required.issubset(df_all.columns):
        raise ValueError(
            f"Your CSV must have columns at least: {required}. Found: {df_all.columns}"
        )
    
    # Run the algorithm multiple times, collect results
    all_tables = []
    for rep in range(NUM_REPLICATES):
        # Use different seed each time to randomize grouping
        seed = 412 + rep
        print(f"\n=== Running replicate {rep+1}/{NUM_REPLICATES} (seed={seed}) ===")
        result_table = run_algorithm0_unconstrained(df_all, seed=seed)
        all_tables.append(result_table)
    
    # Now we aggregate across replicates
    # We'll merge them by 'Method' and compute mean ± std for each numeric column.
    # Each table has 5 rows (one per method). We'll stack them.
    combined_df = pd.concat(all_tables, ignore_index=True)
    
    # Group by 'Method' to compute mean/std of columns
    grouped = combined_df.groupby('Method')
    
    final_rows = []
    methods = grouped.groups.keys()
    
    for method in methods:
        group_data = grouped.get_group(method)
        cost_mean = group_data['Cost'].mean()
        cost_std  = group_data['Cost'].std()
        
        prec_mean = group_data['Precision (%)'].mean()
        prec_std  = group_data['Precision (%)'].std()
        
        rec_mean  = group_data['Recall (%)'].mean()
        rec_std   = group_data['Recall (%)'].std()
        
        time_mean = group_data['Avg Treat Time'].mean()
        time_std  = group_data['Avg Treat Time'].std()
        
        final_rows.append({
            'Method': method,
            'Cost': f"{cost_mean:.2f} ± {cost_std:.2f}",
            'Precision (%)': f"{prec_mean:.2f} ± {prec_std:.2f}",
            'Recall (%)': f"{rec_mean:.2f} ± {rec_std:.2f}",
            'Avg Treat Time': f"{time_mean:.2f} ± {time_std:.2f}"
        })
    
    final_df = pd.DataFrame(final_rows)
    
    print("\n=== ALGORITHM 0 (STANDARD VALIDATION) RESULTS (Unconstrained) ===")
    print(f"Ran {NUM_REPLICATES} replicates. Aggregated (mean ± std) results:")
    print(final_df.to_string(index=False))

if __name__ == "__main__":
    main()


=== Running replicate 1/30 (seed=412) ===

=== Running replicate 2/30 (seed=413) ===

=== Running replicate 3/30 (seed=414) ===

=== Running replicate 4/30 (seed=415) ===

=== Running replicate 5/30 (seed=416) ===

=== Running replicate 6/30 (seed=417) ===

=== Running replicate 7/30 (seed=418) ===

=== Running replicate 8/30 (seed=419) ===

=== Running replicate 9/30 (seed=420) ===

=== Running replicate 10/30 (seed=421) ===

=== Running replicate 11/30 (seed=422) ===

=== Running replicate 12/30 (seed=423) ===

=== Running replicate 13/30 (seed=424) ===

=== Running replicate 14/30 (seed=425) ===

=== Running replicate 15/30 (seed=426) ===

=== Running replicate 16/30 (seed=427) ===

=== Running replicate 17/30 (seed=428) ===

=== Running replicate 18/30 (seed=429) ===

=== Running replicate 19/30 (seed=430) ===

=== Running replicate 20/30 (seed=431) ===

=== Running replicate 21/30 (seed=432) ===

=== Running replicate 22/30 (seed=433) ===

=== Running replicate 23/30 (seed=434) =