In [None]:
###############################################
# File: adp_with_random_forest_markov_two_matrices.py
#
# Illustrative end-to-end code for:
#   - Algorithm 0 (Standard Validation)
#   - Random Forest ML model
#   - Approximate Dynamic Programming (ADP) with capacity constraints
#   - Time-varying risk buckets via TWO Markov transition matrices 
#       (one for healthy, one for sick)
#   - Final metrics on G4: cost, average treatment time, recall, precision
###############################################

import numpy as np
import pandas as pd

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim

###############################################################################
# STEP 0: Data Import
# Assume we have a CSV with 600 synthetic patients:
#   columns = [patient_id, time, EIT, NIRS, EIS, label, ...]
###############################################################################
df = pd.read_csv("synthetic_patients_with_features.csv")

RANDOM_SEED_1 = 42
RANDOM_SEED_2 = 999
RANDOM_SEED_3 = 123

###############################################################################
# STEP 1: Data Partition -> (G1, G2, G3, G4)  (Algorithm 0)
###############################################################################
# We'll do a simple 50-50 for (G12 vs G34), then 50-50 again for (G1 vs G2),
# and 50-50 for (G3 vs G4). 
###############################################################################

G12, G34 = train_test_split(df, test_size=0.50, random_state=RANDOM_SEED_1, stratify=df['label'])
G1, G2 = train_test_split(G12, test_size=0.50, random_state=RANDOM_SEED_2, stratify=G12['label'])
G3, G4 = train_test_split(G34, test_size=0.50, random_state=RANDOM_SEED_3, stratify=G34['label'])

print(f"G1 size: {len(G1)}")
print(f"G2 size: {len(G2)}")
print(f"G3 size: {len(G3)}")
print(f"G4 size: {len(G4)}")

###############################################################################
# STEP 2: Train Random Forest on G1, pick best hyperparams by AUC on G2
###############################################################################
def compute_features(subdf):
    """
    Basic feature extraction. 
    We'll just use [EIT, NIRS, EIS, time] as features.
    Adjust as desired.
    """
    feats = subdf[['EIT','NIRS','EIS','time']].values
    return feats

def prepare_data_for_ml(df_input):
    X_ = compute_features(df_input)
    y_ = df_input['label'].values
    return X_, y_

X1, y1 = prepare_data_for_ml(G1)
X2, y2 = prepare_data_for_ml(G2)

param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [5, 10, None],
    'min_samples_leaf': [1, 5]
}
best_auc = -np.inf
best_params = None

for n_est in param_grid['n_estimators']:
    for md in param_grid['max_depth']:
        for msl in param_grid['min_samples_leaf']:
            rf_model = RandomForestClassifier(n_estimators=n_est,
                                              max_depth=md,
                                              min_samples_leaf=msl,
                                              random_state=0)
            rf_model.fit(X1, y1)
            preds_proba = rf_model.predict_proba(X2)[:,1]
            auc_val = roc_auc_score(y2, preds_proba)
            if auc_val > best_auc:
                best_auc = auc_val
                best_params = (n_est, md, msl)

print(f"[ML] Best AUC on G2 = {best_auc:.4f}, best params = {best_params}")

X12 = np.vstack([X1, X2])
y12 = np.hstack([y1, y2])

final_rf_model = RandomForestClassifier(n_estimators=best_params[0],
                                        max_depth=best_params[1],
                                        min_samples_leaf=best_params[2],
                                        random_state=0)
final_rf_model.fit(X12, y12)

###############################################################################
# STEP 3: Generate risk scores for G3 (and discretize into buckets)
###############################################################################
def get_risk_scores(df_input, model):
    X_, _ = prepare_data_for_ml(df_input)
    risk_scores = model.predict_proba(X_)[:,1]
    df_out = df_input.copy()
    df_out['risk_score'] = risk_scores
    return df_out

def bucket_risk(score):
    if score < 0.2: return 0
    elif score < 0.4: return 1
    elif score < 0.6: return 2
    elif score < 0.8: return 3
    else: return 4

G3_scored = get_risk_scores(G3, final_rf_model)
G3_scored['bucket'] = G3_scored['risk_score'].apply(bucket_risk)

###############################################################################
# STEP 4: Train ADP (Q-learning) on G3
#
# We define *two* Markov transition matrices, one for healthy (label=0)
# and one for sick (label=1). 
# Then we proceed with the standard Q-learning setup.
###############################################################################

# Example cost parameters:
FP = 10   # cost for false positive
FN = 50   # cost for false negative
D  = 1    # cost for per-step delay (if sick)
gamma = 0.99
T_max = 20

# Suppose these are your two Markov transitions:
# (A) For healthy patients (label=0):
transition_matrix_healthy = np.array([
    [0.60, 0.25, 0.10, 0.05, 0.00],  # from bucket 0
    [0.10, 0.50, 0.30, 0.10, 0.00],  # from bucket 1
    [0.05, 0.10, 0.50, 0.25, 0.10],  # from bucket 2
    [0.00, 0.05, 0.20, 0.50, 0.25],  # from bucket 3
    [0.00, 0.00, 0.10, 0.20, 0.70]   # from bucket 4
])

# (B) For sick patients (label=1):
transition_matrix_sick = np.array([
    [0.30, 0.30, 0.20, 0.15, 0.05],  # from bucket 0 (sick rarely in 0, but example)
    [0.10, 0.30, 0.35, 0.20, 0.05],
    [0.05, 0.05, 0.50, 0.30, 0.10],
    [0.00, 0.05, 0.20, 0.40, 0.35],
    [0.00, 0.00, 0.05, 0.25, 0.70]
])

num_sick_g3 = G3_scored['label'].sum()
# For demonstration: capacity is half of the # of sick. 
# (Interpretation: we only can treat up to N_c patients total over the entire horizon.)
N_c = int(0.5 * num_sick_g3) if num_sick_g3 > 0 else 0
N_c = max(0, N_c)

###############################################################################
# (A) Define the environment with TWO Markov transitions
###############################################################################
class HemorrhageEnvAggregatedWithTransitions:
    """
    An environment that tracks aggregated bucket counts, capacity usage,
    and Markov transitions of buckets each time step for UNTREATED patients.
    
    Once a patient is 'treated', they remain out of the evolution process.
    We incorporate false positives, false negatives, and per-step delay cost.

 
    """
    def __init__(self, df_patients, capacity, max_time=20,
                 transition_matrix_healthy=None,
                 transition_matrix_sick=None):
        """
        df_patients: Must have columns [patient_id, bucket, label, ...]
        capacity: integer capacity (across entire horizon in this code)
        max_time: total steps
        transition_matrix_healthy: 5x5 array for healthy transitions
        transition_matrix_sick:    5x5 array for sick transitions
        """
        self.df = df_patients.copy()
        self.capacity = capacity
        self.max_time = max_time

        # Two separate transition matrices
        self.transition_matrix_healthy = transition_matrix_healthy
        self.transition_matrix_sick = transition_matrix_sick

        # Additional columns
        self.patients = self.df[['patient_id','bucket','label']].copy()
        self.patients['treated'] = 0
        self.patients['treat_time'] = -1

        self.t = 0
        self.done = False

    def reset(self):
        self.patients['treated'] = 0
        self.patients['treat_time'] = -1
        self.t = 0
        self.done = False
        return self._get_aggregated_state()

    def _get_aggregated_state(self):
        # How many have been treated so far?
        treated_count = self.patients['treated'].sum()
        cap_rem = max(0, self.capacity - treated_count)
        bucket_counts = np.zeros(5, dtype=int)
        # Count how many untreated in each bucket:
        for b in range(5):
            bucket_counts[b] = ((self.patients['bucket'] == b)
                                & (self.patients['treated'] == 0)).sum()
        # State is (count in bucket0..bucket4, cap_rem, time)
        return (bucket_counts[0],
                bucket_counts[1],
                bucket_counts[2],
                bucket_counts[3],
                bucket_counts[4],
                cap_rem,
                self.t)

    def step(self, action):
        """
        1) Treat 'action' patients from highest bucket downward (subject to capacity).
        2) Compute immediate cost:
           - FP for newly treated healthy
           - Per-step delay for sick still untreated
           - FN for sick never treated if final step
        3) Markov transition for all remaining untreated patients 
           using the appropriate matrix (healthy vs. sick).
        4) Return new aggregated state, cost, done_flag
        """
        global FP, FN, D  

        if self.done:
            # If already done, no changes
            return self._get_aggregated_state(), 0.0, True

        # Sort all untreated by bucket desc (treat the highest-risk buckets first)
        df_untreated = self.patients[self.patients['treated'] == 0].copy()
        df_untreated = df_untreated.sort_values('bucket', ascending=False)

        # capacity used so far
        treated_so_far = self.patients['treated'].sum()
        cap_rem = max(0, self.capacity - treated_so_far)
        # we can't treat more than capacity left or # of untreated
        num_to_treat = min(action, cap_rem, len(df_untreated))

        to_treat_indices = []
        if num_to_treat > 0:
            to_treat_indices = df_untreated.iloc[:num_to_treat].index
            self.patients.loc[to_treat_indices, 'treated'] = 1
            self.patients.loc[to_treat_indices, 'treat_time'] = self.t

        newly_treated = self.patients.loc[to_treat_indices]

        # cost for false positives
        cost_fp = 0.0
        for _, row in newly_treated.iterrows():
            if row['label'] == 0:
                cost_fp += FP

        # delay cost for sick still untreated
        df_sick_untreated = self.patients[(self.patients['label'] == 1)
                                          & (self.patients['treated'] == 0)]
        cost_delay = D * len(df_sick_untreated)

        # if final step => false negatives for all sick untreated
        done_next = False
        cost_fn = 0.0
        if self.t == (self.max_time - 1):
            done_next = True
            cost_fn = FN * len(df_sick_untreated)

        immediate_cost = cost_fp + cost_delay + cost_fn

        # increment time
        self.t += 1
        if self.t >= self.max_time:
            done_next = True

        # Markov transitions for all STILL untreated patients
        if not done_next:
            untreated_inds = self.patients[self.patients['treated'] == 0].index
            for idx in untreated_inds:
                current_b = self.patients.loc[idx, 'bucket']
                lab = self.patients.loc[idx, 'label']
                # pick appropriate transition matrix
                if lab == 0:
                    probs = self.transition_matrix_healthy[current_b, :]
                else:
                    probs = self.transition_matrix_sick[current_b, :]
                new_b = np.random.choice(np.arange(5), p=probs)
                self.patients.loc[idx, 'bucket'] = new_b

        self.done = done_next
        next_state = self._get_aggregated_state()
        return next_state, immediate_cost, done_next

###############################################################################
# (B) Define Q-Network
###############################################################################
class QNetwork(nn.Module):
    def __init__(self, state_dim=7, action_dim=21, hidden=32):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, action_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        qvals = self.fc3(x)  # shape [batch_size, action_dim]
        return qvals

###############################################################################
# (C) Q-Learning / ADP Loop
###############################################################################
def state_to_tensor(state):
    return torch.tensor([state], dtype=torch.float32)

def choose_action_epsilon_greedy(qnet, state, epsilon, max_action):
    """
    Since we are minimizing cost, we pick argmin Q(s,a).
    With epsilon probability, pick a random action among {0..max_action}.
    """
    if np.random.rand() < epsilon:
        return np.random.randint(0, max_action+1)
    else:
        with torch.no_grad():
            s_t = state_to_tensor(state)
            qvals = qnet(s_t).detach().numpy().flatten()
            return np.argmin(qvals)

def train_adp_on_G3(df_g3, capacity, max_time=20,
                    transition_matrix_healthy=None,
                    transition_matrix_sick=None,
                    gamma=0.99,
                    episodes=2000,
                    learning_rate=1e-3,
                    epsilon_start=0.2,
                    epsilon_decay=0.999,
                    batch_size=64,
                    replay_size=50000):
    """
    Train a Q-network (ADP) using transitions from the environment that
    has two different Markov transition matrices (healthy vs. sick).
    """
    env = HemorrhageEnvAggregatedWithTransitions(
        df_patients=df_g3,
        capacity=capacity,
        max_time=max_time,
        transition_matrix_healthy=transition_matrix_healthy,
        transition_matrix_sick=transition_matrix_sick
    )

    max_action = capacity  # we allow actions in [0..capacity]
    qnet = QNetwork(state_dim=7, action_dim=max_action+1, hidden=32)
    optimizer = optim.Adam(qnet.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    replay_buffer = []
    epsilon = epsilon_start

    def get_target(r, gamma_, next_state, done_):
        if done_:
            return r
        else:
            with torch.no_grad():
                qvals_next = qnet(state_to_tensor(next_state)).detach().numpy().flatten()
                # Minimizing cost => use min, not max
                return r + gamma_ * np.min(qvals_next)

    for ep in range(episodes):
        s = env.reset()
        done = False

        while not done:
            a = choose_action_epsilon_greedy(qnet, s, epsilon, max_action)
            s_next, cost, done = env.step(a)

            # store transition
            replay_buffer.append((s, a, cost, s_next, done))
            if len(replay_buffer) > replay_size:
                replay_buffer.pop(0)

            s = s_next

            # training step
            if len(replay_buffer) >= batch_size:
                batch_indices = np.random.choice(len(replay_buffer), batch_size, replace=False)
                states_b = []
                actions_b = []
                targets_b = []

                for idx in batch_indices:
                    st, ac, c_, sn, dn = replay_buffer[idx]
                    y_ = get_target(c_, gamma, sn, dn)
                    states_b.append(st)
                    actions_b.append(ac)
                    targets_b.append(y_)

                states_t = torch.tensor(states_b, dtype=torch.float32)
                actions_t = torch.tensor(actions_b, dtype=torch.long)
                targets_t = torch.tensor(targets_b, dtype=torch.float32)

                qvals_all = qnet(states_t)  # shape [batch_size, action_dim]
                qvals_chosen = qvals_all.gather(1, actions_t.unsqueeze(1)).squeeze(1)

                loss = loss_fn(qvals_chosen, targets_t)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Epsilon decay
        epsilon *= epsilon_decay
        epsilon = max(epsilon, 0.01)

    return qnet

print("[ADP] Training Q-Network on G3 with TWO Markov transitions (healthy vs. sick) ...")
qnet_final = train_adp_on_G3(
    df_g3=G3_scored,
    capacity=N_c,
    max_time=T_max,
    transition_matrix_healthy=transition_matrix_healthy,
    transition_matrix_sick=transition_matrix_sick,
    gamma=gamma,
    episodes=2000,
    learning_rate=1e-3,
    epsilon_start=0.2,
    epsilon_decay=0.999,
    batch_size=64
)
print("[ADP] Done training Q-network.")

###############################################################################
# STEP 5: Evaluate final policy (RF risk + ADP) on G4
###############################################################################

# 1) Generate risk scores + bucket for G4
G4_scored = get_risk_scores(G4, final_rf_model)
G4_scored['bucket'] = G4_scored['risk_score'].apply(bucket_risk)

def evaluate_policy_on_dataset(df_input, qnet, capacity, max_time=20,
                               transition_matrix_healthy=None,
                               transition_matrix_sick=None):
    """
    Evaluate the learned policy by running the environment with Markov transitions
    and a greedy policy w.r.t. Q(s,a). Minimizing cost => pick argmin Q(s,a).

    Returns:
      total_cost, avg_treatment_time, recall, precision
    """
    env_eval = HemorrhageEnvAggregatedWithTransitions(
        df_patients=df_input,
        capacity=capacity,
        max_time=max_time,
        transition_matrix_healthy=transition_matrix_healthy,
        transition_matrix_sick=transition_matrix_sick
    )
    s = env_eval.reset()
    done = False
    total_cost = 0.0

    while not done:
        with torch.no_grad():
            s_t = state_to_tensor(s)
            qvals = qnet(s_t).detach().numpy().flatten()
            action = np.argmin(qvals)  # Minimizing cost
        s_next, cost, done = env_eval.step(action)
        total_cost += cost
        s = s_next

    # final metrics from env_eval
    df_final = env_eval.patients
    labels = df_final['label'].values
    treated = df_final['treated'].values
    treat_time = df_final['treat_time'].values

    # basic classification outcomes
    TP = np.sum((labels == 1) & (treated == 1))
    FP_ = np.sum((labels == 0) & (treated == 1))
    FN_ = np.sum((labels == 1) & (treated == 0))

    recall = TP / (TP + FN_) if (TP + FN_) > 0 else 0.0
    precision = TP / (TP + FP_) if (TP + FP_) > 0 else 0.0

    # average treatment time among those treated
    treated_mask = (treated == 1)
    if treated_mask.sum() > 0:
        avg_treat_time = treat_time[treated_mask].mean()
    else:
        avg_treat_time = -1

    return total_cost, avg_treat_time, recall, precision

final_cost_G4, avg_tt_G4, recall_G4, precision_G4 = evaluate_policy_on_dataset(
    df_input=G4_scored,
    qnet=qnet_final,
    capacity=N_c,
    max_time=T_max,
    transition_matrix_healthy=transition_matrix_healthy,
    transition_matrix_sick=transition_matrix_sick
)

print("===== EVALUATION on G4 (Two Markov transitions) =====")
print(f"Cost: {final_cost_G4:.2f}")
print(f"Avg. Treatment Time: {avg_tt_G4:.2f}")
print(f"Recall: {recall_G4:.4f}")
print(f"Precision: {precision_G4:.4f}")

# Optionally compute the final RF AUC on G4
X4, y4 = prepare_data_for_ml(G4)
proba4 = final_rf_model.predict_proba(X4)[:,1]
auc_g4 = roc_auc_score(y4, proba4)
print(f"RandomForest AUC on G4: {auc_g4:.4f}")

G1 size: 3150
G2 size: 3150
G3 size: 3150
G4 size: 3150
[ML] Best AUC on G2 = 0.9286, best params = (100, 5, 1)
[ADP] Training Q-Network on G3 with TWO Markov transitions (healthy vs. sick) ...


In [None]:
###############################################
# File: adp_with_aggregated_markov.py
#
# Illustrative code for:
#   - Algorithm 0 (Standard Validation)
#   - Random Forest ML for risk
#   - Aggregated Markov transitions (bucket-level)
#   - Capacitated ADP (Q-learning) with approximate cost-to-go
#   - Final metrics on G4: cost, treatment time (approx), recall, precision
###############################################

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim

###############################################################################
# STEP 0: Data Import
###############################################################################
df = pd.read_csv("synthetic_patients_with_features.csv")

RANDOM_SEED_1 = 42
RANDOM_SEED_2 = 999
RANDOM_SEED_3 = 123

###############################################################################
# STEP 1: Data Partition -> (G1, G2, G3, G4) (Algorithm 0)
###############################################################################
G12, G34 = train_test_split(df, test_size=0.50, random_state=RANDOM_SEED_1, stratify=df['label'])
G1, G2 = train_test_split(G12, test_size=0.50, random_state=RANDOM_SEED_2, stratify=G12['label'])
G3, G4 = train_test_split(G34, test_size=0.50, random_state=RANDOM_SEED_3, stratify=G34['label'])

print(f"G1 size: {len(G1)}")
print(f"G2 size: {len(G2)}")
print(f"G3 size: {len(G3)}")
print(f"G4 size: {len(G4)}")

###############################################################################
# STEP 2: Train Random Forest on G1, pick best hyperparams by AUC on G2
###############################################################################
def compute_features(subdf):
    """ Example feature extraction; adjust as needed. """
    feats = subdf[['EIT','NIRS','EIS','time']].values
    return feats

def prepare_data_for_ml(df_input):
    X_ = compute_features(df_input)
    y_ = df_input['label'].values
    return X_, y_

X1, y1 = prepare_data_for_ml(G1)
X2, y2 = prepare_data_for_ml(G2)

param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [5, 10, None],
    'min_samples_leaf': [1, 5]
}
best_auc = -np.inf
best_params = None

for n_est in param_grid['n_estimators']:
    for md in param_grid['max_depth']:
        for msl in param_grid['min_samples_leaf']:
            rf_model = RandomForestClassifier(n_estimators=n_est,
                                              max_depth=md,
                                              min_samples_leaf=msl,
                                              random_state=0)
            rf_model.fit(X1, y1)
            preds_proba = rf_model.predict_proba(X2)[:,1]
            auc_val = roc_auc_score(y2, preds_proba)
            if auc_val > best_auc:
                best_auc = auc_val
                best_params = (n_est, md, msl)

print(f"[ML] Best AUC on G2 = {best_auc:.4f}, best params = {best_params}")

# Retrain on G1+G2
X12 = np.vstack([X1, X2])
y12 = np.hstack([y1, y2])
final_rf_model = RandomForestClassifier(n_estimators=best_params[0],
                                        max_depth=best_params[1],
                                        min_samples_leaf=best_params[2],
                                        random_state=0)
final_rf_model.fit(X12, y12)

###############################################################################
# STEP 3: Generate risk scores + buckets for G3
###############################################################################
def get_risk_scores(df_input, model):
    X_, _ = prepare_data_for_ml(df_input)
    risk_scores = model.predict_proba(X_)[:,1]
    df_out = df_input.copy()
    df_out['risk_score'] = risk_scores
    return df_out

def bucket_risk(score):
    if score < 0.2: return 0
    elif score < 0.4: return 1
    elif score < 0.6: return 2
    elif score < 0.8: return 3
    else: return 4

G3_scored = get_risk_scores(G3, final_rf_model)
G3_scored['bucket'] = G3_scored['risk_score'].apply(bucket_risk)

###############################################################################
# STEP 4: ADP (Q-learning) on aggregated Markov transitions
###############################################################################

# Cost parameters
FP = 10
FN = 50
D  = 1
gamma = 0.99
T_max = 20

# Resource capacity
num_sick_g3 = G3_scored['label'].sum()
N_c = int(0.5 * num_sick_g3) if num_sick_g3>0 else 0
N_c = max(0, N_c)

###############################################################################
# (A) Aggregated Markov environment
###############################################################################
class AggregatedMarkovEnv:
    """
    Keeps track of #healthy and #sick in each of 5 buckets: 
       bH[i], bS[i], for i=0..4
    Plus capacity usage and time.

    We do transitions as *expected* flows:
      new_bH[j] = sum_{i} (bH[i] * P_H[i,j])  except for those removed by action
      similarly for bS.
    We'll treat only from the highest-risk buckets downward, 
    removing either healthy or sick from that bucket in proportion to 
    the fraction of healthy vs. sick in that bucket.
    
    This is an approximation that avoids enumerating individual patients.
    """

    def __init__(self, df_patients, capacity, max_time=20,
                 transition_mat_healthy=None,
                 transition_mat_sick=None):
        """
        df_patients: must have columns [bucket, label]
        capacity: total resource capacity
        max_time: horizon
        transition_mat_healthy: 5x5 matrix for healthy transitions
        transition_mat_sick:    5x5 matrix for sick transitions
        """
        self.capacity = capacity
        self.max_time = max_time
        self.transition_mat_h = transition_mat_healthy
        self.transition_mat_s = transition_mat_sick

        # We compute aggregated counts:
        # bH[i] = how many healthy in bucket i
        # bS[i] = how many sick in bucket i
        bH = np.zeros(5, dtype=float)
        bS = np.zeros(5, dtype=float)

        for i in range(5):
            # healthy in bucket i
            cond = (df_patients['bucket']==i) & (df_patients['label']==0)
            bH[i] = cond.sum()

            # sick in bucket i
            cond = (df_patients['bucket']==i) & (df_patients['label']==1)
            bS[i] = cond.sum()

        self.bH = bH
        self.bS = bS

        # how many are "treated" so far => we do NOT free capacity in this example
        self.treated_so_far = 0.0
        # approximate "treatment time" tracking:
        # We'll store a running sum of ( # newly treated * time ), then 
        # average later for treated patients. This is an approximation in aggregated form.
        self.cumulative_treatment_time = 0.0

        self.t = 0
        self.done = False

    def reset(self):
        # No real "reset" to random for now; re-init everything from scratch if needed
        self.treated_so_far = 0.0
        self.cumulative_treatment_time = 0.0
        self.t = 0
        self.done = False
        return self._get_state()

    def _get_state(self):
        cap_rem = max(0, self.capacity - self.treated_so_far)
        # state is (bH0..bH4, bS0..bS4, cap_rem, t)
        # we'll flatten bH,bS
        state = np.concatenate([self.bH, self.bS, [cap_rem, self.t]])
        return state

    def step(self, action):
        """
        1) Treat 'action' patients from the top buckets (4..0).
           If bucket i has bH[i] healthy and bS[i] sick => total_i = bH[i]+bS[i].
           - The fraction that are healthy is bH[i]/total_i, sick is bS[i]/total_i.
           - We'll remove 'treated' from that bucket up to min(action_left, total_i).
        2) Compute immediate cost:
           - FP = 10 * (# healthy treated)
           - Delay = 1 * (sum of bS[i] for all i after action)
           - FN at final step = 50 * (sum of bS[i])
        3) Markov transitions (expected) for bH, bS if not final step
        4) Return (next_state, cost, done)
        """

        if self.done:
            return self._get_state(), 0.0, True

        # 1) Treat from top bucket down
        act_left = action
        # how many actually get treated
        actually_treated = 0.0
        treated_healthy = 0.0
        treated_sick = 0.0

        # We'll iterate from bucket=4 down to 0
        for bucket_i in reversed(range(5)):
            if act_left <= 0:
                break
            total_in_bucket = self.bH[bucket_i] + self.bS[bucket_i]
            if total_in_bucket <= 1e-9:
                continue

            # can treat up to min(act_left, total_in_bucket)
            can_treat = min(act_left, total_in_bucket)

            # fraction healthy vs sick in that bucket
            fracH = self.bH[bucket_i]/total_in_bucket
            fracS = self.bS[bucket_i]/total_in_bucket

            # remove from bH, bS
            treated_H_i = can_treat * fracH
            treated_S_i = can_treat * fracS
            self.bH[bucket_i] -= treated_H_i
            self.bS[bucket_i] -= treated_S_i

            treated_healthy += treated_H_i
            treated_sick += treated_S_i
            actually_treated += can_treat

            act_left -= can_treat

        # update capacity usage (we don't free capacity in the horizon)
        self.treated_so_far += actually_treated
        # approximate average treatment time tracking:
        # add sum(# newly treated * current_t)
        self.cumulative_treatment_time += actually_treated * self.t

        # 2) immediate cost
        cost_fp = FP * treated_healthy
        # delay = D * (sum of bS across all buckets)
        sum_sick_still_untreated = self.bS.sum()
        cost_delay = D * sum_sick_still_untreated

        # if final step, cost_fn for all sick left
        cost_fn = 0.0
        done_next = False
        if self.t == (self.max_time - 1):
            done_next = True
            cost_fn = FN * sum_sick_still_untreated

        immediate_cost = cost_fp + cost_delay + cost_fn

        # 3) Markov transitions if not done
        self.t += 1
        if self.t >= self.max_time:
            done_next = True

        if not done_next:
            # expected transitions
            new_bH = np.zeros(5, dtype=float)
            new_bS = np.zeros(5, dtype=float)

            # for healthy
            for i in range(5):
                if self.bH[i] > 1e-9:
                    for j in range(5):
                        new_bH[j] += self.bH[i]*self.transition_mat_h[i,j]
            # for sick
            for i in range(5):
                if self.bS[i] > 1e-9:
                    for j in range(5):
                        new_bS[j] += self.bS[i]*self.transition_mat_s[i,j]

            self.bH = new_bH
            self.bS = new_bS

        self.done = done_next
        next_state = self._get_state()
        return next_state, immediate_cost, done_next

    def get_avg_treatment_time(self):
        """
        Approx average time = cumulative_treatment_time / total_treated
        """
        total_treated = self.treated_so_far
        if total_treated < 1e-9:
            return -1
        return self.cumulative_treatment_time / total_treated

###############################################################################
# (B) Example transition matrices for healthy vs. sick
# You can estimate these from data or set them manually
###############################################################################
transition_mat_healthy = np.array([
    [0.70, 0.20, 0.10, 0.00, 0.00],  # bucket 0 -> ...
    [0.10, 0.70, 0.15, 0.05, 0.00],
    [0.05, 0.10, 0.65, 0.15, 0.05],
    [0.00, 0.05, 0.15, 0.60, 0.20],
    [0.00, 0.00, 0.05, 0.20, 0.75]
])
transition_mat_sick = np.array([
    [0.50, 0.30, 0.15, 0.05, 0.00],
    [0.05, 0.50, 0.25, 0.15, 0.05],
    [0.00, 0.10, 0.50, 0.30, 0.10],
    [0.00, 0.00, 0.10, 0.60, 0.30],
    [0.00, 0.00, 0.05, 0.25, 0.70]
])

###############################################################################
# (C) Q-Network
###############################################################################
class QNetwork(nn.Module):
    def __init__(self, state_dim=12, action_dim=21, hidden=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, action_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        qvals = self.fc3(x)  # shape [batch_size, action_dim]
        return qvals

###############################################################################
# (D) Q-Learning / ADP training
###############################################################################
def state_to_tensor(state):
    # state shape = (bH0..bH4, bS0..bS4, cap_rem, t) => length=10 + 1 + 1 = 12
    return torch.tensor([state], dtype=torch.float32)

def choose_action_epsilon_greedy(qnet, state, epsilon, max_action):
    if np.random.rand() < epsilon:
        return np.random.randint(0, max_action+1)
    else:
        with torch.no_grad():
            s_t = state_to_tensor(state)
            qvals = qnet(s_t).numpy().flatten()
            # Minimizing cost => pick argmin
            return np.argmin(qvals)

def train_adp_on_G3(df_g3, capacity, max_time=20,
                    trans_mat_h=None, trans_mat_s=None,
                    gamma=0.99,
                    episodes=2000,
                    learning_rate=1e-3,
                    epsilon_start=0.2,
                    epsilon_decay=0.999,
                    batch_size=64,
                    replay_size=20000):
    env = AggregatedMarkovEnv(df_patients=df_g3,
                              capacity=capacity,
                              max_time=max_time,
                              transition_mat_healthy=trans_mat_h,
                              transition_mat_sick=trans_mat_s)
    # actions = 0..capacity
    max_action = capacity
    # state_dim=12 => bH[5], bS[5], cap_rem, t
    qnet = QNetwork(state_dim=12, action_dim=max_action+1, hidden=64)
    optimizer = optim.Adam(qnet.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    replay_buffer = []
    epsilon = epsilon_start

    def get_target(r, gamma_, next_state, done_):
        if done_:
            return r
        else:
            with torch.no_grad():
                qvals_next = qnet(state_to_tensor(next_state)).numpy().flatten()
                return r + gamma_ * np.min(qvals_next)

    for ep in range(episodes):
        # "reset" the environment
        env.__init__(df_g3, capacity, max_time, trans_mat_h, trans_mat_s)
        s = env.reset()
        done = False

        while not done:
            a = choose_action_epsilon_greedy(qnet, s, epsilon, max_action)
            s_next, cost, done = env.step(a)

            # store
            replay_buffer.append((s, a, cost, s_next, done))
            if len(replay_buffer) > replay_size:
                replay_buffer.pop(0)

            s = s_next

            # training step
            if len(replay_buffer) >= batch_size:
                batch_indices = np.random.choice(len(replay_buffer), batch_size, replace=False)
                states_b = []
                actions_b = []
                targets_b = []
                for idx in batch_indices:
                    st, ac, c_, sn, dn = replay_buffer[idx]
                    y_ = get_target(c_, gamma, sn, dn)
                    states_b.append(st)
                    actions_b.append(ac)
                    targets_b.append(y_)

                states_t = torch.tensor(states_b, dtype=torch.float32)
                actions_t = torch.tensor(actions_b, dtype=torch.long)
                targets_t = torch.tensor(targets_b, dtype=torch.float32)

                qvals_all = qnet(states_t)  # shape [batch_size, action_dim]
                qvals_chosen = qvals_all.gather(1, actions_t.unsqueeze(1)).squeeze(1)

                loss = loss_fn(qvals_chosen, targets_t)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # end of episode: epsilon decay
        epsilon *= epsilon_decay
        epsilon = max(epsilon, 0.01)

    return qnet

print("[ADP] Training Q-Network on G3 (aggregated Markov) ...")
qnet_final = train_adp_on_G3(
    df_g3=G3_scored,
    capacity=N_c,
    max_time=T_max,
    trans_mat_h=transition_mat_healthy,
    trans_mat_s=transition_mat_sick,
    gamma=gamma,
    episodes=2000,
    learning_rate=1e-3,
    epsilon_start=0.2,
    epsilon_decay=0.999,
    batch_size=64
)
print("[ADP] Done training Q-network.")

###############################################################################
# STEP 5: Evaluate final policy on G4
###############################################################################
G4_scored = get_risk_scores(G4, final_rf_model)
G4_scored['bucket'] = G4_scored['risk_score'].apply(bucket_risk)

def evaluate_policy_aggregated(df_input, qnet, capacity, max_time=20,
                               trans_mat_h=None, trans_mat_s=None):
    """
    Evaluate the learned policy on a new dataset (G4) in aggregated form.
    Returns (cost, avg_treatment_time, recall, precision).
    """
    # 1) Build environment for G4
    env_eval = AggregatedMarkovEnv(df_patients=df_input,
                                   capacity=capacity,
                                   max_time=max_time,
                                   transition_mat_healthy=trans_mat_h,
                                   transition_mat_sick=trans_mat_s)
    s = env_eval.reset()
    done = False
    total_cost = 0.0

    while not done:
        with torch.no_grad():
            s_t = state_to_tensor(s)
            qvals = qnet(s_t).numpy().flatten()
            action = np.argmin(qvals)
        s_next, cost, done = env_eval.step(action)
        total_cost += cost
        s = s_next

    # Then compute approximate recall, precision, avg_treatment_time, etc.
    # Because we used aggregated counts, we can interpret:
    #   total_healthy = sum(env_eval.bH) + (treated_healthy?), 
    #   total_sick = sum(env_eval.bS) + ...
    # We'll approximate final outcomes from the environment's final bH,bS 
    # and the "treated_healthy" & "treated_sick" we found over time.

    # In this code, we only have final bH, bS. 
    # The environment doesn't explicitly track "treated_healthy" vs. "treated_sick" 
    # as separate accumulators, so let's approximate them from:
    #   total_healthy_init, total_sick_init - final remain = treated
    total_healthy_init = 0.0
    total_sick_init = 0.0
    for i in range(5):
        # how many healthy in i, sick in i at start
        cond_h = (df_input['bucket']==i)&(df_input['label']==0)
        cond_s = (df_input['bucket']==i)&(df_input['label']==1)
        total_healthy_init += cond_h.sum()
        total_sick_init += cond_s.sum()

    final_healthy_remain = env_eval.bH.sum()
    final_sick_remain = env_eval.bS.sum()

    treated_healthy = total_healthy_init - final_healthy_remain
    treated_sick = total_sick_init - final_sick_remain

    # precision = TP/(TP+FP) => TP = treated_sick, FP = treated_healthy
    if (treated_sick+treated_healthy) <= 1e-9:
        precision = 0.0
    else:
        precision = treated_sick/(treated_sick+treated_healthy)
    # recall = TP/(TP+FN) => FN = final_sick_remain
    if (treated_sick+final_sick_remain) <= 1e-9:
        recall = 0.0
    else:
        recall = treated_sick/(treated_sick+final_sick_remain)

    avg_treat_time = env_eval.get_avg_treatment_time()

    return total_cost, avg_treat_time, recall, precision

final_cost_G4, avg_tt_G4, recall_G4, precision_G4 = evaluate_policy_aggregated(
    df_input=G4_scored,
    qnet=qnet_final,
    capacity=N_c,
    max_time=T_max,
    trans_mat_h=transition_mat_healthy,
    trans_mat_s=transition_mat_sick
)

print("===== EVALUATION on G4 (aggregated Markov) =====")
print(f"Cost: {final_cost_G4:.2f}")
print(f"Avg. Treatment Time: {avg_tt_G4:.2f}")
print(f"Recall: {recall_G4:.4f}")
print(f"Precision: {precision_G4:.4f}")

# Optionally, random forest AUC on G4
X4, y4 = prepare_data_for_ml(G4)
proba4 = final_rf_model.predict_proba(X4)[:,1]
auc_g4 = roc_auc_score(y4, proba4)
print(f"RandomForest AUC on G4: {auc_g4:.4f}")