In [None]:
import json
import math
import re
import warnings

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn

from rl import make_df, make_transition_test, Model

from tqdm import tqdm

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))
print(torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

torch.set_printoptions(precision=4, sci_mode=False)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.allow_tf32 = True

In [None]:
train_id = pd.read_csv('processed//train_id.csv',index_col=0)
valid_id = pd.read_csv('processed/val_id.csv',index_col=0)
test_id = pd.read_csv('processed/test_id.csv',index_col=0)
temporal_id = pd.read_csv('processed/temporal_id.csv',index_col=0)

Data

In [None]:
def cal_thres(true_label, prob):

    unique_labels = np.unique(true_label)
    pos_label = max(unique_labels) 

    fpr, tpr, thresholds = roc_curve(true_label, prob, pos_label=pos_label)

    J = tpr - fpr
    ix = np.argmax(J)
    best_thresh = thresholds[ix]

    return best_thresh

In [None]:
def test(target, algorithm, version, test_transition):
    path = f"experiments/{target.lower()}/{algorithm}/{version}"
    trials = [f for f in os.listdir(path) if 'trial' in f.lower()]

    def natural_sort_key(s):
        return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

    trials = sorted([f for f in os.listdir(path) if 'trial' in f.lower()], key=natural_sort_key)

    best_trial, best_epoch, best_value = None, None, float('-inf')
    
    for trial in trials:

        valid_auroc_p_gat = torch.load(f"{path}/{trial}/valid_auroc_p_gat.pth")
        valid_auroc_p_med = torch.load(f"{path}/{trial}/valid_auroc_p_med.pth")

        auroc_values = [x + y for x, y in zip(valid_auroc_p_gat, valid_auroc_p_med)]
        max_value, max_epoch = max((v, i+1) for i, v in enumerate(auroc_values))
        
        if max_value > best_value:
            best_value = max_value
            best_epoch = max_epoch
            best_trial = trial

    params = json.load(open(f"{path}/{best_trial}/params.json"))
    
    valid_keys = {'mlp_size', 'mlp_num_layers', 'activation_type'} 
    filtered_params = {k: v for k, v in params.items() if k in valid_keys}
    
    obs_dim, nb_actions = test_transition.dataset.tensors[0].shape[1], max(test_transition.dataset.tensors[1])
    network = Model(obs_dim=obs_dim, nb_actions=nb_actions, **filtered_params).to(device)
    target_network = Model(obs_dim=obs_dim, nb_actions=nb_actions, **filtered_params).to(device)

    for model, name in zip([network, target_network], ["network", "target_network"]):
        model_path = f"{path}/{best_trial}/{name}_{best_epoch}.pth"
        model.load_state_dict(torch.load(model_path, map_location=device))

    q_values, rewards, patients, actions = [], [], [], []

    with torch.no_grad():
        for s, a, r, rp in test_transition:
            s, a = s.to(device), a.to(device).long() - 1
            q = network(s)

            if version == '_negative':
                r_clamped = torch.clamp(r, min=-1.0, max=0.0)
                rp_clamped = torch.clamp(rp, min=-1.0, max=0.0)
                q_clamped = torch.clamp(q, min=-1.0, max=0.0)
            elif version == '_positive':
                r_clamped = torch.clamp(r, min=0.0, max=1.0)
                rp_clamped = torch.clamp(rp, min=0.0, max=1.0)
                q_clamped = torch.clamp(q, min=0.0, max=1.0)
            else:
                r_clamped = torch.clamp(r, min=-1.0, max=1.0)
                rp_clamped = torch.clamp(rp, min=-1.0, max=1.0)
                q_clamped = torch.clamp(q, min=-1.0, max=1.0)


            q_values.append(q_clamped.cpu().numpy())
            rewards.append(r_clamped.cpu().numpy())
            patients.append(rp_clamped.cpu().numpy())
            actions.append(a.cpu().numpy())

    q_values, rewards, patients, actions = map(np.concatenate, [q_values, rewards, patients, actions])
    q_metrics = {stat: func(q_values, axis=1) for stat, func in zip(["max", "min", "med"], [np.max, np.min, np.median])}
    q_metrics["gat"] = q_values[np.arange(q_values.shape[0]), actions]

    aurocs = {f"auroc_p_{k}": round(roc_auc_score(patients, v), 3) for k, v in q_metrics.items()}
    if version != '_both':
        aurocs.update({f"auroc_{k}": round(roc_auc_score(rewards, v), 3) for k, v in q_metrics.items()})
    
    thresholds = {f"threshold_p_{k}": cal_thres(patients, v) for k, v in q_metrics.items()}
    if version != '_both':
        thresholds.update({f"threshold_{k}": cal_thres(rewards, v) for k, v in q_metrics.items()})

    def bootstrap_ci(y_true, y_pred, n_bootstraps=1000, ci=95):
        bootstrapped_scores = []
        rng = np.random.RandomState(42)

        for _ in range(n_bootstraps):
            while True:
                indices = rng.randint(0, len(y_true), len(y_true))
                if len(np.unique(y_true[indices])) >= 2:
                    break
            score = roc_auc_score(y_true[indices], y_pred[indices])
            bootstrapped_scores.append(score)
        
        sorted_scores = np.array(bootstrapped_scores)
        sorted_scores.sort()
        lower = round(np.percentile(sorted_scores, (100 - ci) / 2), 3)
        upper = round(np.percentile(sorted_scores, 100 - (100 - ci) / 2), 3)
        return lower, upper

    auroc_ci = {}
    for k, v in q_metrics.items():
        auroc_ci[f"auroc_p_{k}_ci"] = bootstrap_ci(patients, v)
        if version != '_both':
            auroc_ci[f"auroc_{k}_ci"] = bootstrap_ci(rewards, v)

    return (network, target_network, q_values, rewards, patients, actions, *q_metrics.values(), *aurocs.values(), thresholds, auroc_ci, params)

In [None]:
def process_baseline(train_path, test_df):
    train_RAW = pd.read_csv(train_path)
    test_df_baseline = test_df.copy(deep=True)
    scaler = StandardScaler()

    col_groups = {
        'norm': ['age', 'Weight', 'GCS', 'Heartrate', 'Systolic_BP', 'Diastolic_BP', 'Mean_BP', 'Resprate', 'Temperature', 'FiO2',
                 'Potassium', 'Sodium', 'Chloride', 'Glucose', 'Magnesium', 'Calcium', 'Hemoglobin', 'WBC', 'Platelets', 'PTT', 'PT',
                 'Arterial_ph', 'PaO2', 'PaCO2', 'BaseExcess', 'Bicarbonate', 'Lactate', 'SOFA', 'SIRS', 'Shock_Index', 'PaO2/FiO2',
                 'Cumulated_balance', 'elixhauser'],
        'log': ['SpO2', 'BUN', 'SCr', 'SGOT', 'SGPT', 'Total_Bilirubin', 'INR', 'output_total', 'output_4hr'],
        'bin': ['gender', 're_admission', 'MV'] + ['action_{}_prev'.format(i) for i in range(1, 10)]
    }

    for group, cols in col_groups.items():
        for col in cols:
            prefixed_col = f's:{col}'
            if group == 'norm':
                scaler.fit(train_RAW[[col]])
                mean, std = scaler.mean_[0], scaler.scale_[0]
                test_df_baseline[prefixed_col] = -(mean/std)
            elif group == 'log':
                scaler.fit(train_RAW[[col]])
                mean, std = scaler.mean_[0], scaler.scale_[0]
                test_df_baseline[prefixed_col] = (np.log(0.1) - mean) / std
            elif group == 'bin':
                test_df_baseline[prefixed_col] = -0.5

    return test_df_baseline

In [None]:
data_dict = {}

for target in tqdm(['Dead_icu', 'Dead_hosp', 'Dead_90', 'AKI_rrt', 'AKI_48', 'AKI_24', 'AKI_12', 'Septic_shock']):

    data_dict[target] = {} 

    data = pd.read_csv(f'processed/df_{target}.csv')

    if 'Septic' in target : reward = 'r:reward_septic_shock'
    elif 'Dead' in target : reward = 'r:reward_dead'
    else : reward = 'r:reward_aki'

    data[reward] = data.groupby('traj')[reward].transform(lambda x: x[:-1].tolist() + ([1] if x.iloc[-1] == 0 else [x.iloc[-1]]))

    train_df, valid_df, test_df = make_df(data, reward, train_id, valid_id, test_id)
    train_df, valid_df, temporal_df = make_df(data, reward, train_id, valid_id, temporal_id)
    test_transition = make_transition_test(test_df, reward, rolling_size=1)
    temporal_transition = make_transition_test(temporal_df, reward, rolling_size=1)

    test_df_baseline = process_baseline('processed/train_df_RAW.csv', test_df)
    test_df_baseline_transition = make_transition_test(test_df_baseline, reward, rolling_size=1)

    data_dict[target] = {
        "train": train_df,
        "valid": valid_df,
        "test": test_df,
        "temporal": temporal_df,
        "test_transition": test_transition,
        "temporal_transition": temporal_transition,
        "test_df_baseline_transition": test_df_baseline_transition
    }

    for algorithm in ['_ddqn','_cql','_iql','_bcq']:

        data_dict[target][algorithm] = {}
    
        for version in ['_negative', '_positive', '_both']:

            data_dict[target][algorithm][version] = {}
            
            def process_transitions(target, algorithm, version, transitions, data_dict):

                for transition_type, transition in transitions.items():

                    results = test(target, algorithm, version, transition)

                    if version == "_both":
                        (network, target_network, q_value, r_arr, patient, action_space, q_max, q_min, q_median, q_gather,
                        test_auroc_p_gat, test_auroc_p_med, test_auroc_p_min, test_auroc_p_max,
                        thresholds, auroc_ci, params) = results

                        test_auroc_gat = test_auroc_med = test_auroc_min = test_auroc_max = None
                    else:
                        (network, target_network, q_value, r_arr, patient, action_space, q_max, q_min, q_median, q_gather,
                        test_auroc_p_gat, test_auroc_p_med, test_auroc_p_min, test_auroc_p_max,
                        test_auroc_gat, test_auroc_med, test_auroc_min, test_auroc_max,
                        thresholds, auroc_ci, params) = results

                    data_dict[target][algorithm][version][transition_type] = {
                        "q_value": q_value,
                        "reward": r_arr,
                        "patient": patient,
                        "action_space": action_space,
                        "q_max": q_max,
                        "q_min": q_min,
                        "q_median": q_median,
                        "q_gather": q_gather,
                        "auroc_p_gat": test_auroc_p_gat,
                        "auroc_p_med": test_auroc_p_med,
                        "auroc_p_min": test_auroc_p_min,
                        "auroc_p_max": test_auroc_p_max,
                        "thresholds": thresholds,
                        "auroc_ci": auroc_ci,
                        'network': network,
                        'target_network': target_network,
                        'params': params
                    }

                    if version != "_both":
                        data_dict[target][algorithm][version][transition_type].update({
                            "auroc_gat": test_auroc_gat,
                            "auroc_med": test_auroc_med,
                            "auroc_min": test_auroc_min,
                            "auroc_max": test_auroc_max
                        })

                    print(f"Target: {target}, algorithm: {algorithm}, version: {version}, transition: {transition_type}")

                    auroc_keys = ["auroc_p_gat", "auroc_p_med", "auroc_p_min", "auroc_p_max"]
                    auroc_ci_keys = [f"auroc_p_{k}_ci" for k in ["gat", "med", "min", "max"]]

                    if version != '_both':
                        auroc_keys += ["auroc_gat", "auroc_med", "auroc_min", "auroc_max"]
                        auroc_ci_keys += [f"auroc_{k}_ci" for k in ["gat", "med", "min", "max"]]

                    for key in auroc_keys:
                        print(f"{key}: {data_dict[target][algorithm][version][transition_type][key]}")

                    for key in auroc_ci_keys:
                        ci_value = data_dict[target][algorithm][version][transition_type]["auroc_ci"].get(key, 'N/A')
                        print(f"{key}: {ci_value}")

                    print("-" * 50)

                return data_dict

            transitions = {
                'test': data_dict[target]['test_transition'],
                'temporal': data_dict[target]['temporal_transition']
            }

            data_dict = process_transitions(target, algorithm, version, transitions, data_dict)

In [None]:
torch.save(data_dict, 'processed/data_dict.pth')

In [None]:
data_dict = torch.load('processed/data_dict.pth', weights_only=False)

Distribution

In [None]:
POS_BAR_COLOR = 'blue'
POS_MEAN_COLOR = 'navy'
POS_STD_COLOR = 'dodgerblue'

NEG_BAR_COLOR = 'red'
NEG_MEAN_COLOR = 'darkred'
NEG_STD_COLOR = 'salmon'

for target in tqdm(['Dead_icu', 'AKI_rrt', 'Septic_shock']):
    data = pd.read_csv(f'processed/df_{target}.csv', index_col=0)

    if 'Septic' in target:
        reward = 'r:reward_septic_shock'
    elif 'Dead' in target:
        reward = 'r:reward_dead'
    else:
        reward = 'r:reward_aki'

    left_label, right_label = 'Positive', 'Negative'

    for algorithm in ['_iql']:
        for version in ['_negative', '_positive', '_both']:
            q_value = data_dict[target][algorithm][version]['test']['q_value']
            unique_conditions = np.unique(data_dict[target]['test'][reward])
            unique_conditions_range = np.unique(data_dict[target][algorithm][version]['test']['reward'])

            positive_traj = (
                data_dict[target]['test']
                .groupby('traj')
                .filter(lambda x: x.iloc[-1][reward] in [max(unique_conditions)])['traj']
                .drop_duplicates()
                .reset_index(drop=True)
            )
            negative_traj = (
                data_dict[target]['test']
                .groupby('traj')
                .filter(lambda x: x.iloc[-1][reward] in [min(unique_conditions)])['traj']
                .drop_duplicates()
                .reset_index(drop=True)
            )
            data_ = (
                data_dict[target]['test']
                .groupby('traj', group_keys=False)
                .apply(lambda x: x.iloc[:-1])
                .reset_index(drop=True)
            )

            action = [
                'No-Medication', 'Cephalosporin', 'Glycopeptide',
                'Beta-lactam', 'Carbapenem', 'Penicillin',
                'Minor Antibiotic', 'Selective Antimicrobial', 'Combination'
            ]

            q_value_df = pd.DataFrame(q_value, columns=action)
            combined_df = pd.concat([data_, q_value_df], axis=1)

            negative_df = pd.merge(combined_df, negative_traj, on='traj', how='inner')
            positive_df = pd.merge(combined_df, positive_traj, on='traj', how='inner')

            fig, axes = plt.subplots(3, 3, figsize=(16, 12))

            if version == '_positive':
                name = 'R-network'
            if version == '_negative':
                name = 'D-network'
            if version == '_both':
                name = 'C-network'

            fig.suptitle(
                f"{name}\nBlue: {left_label} | Red: {right_label}",
                fontsize=20, fontweight='bold'
            )

            for i, c in enumerate(action):
                row_idx = i // 3
                col_idx = i % 3
                ax = axes[row_idx, col_idx]

                neg_data_col = negative_df[c].dropna()
                pos_data_col = positive_df[c].dropna()

                if len(neg_data_col) > 0:
                    counts_neg, bins_neg = np.histogram(neg_data_col, bins=50)
                    bin_width = bins_neg[1] - bins_neg[0]
                    midpoints_neg = 0.5 * (bins_neg[1:] + bins_neg[:-1])
                    ratio_neg = counts_neg / counts_neg.sum()
                    data_dict[target][algorithm][version]['test'][f'ratio_neg_{c}'] = ratio_neg
                    ax.bar(midpoints_neg, ratio_neg, width=bin_width, color=NEG_BAR_COLOR, alpha=0.3, label=f"{right_label}")
                    neg_mean = neg_data_col.mean()
                    neg_std = neg_data_col.std()
                    ax.fill_betweenx(y=[0, 1], x1=neg_mean - neg_std, x2=neg_mean + neg_std, color=NEG_STD_COLOR, alpha=0.3,
                                     label=f"{right_label} ±1σ" if i == 0 else None)
                    ax.axvline(neg_mean, color=NEG_MEAN_COLOR, linestyle='--', label=f"{right_label} Mean" if i == 0 else None)

                if len(pos_data_col) > 0:
                    counts_pos, bins_pos = np.histogram(pos_data_col, bins=50)
                    bin_width_pos = bins_pos[1] - bins_pos[0]
                    midpoints_pos = 0.5 * (bins_pos[1:] + bins_pos[:-1])
                    ratio_pos = counts_pos / counts_pos.sum()
                    data_dict[target][algorithm][version]['test'][f'ratio_pos_{c}'] = ratio_pos
                    ax.bar(midpoints_pos, ratio_pos, width=bin_width_pos, color=POS_BAR_COLOR, alpha=0.3, label=f"{left_label}")
                    pos_mean = pos_data_col.mean()
                    pos_std = pos_data_col.std()
                    ax.fill_betweenx(y=[0, 1], x1=pos_mean - pos_std, x2=pos_mean + pos_std, color=POS_STD_COLOR, alpha=0.3,
                                     label=f"{left_label} ±1σ" if i == 0 else None)
                    ax.axvline(pos_mean, color=POS_MEAN_COLOR, linestyle='--', label=f"{left_label} Mean" if i == 0 else None)

                ax.set_title(c, fontsize=18)

                if version != 'both':
                    min_x = min(unique_conditions_range)
                    max_x = max(unique_conditions_range)
                else:
                    min_x = -1.0
                    max_x = +1.0

                start_tick = math.floor(min_x * 5) / 5.0
                end_tick = math.ceil(max_x * 5) / 5.0
                x_ticks = np.arange(start_tick, end_tick + 0.001, 0.2)
                ax.set_xticks(x_ticks)
                ax.set_xticklabels([f"{tick:.1f}" for tick in x_ticks], fontsize=12)
                ax.set_xlim(min_x, max_x)

                ax.set_ylim(0, 0.31)
                y_ticks = np.arange(0.05, 0.31, 0.05)
                ax.set_yticks(y_ticks)
                ax.set_yticklabels(y_ticks, fontsize=12)
                ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda val, _: f"{val*100:.1f}%"))
                ax.grid(False)

            if version == '_positive':
                fig.text(0.03, 0.5, "Relative proportion", va='center', rotation='vertical', fontsize=18)

            handles, labels = axes[0, 0].get_legend_handles_labels()
            legend_dict = dict(zip(labels, handles))

            if target == 'Septic_shock':
                fig.text(0.5, 0.04, "Q-value", ha='center', fontsize=18)
                fig.legend(legend_dict.values(), legend_dict.keys(), loc='lower center', bbox_to_anchor=(0.5, 0.00),
                           ncol=6, frameon=True, fontsize=15)

            plt.tight_layout(rect=[0.05, 0.06, 0.95, 0.98])
            plt.savefig(f'figure/{target}{algorithm}{version}.png', dpi=300)
            plt.show()


First Flag

In [None]:
def first_flag_plot(
    positive_df, negative_df,
    positive_df_RAW, negative_df_RAW,
    left_label, right_label,
    algorithm, threshold,
    data_dict, target
):
    t = [-6, -5, -4, -3, -2, -1, 0, +1, +2, +3, +4]

    predefined_vars = [
        'q_positive_median',
        'q_negative_median',
        'q_positive_gather',
        'q_negative_gather'
    ]

    s_col = [x for x in data_dict[target]['test'] if x.startswith('s:')]
    exclude_vars = ['prev', 'age', 'Weight', 're_admission', 'elixhauser', 'gender', 'MV', 'Glucose', 'FiO2', 'SGPT', 'SGOT', 'output_total', 'output_4hr', 'Cumulated_balance',' Arterial_ph', 'WBC']
    include_vars = ['Lactate', 'BUN', 'Diastolic_BP', 'Heartrate', 'INR', 'Resprate', 'SpO2', 'Systolic_BP', 'Temperature', 'GCS', 'SOFA', 'SIRS']
    s_col = [x for x in s_col if not any(ex in x for ex in exclude_vars)]
    s_col = [x for x in s_col if any(ex in x for ex in include_vars)]

    treatments = ['input_4hr','max_vaso', 'AKI_max_stage']
    all_vars = predefined_vars + treatments + s_col

    def map_variable_name(var_name):
        if var_name.startswith('s:'):
            title = var_name[2:].replace('_', ' ')
        else:
            title = var_name.replace('_', ' ')
        title = re.sub(r'\bSpo2\b', 'SpO$_2$', title, flags=re.IGNORECASE)
        title = title[0].upper() + title[1:]
        if 'median' in var_name and not 'vaso' in var_name:
            base = 'V'
        elif 'gather' in var_name:
            base = 'Q'
        else:
            base = title
        if 'positive' in var_name:
            suffix = '_R'
        elif 'negative' in var_name:
            suffix = '_D'
        else:
            suffix = ''
        if suffix:
            new_label = f"{base}$_{{{suffix[1]}}}$"
        else:
            new_label = base
        return new_label, var_name

    mapped_vars = [map_variable_name(var) for var in all_vars]

    special_original_vars = [
        'q_positive_median',
        'q_negative_median',
        'q_positive_gather',
        'q_negative_gather'
    ]

    special_mapped_vars = [mv for mv in mapped_vars if mv[1] in special_original_vars]
    main_mapped_vars = [mv for mv in mapped_vars if mv[1] not in special_original_vars]

    num_main = len(main_mapped_vars)
    num_special = len(special_mapped_vars)

    plots_per_row = 5
    rows_main = math.ceil(num_main / plots_per_row) if num_main > 0 else 0
    total_rows = rows_main + (1 if num_special > 0 else 0)
    total_subplots = num_main + num_special + 1

    fig_width = 4 * plots_per_row * 0.5
    fig_height = 3 * total_rows * 0.5
    fig, axes = plt.subplots(total_rows, plots_per_row, figsize=(fig_width, fig_height), sharex=True, sharey=False)

    if total_rows == 1:
        axes = np.array(axes).reshape(-1)
    else:
        axes = axes.flatten()

    legend_ax = None
    if total_subplots > plots_per_row * total_rows:
        legend_ax = axes[-1]
    else:
        legend_ax = axes[plots_per_row * total_rows - 1]

    ax_idx = 0

    name = [left_label, right_label]
    df = [positive_df, negative_df]
    df_RAW = [positive_df_RAW, negative_df_RAW]
    colors = ['blue', 'orange']

    ytick_values = {
        'BaseExcess': [-5, 0, 5], 'Lactate': [0, 5], 'Arterial_ph': [7.3, 7.4],
        'BUN': [25, 50, 75], 'Calcium': [8, 9], 'Chloride': [100, 110], 'SCr': [0, 2.5],
        'Diastolic BP': [50, 75], 'FiO2': [0.5, 0.75], 'GCS': [5, 10, 15], 'Glucose': [100, 200],
        'Bicarbonate': [20, 30], 'Heartrate': [75, 100], 'Hemoglobin': [10.0, 12.5], 'INR': [1, 3],
        'Magnesium': [2, 2.5], 'Mean_BP': [60, 80], 'PT': [10, 30], 'PTT': [25, 75],
        'PaO2/FiO2': [0, 500], 'Platelets': [100, 200, 300], 'Potassium': [3.5, 4.0, 4.5],
        'Resprate': [20, 30], 'SGOT': [0, 1000], 'SGPT': [0, 1000], 'SIRS': [0, 3],
        'SOFA': [5, 10, 15], 'Shock_Index': [0.75, 1.00, 1.25], 'Sodium': [135, 140, 145],
        'SpO$_2$': [90, 100], 'Systolic BP': [100, 125], 'Temperature': [36, 38], 'Total_Bilirubin': [0, 10],
        'WBC': [0, 20], 'Cumulated_balance': [0, 25000], 'MV': [0, 1],
        'output_4hr': [0, 500], 'output_total': [0, 25000], 'PaCO2': [30, 40, 50], 'PaO2': [100, 200],
        'Input 4hr': [0, 1000], 'input_total': [0, 30000], 'median_vaso': [0, 2], 'Max vaso': [0, 5], 'AKI max stage': [0, 3],
        'V$_{D}$': [-0.25, 0.00], 'V$_{R}$': [0.75, 1.00], 'Q$_{D}$': [-0.25, 0.00], 'Q$_{R}$': [0.75, 1.00]
    }

    def get_yticks(variable):
        return ytick_values.get(variable, [])

    def plot_variable(ax, var_title, var_name):
        for j, label in enumerate(['positive', 'negative']):
            df[j][treatments] = df_RAW[j][treatments]
            condition = (df[j]['survivor'] == -j) & df[j]['first_flag'].isin(t)
            grouped = df[j][condition].groupby('first_flag')[var_name]
            mean = grouped.mean().reindex(t).values
            std = grouped.std().reindex(t).values
            ax.plot(t, mean, label=name[j], color=colors[j], linewidth=2)
            ax.fill_between(t, mean - std, mean + std, alpha=0.3, color=colors[j])
            ax.set_xlim(min(t), max(t))
            ax.set_xticks(t)
            ax.set_xticklabels([f'{4 * x}h' if ((x != 2) and (x != 4) and (x != -2) and (x != 0) and (x != -4) and (x != -5) and (x != 5)) else '' for x in t], fontsize=8)
        ax.axvline(0, color='gray', linestyle='--')
        ax.set_title(var_title, fontsize=8)
        ax.set_ylabel('')
        ax.set_yticks(get_yticks(var_title))
        ax.set_yticklabels(ax.get_yticks(), fontsize=8)

    for var_title, var_name in main_mapped_vars:
        if ax_idx >= len(axes) - 1:
            break
        ax = axes[ax_idx]
        plot_variable(ax, var_title, var_name)
        ax_idx += 1

    for var_title, var_name in special_mapped_vars:
        if ax_idx >= len(axes) - 1:
            break
        ax = axes[ax_idx]
        plot_variable(ax, var_title, var_name)
        ax_idx += 1

    handles, labels = ax.get_legend_handles_labels()
    legend_ax.axis('off')
    legend_ax.legend(handles, labels, loc='center', fontsize=12, frameon=False)

    for idx in range(ax_idx, len(axes) - 1):
        axes[idx].axis('off')

    if target == 'Septic_shock':
        fig.text(0.5, 0.04, 'Time from First Flag', ha='center', fontsize=10)

    plt.savefig(f'figure/{right_label}{algorithm}.png', dpi=300)
    plt.tight_layout(rect=[0.02, 0.02, 0.98, 0.98])
    plt.show()

Continuous Flag

In [None]:
def flag_plot(flag_dict, q_values, left_label, right_label, algorithm, threshold):
    fig, ax = plt.subplots(1, 2, figsize=(21, 9), dpi=300, constrained_layout=True)
    step = flag_dict['step']
    xaxis = np.arange(len(step), dtype=float)

    q_min, q_max = -1, +1
    from matplotlib.colors import TwoSlopeNorm
    norm = TwoSlopeNorm(vcenter=0.0, vmin=q_min, vmax=q_max)
    cmap = plt.cm.RdBu

    if 'Last\nSepsis' in step:
        idx_last_sepsis = step.index('Last\nSepsis')
        if idx_last_sepsis + 1 < len(step):
            xaxis[idx_last_sepsis+1:] += 0.5

    conditions = ['V_flag_Clinician_flag', 'V_flag_Clinician_no', 'V_no_Clinician_flag', 'V_no_Clinician_no']
    colors = []
    alphas = [0.8, 0.6, 0.4, 0.2]

    names = [left_label, right_label]

    subplot_traj_types = ['positive', 'negative']
    subplot_mean_q = []
    for traj_type in subplot_traj_types:
        cond_mean_q = []
        for cond in conditions:
            arr = [q for q in q_values[traj_type][cond] if not np.isnan(q)]
            if len(arr) > 0:
                avg_q = np.mean(arr)
            else:
                avg_q = 0.0
            cond_mean_q.append(avg_q)
        subplot_mean_q.append(cond_mean_q)

    colors = ['red','yellow','darkgray','lightgray']

    for i, traj_type in enumerate(['positive', 'negative']):
        bottom = np.zeros(len(step))
        vflag_vals = None
        for cond, color, alpha in zip(conditions, colors, alphas):
            vals = flag_dict[traj_type][cond]
            q = q_values[traj_type][cond]

            bars = ax[i].bar(xaxis, vals, bottom=bottom, color=color, alpha=alpha, label=cond, antialiased=True, rasterized=True)
            bottom += vals

            q_means = q
            labels = [f'{v:.1%}\n({qv:.2f})' if not (np.isnan(qv) or v < 0.05) else '' for v, qv in zip(vals, q_means)]
            ax[i].bar_label(bars, labels=labels, label_type='center', fontsize=10)

            if cond == 'V_no_Clinician_no':
                vflag_vals = [1 - v for v in vals]

        if vflag_vals is not None and len(step) > 2:
            ax[i].plot(xaxis[2:], vflag_vals[2:], color='black', marker='o', linestyle='-', markersize=5)

        ax[i].set_xticks(xaxis)
        ax[i].set_xticklabels('', fontsize=15)

        if target == 'Septic_shock':
            ax[i].set_xticklabels(step, fontsize=15)

        ax[i].set_title(names[i], fontsize=20)

        if target == 'Septic_shock':
            ax[i].set_xlabel("Time from Terminal state", fontsize=15)

        ax[i].set_yticks([0.2,0.4,0.6,0.8,1.0])
        ax[i].set_yticklabels(['20%', '40%', '60%', '80%', '100%'], fontsize=15)
        ax[i].set_ylim(0,1)

    legend_patches_0 = []
    for cond, color, a in zip(conditions, colors, alphas):
        patch = plt.Rectangle((0,0),1,1, facecolor=color, alpha=a)
        legend_patches_0.append((patch, cond))

    if target == 'Septic_shock':
        ax[0].legend(
            [lp[0] for lp in legend_patches_0],
            [lp[1] for lp in legend_patches_0],
            loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=4, fontsize=10
        )

    legend_patches_1 = []
    for cond, color, a in zip(conditions, colors, alphas):
        patch = plt.Rectangle((0,0),1,1, facecolor=color, alpha=a)
        legend_patches_1.append((patch, cond))
    
    if target == 'Septic_shock':
        ax[1].legend(
            [lp[0] for lp in legend_patches_1],
            [lp[1] for lp in legend_patches_1],
            loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=4, fontsize=10
        )

    if 'Last\nSepsis' in step:
        idx_last_sepsis = step.index('Last\nSepsis')
        if idx_last_sepsis + 1 < len(step):
            mid_point = (xaxis[idx_last_sepsis] + xaxis[idx_last_sepsis+1])/2
            ax[0].axvline(x=mid_point, color='black', linewidth=0.75)
            ax[1].axvline(x=mid_point, color='black', linewidth=0.75)

    plt.savefig(f'figure/Flag_analysis_{right_label}{algorithm}_{threshold}.png', dpi=300)
    plt.show()

def compute_flag_data(positive_df, negative_df, selected_steps):
    flag_dict = {'step': selected_steps, 'positive': {}, 'negative': {}}
    q_values = {'positive': {}, 'negative': {}}

    conditions = ['V_no_Clinician_no', 'V_no_Clinician_flag', 'V_flag_Clinician_no', 'V_flag_Clinician_flag']
    
    for traj_type in ['positive', 'negative']:
        flag_dict[traj_type] = {c: [] for c in conditions}
        q_values[traj_type] = {c: [] for c in conditions}

    df_list = [positive_df, negative_df]
    traj_types = ['positive', 'negative']

    last_sepsis_pos = positive_df[positive_df['sepsis'] == 1].groupby('traj').apply(lambda x: x.iloc[-1])
    last_sepsis_neg = negative_df[negative_df['sepsis'] == 1].groupby('traj').apply(lambda x: x.iloc[-1])

    for traj_type_i, traj_type in enumerate(traj_types):
        df_traj = df_list[traj_type_i]

        if traj_type == 'positive':
            q_col = 'q_positive_median'
            last_sepsis_data = last_sepsis_pos
        else:
            q_col = 'q_negative_median'
            last_sepsis_data = last_sepsis_neg

        for t in selected_steps:
            if t == 'Sepsis':
                tmp = df_traj[df_traj['sepsis'] == 1]
            elif t == 'Last\nSepsis':
                tmp = last_sepsis_data
            else:
                tmp = df_traj[df_traj['Time'] == t]

            total = len(tmp)

            subsets = {
                'V_no_Clinician_no': tmp[tmp['V_no_Clinician_no'] == 1],
                'V_no_Clinician_flag': tmp[tmp['V_no_Clinician_flag'] == 1],
                'V_flag_Clinician_no': tmp[tmp['V_flag_Clinician_no'] == 1],
                'V_flag_Clinician_flag': tmp[tmp['V_flag_Clinician_flag'] == 1]
            }

            for cond in conditions:
                count = len(subsets[cond])
                ratio = count / total if total > 0 else 0
                mean_q = subsets[cond][q_col].mean() if count > 0 else np.nan

                flag_dict[traj_type][cond].append(ratio)
                q_values[traj_type][cond].append(mean_q)

    return flag_dict, q_values

In [None]:
for target in tqdm(['Dead_icu', 'AKI_rrt', 'Septic_shock']):

    data = pd.read_csv(f'processed/df_{target}.csv', index_col=0)

    if 'Septic' in target : reward = 'r:reward_septic_shock'
    elif 'Dead' in target : reward = 'r:reward_dead'
    else : reward = 'r:reward_aki'

    if 'Dead' in target:
        left_label, right_label = 'Discharge', target
    elif 'rrt' in target:
        left_label, right_label = 'No-RRT', 'RRT'
    elif 'AKI' in target:
        left_label, right_label = 'No-C'+target, 'C'+target
    elif 'Septic' in target:
        left_label, right_label = 'No-Shock', 'Shock'
    else:
        left_label, right_label = 'Positive', 'Negative'

    for algorithm in ['_iql']:

        for threshold in ['threshold_p']:

            for source in ['test']:

                data = data_dict[target][source] 
                unique_conditions = np.unique(data[reward])
                positive_traj = data.groupby('traj').filter(lambda x: x.iloc[-1][reward] in [max(unique_conditions)])['traj'].drop_duplicates().reset_index(drop=True)
                negative_traj = data.groupby('traj').filter(lambda x: x.iloc[-1][reward] in [min(unique_conditions)])['traj'].drop_duplicates().reset_index(drop=True)

                data['max_step'] = data.groupby('traj')['step'].transform('max')
                data['step_before'] = data['max_step'] - data['step']

                def survivor(data):
                    if data[reward].sum() == 1: 
                        data['survivor'] = 0
                        return data['survivor'] 
                    
                    else:
                        data['survivor'] = -1
                        return data['survivor']
                    
                data['survivor'] = data.groupby('traj').apply(survivor).reset_index(drop=True) 
                data = data.groupby('traj', group_keys=False).apply(lambda x: x.iloc[:-1]).reset_index(drop=True)
                
                data['q_negative_median'] = data_dict[target][algorithm]['_negative'][source]['q_median']
                data['q_positive_median'] = data_dict[target][algorithm]['_positive'][source]['q_median']

                data['q_negative_gather'] = data_dict[target][algorithm]['_negative'][source]['q_gather']
                data['q_positive_gather'] = data_dict[target][algorithm]['_positive'][source]['q_gather']

                data['median_flag'] = ((data_dict[target][algorithm]['_positive'][source]['q_median'] <= data_dict[target][algorithm]['_positive'][source]['thresholds'][threshold+'_med']) & 
                                    (data_dict[target][algorithm]['_negative'][source]['q_median'] <= data_dict[target][algorithm]['_negative'][source]['thresholds'][threshold+'_med'])).astype(int)

                data['min_flag'] = ((data_dict[target][algorithm]['_positive'][source]['q_min'] <= data_dict[target][algorithm]['_positive'][source]['thresholds'][threshold+'_min']) & 
                                    (data_dict[target][algorithm]['_negative'][source]['q_min'] <= data_dict[target][algorithm]['_negative'][source]['thresholds'][threshold+'_min'])).astype(int)

                data['max_flag'] = ((data_dict[target][algorithm]['_positive'][source]['q_max'] <= data_dict[target][algorithm]['_positive'][source]['thresholds'][threshold+'_max']) & 
                                    (data_dict[target][algorithm]['_negative'][source]['q_max'] <= data_dict[target][algorithm]['_negative'][source]['thresholds'][threshold+'_max'])).astype(int)

                data['gather_flag'] = ((data_dict[target][algorithm]['_positive'][source]['q_gather'] <= data_dict[target][algorithm]['_positive'][source]['thresholds'][threshold+'_gat']) & 
                                    (data_dict[target][algorithm]['_negative'][source]['q_gather'] <= data_dict[target][algorithm]['_negative'][source]['thresholds'][threshold+'_gat'])).astype(int)

                data['V_no_Clinician_no'] = ((data['median_flag'] == 0) & (data['gather_flag'] == 0)).astype(int)
                data['V_no_Clinician_flag'] = ((data['median_flag'] == 0) & (data['gather_flag'] == 1)).astype(int)
                data['V_flag_Clinician_no'] = ((data['median_flag'] == 1) & (data['gather_flag'] == 0)).astype(int)
                data['V_flag_Clinician_flag'] = ((data['median_flag'] == 1) & (data['gather_flag'] == 1)).astype(int)
                
                action = ['No-Medication', 'Penicillin', 'Beta-lactam', 'Cephalosporin', 'Carbapenem', 'Glycopeptide', 'Selective Antimicrobial', 'Minor Antibiotic', 'Combination']

                tmp = pd.DataFrame(data.query('median_flag==1').groupby('traj')['step'].min()).reset_index(drop=False)
                tmp.rename(columns={'step':'first_flag_step'},inplace=True)
                data = data.merge(tmp,on='traj',how='left')
                data['first_flag'] = data['step'] - data['first_flag_step']
                test_df_RAW = pd.read_csv(f'processed/df_{target}_RAW.csv')
                test_df_RAW = test_df_RAW.groupby('traj', group_keys=False).apply(lambda x: x.iloc[:-1]).reset_index(drop=True)

                s_col = [x for x in data if x.startswith('s:') and 'prev' not in x]
                o_col = [col.replace('s:', '') for col in s_col if 'prev' not in col]
                data[s_col] = test_df_RAW[o_col]

                negative_df = pd.merge(data, negative_traj, on='traj', how='inner')
                negative_df['Time'] = negative_df.groupby('traj').cumcount(ascending=False).apply(lambda x: f'-{4 * (x + 1)}h')
                negative_df_RAW = pd.merge(test_df_RAW, negative_traj, on='traj', how='inner')
    
                positive_df = pd.merge(data, positive_traj, on='traj', how='inner')
                positive_df['Time'] = positive_df.groupby('traj').cumcount(ascending=False).apply(lambda x: f'-{4 * (x + 1)}h')
                positive_df_RAW = pd.merge(test_df_RAW, positive_traj, on='traj', how='inner')

                print(len(negative_df_RAW),len(negative_df))
                print(len(positive_df_RAW),len(positive_df))

                selected_steps = ['Sepsis', 'Last\nSepsis', '-48h', '-36h', '-24h', '-12h', '-8h','-4h']
                
                flag_dict_res, q_values_res = compute_flag_data(positive_df, negative_df, selected_steps)
                flag_plot(flag_dict_res, q_values_res, left_label, right_label, algorithm, threshold)

Feature Importance

In [None]:
class SHAPModelWrapper(nn.Module):
    def __init__(self, model, index=0):
        super(SHAPModelWrapper, self).__init__()
        self.model = model
        self.index = index

    def forward(self, s):
        q = self.model(s)
        return q[:, :, self.index]

In [None]:
from captum.attr import IntegratedGradients

for target in tqdm(['Dead_icu', 'Dead_hosp', 'Dead_90', 'AKI_rrt', 'AKI_48', 'AKI_24', 'AKI_12', 'Septic_shock']):

    s_col = [x for x in data_dict[target]['test'] if x[:2] == 's:']
    s_col = [x[2:] if x.startswith('s:') else x for x in s_col]
    
    background = data_dict[target]['test_transition'].dataset.tensors[0].to(device)
    baseline = data_dict[target]['test_df_baseline_transition'].dataset.tensors[0].to(device)

    for algorithm in (['_ddqn', '_cql', '_iql', '_bcq']): 
        for version in (['_negative', '_positive', '_both']):

            if 'integrated_gradients' not in data_dict[target][algorithm][version]:
                data_dict[target][algorithm][version]['test']['integrated_gradients'] = {}
                data_dict[target][algorithm][version]['test']['delta'] = {}

            for idx, action in enumerate(['No-Medication', 'Penicillin', 'Beta-lactam', 'Cephalosporin',
                                            'Carbapenem', 'Glycopeptide', 'Selective Antimicrobial', 'Minor Antibiotic', 'Combination']):
                
                model = data_dict[target][algorithm][version]['test']['network'].to(device)
                model.eval()
                
                def model_forward(input):
                    return model(input) 

                ig = IntegratedGradients(model_forward)
                
                attr, delta = ig.attribute(background, baseline, target=idx, return_convergence_delta=True)
                attr_mean = torch.mean(attr, dim=0).cpu().numpy()
                data_dict[target][algorithm][version]['test']['integrated_gradients'][action] = attr_mean
                data_dict[target][algorithm][version]['test']['delta'][action] = delta

In [None]:
rename_map = {
    "No-Nephrotoxic drug": "No-Medication",
    "Antimicrobial": "Selective Antimicrobial",
    "Antibiotic": "Minor Antibiotic"
}

for target in data_dict.keys():
    for algorithm in ['_ddqn', '_cql', '_iql', '_bcq']:
        for version in ['_negative', '_positive', '_both']:
            if 'integrated_gradients' in data_dict[target][algorithm][version]['test']:
                ig_dict = data_dict[target][algorithm][version]['test']['integrated_gradients']
                delta_dict = data_dict[target][algorithm][version]['test']['delta']
                
                updated_ig_dict = {rename_map.get(k, k): v for k, v in ig_dict.items()}
                updated_delta_dict = {rename_map.get(k, k): v for k, v in delta_dict.items()}
                
                data_dict[target][algorithm][version]['test']['integrated_gradients'] = updated_ig_dict
                data_dict[target][algorithm][version]['test']['delta'] = updated_delta_dict

In [None]:
actions = ['No-Medication', 'Penicillin', 'Beta-lactam','Cephalosporin', 'Carbapenem', 'Glycopeptide', 'Selective Antimicrobial', 'Minor Antibiotic', 'Combination']

algorithms = ['_iql']
versions = ['_negative', '_positive', '_both']

for target in ['Dead_icu', 'Dead_hosp', 'Dead_90', 'AKI_rrt', 'AKI_48', 'AKI_24', 'AKI_12', 'Septic_shock']:
    s_col = [x for x in data_dict[target]['train'] if x.startswith('s:')]
    s_col = [x[2:] for x in s_col]

    def replace_action_name(feature_name):
        parts = feature_name.split('_')
        if len(parts) == 3 and parts[0] == 'action' and parts[2] == 'prev':
            action_idx = int(parts[1])
            if 1 <= action_idx <= len(actions):
                return f"{actions[action_idx-1]}_prev"
        return feature_name
    
    s_col = [replace_action_name(x) for x in s_col]
    s_col = [item.replace('_', ' ') for item in s_col]

    def convert_subscripts(label):
        label = label.replace('_', ' ')
        label = re.sub(r'(\bFiO2\b)', r'$\\mathrm{FiO_2}$', label)
        label = re.sub(r'(\bPaO2\b)', r'$\\mathrm{PaO_2}$', label)
        label = re.sub(r'(\bPaCO2\b)', r'$\\mathrm{PaCO_2}$', label)
        label = re.sub(r'(\bPaO2/FiO2\b)', r'$\\mathrm{PaO_2/FiO_2}$', label)
        label = re.sub(r'(\bSpO2\b)', r'$\\mathrm{SpO_2}$', label)
        return label

    categories = [convert_subscripts(label) for label in s_col]

    fig, axes = plt.subplots(
        nrows=len(algorithms),
        ncols=len(versions),
        figsize=(20, 20),
        subplot_kw={'polar': True}
    )

    angles = np.linspace(0, 2*np.pi, len(categories), endpoint=False).tolist()
    angles += [angles[0]]

    for i, algorithm in enumerate(algorithms):
        for j, version in enumerate(versions):
            ax = axes[j]

            local_min_val = float('inf')
            local_max_val = float('-inf')
            all_action_data = {}
            for act in actions:
                data = data_dict[target][algorithm][version]['test']['integrated_gradients'][act]
                all_action_data[act] = data
                local_min_val = min(local_min_val, data.min())
                local_max_val = max(local_max_val, data.max())

            if local_min_val == local_max_val:
                local_min_val -= 1e-6
                local_max_val += 1e-6

            scaled_data_dict = {}
            for act in actions:
                data = all_action_data[act]
                scaled = 2.0*(data - local_min_val)/(local_max_val - local_min_val) - 1.0
                scaled_data_dict[act] = scaled

            ax.set_ylim(-1, 1)

            for idx, act in enumerate(actions):
                scaled_data = scaled_data_dict[act]
                plot_data = np.append(scaled_data, scaled_data[0])
                color = plt.cm.get_cmap('tab10')(idx)
                ax.fill(angles, plot_data, color=color, alpha=0.3)
                ax.plot(angles, plot_data, color=color, alpha=0.8, linewidth=2.0, label=act)

            label_r = 1.8
            for k, cat in enumerate(categories):
                theta = angles[k]
                angle_deg = np.degrees(theta)
                if 90 < angle_deg < 270:
                    angle_deg += 180
                ax.text(theta, label_r, cat, rotation=angle_deg, rotation_mode='anchor', ha='center', va='center', fontsize=8)

            radial_lines = angles[:-1]
            ax.set_xticks(radial_lines)
            ax.set_xticklabels([''] * len(radial_lines))
            ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0]) 
            ax.xaxis.grid(True, color='gray', linewidth=1.0, alpha=0.7)
            ax.yaxis.grid(True, color='black', linewidth=1.2)

            subplot_title = f"{target} - {algorithm.replace('_', '')} - {version.replace('_', '')}"
            ax.set_title(subplot_title, fontsize=12, pad=70)

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(
        handles, labels,
        loc='lower center', bbox_to_anchor=(0.5, -0.05),
        ncol=len(actions), fontsize=12, title="Actions"
    )

    fig.subplots_adjust(
        left=0.05, right=0.95,
        bottom=0.05, top=0.95,
        wspace=0.05,  
        hspace=0.8
    )

    plt.savefig(f'figure/Feature_Importance_{target}_all_algorithms_versions.png', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import re

targets = ['Dead_icu', 'AKI_rrt', 'Septic_shock']
versions = ['_negative', '_positive', '_both']
actions = [
    'No-Medication', 'Penicillin', 'Beta-lactam',
    'Cephalosporin', 'Carbapenem', 'Glycopeptide',
    'Selective Antimicrobial', 'Minor Antibiotic', 'Combination'
]

algorithm = '_iql'

target_mapping = {
    'Septic_shock': 'Shock',
    'AKI_rrt': 'RRT',
}
version_mapping = {
    '_negative': 'D-Network',
    '_positive': 'R-Network',
    '_both': 'C-Network'
}

fig, axes = plt.subplots(
    nrows=len(targets),
    ncols=len(versions),
    figsize=(25, 30),
    subplot_kw={'polar': True}
)

for i, target in enumerate(targets):
    s_col = [x for x in data_dict[target]['train'] if x.startswith('s:')]
    s_col = [x[2:] for x in s_col]

    def replace_action_name(feature_name):
        parts = feature_name.split('_')
        if len(parts) == 3 and parts[0] == 'action' and parts[2] == 'prev':
            action_idx = int(parts[1])
            if 1 <= action_idx <= len(actions):
                return f"{actions[action_idx-1]}_prev"
        return feature_name
    
    s_col = [replace_action_name(x) for x in s_col]
    s_col = [item.replace('_', ' ') for item in s_col]

    def convert_subscripts(label):
        label = re.sub(r'(\bFiO2\b)', r'$\\mathrm{FiO_2}$', label)
        label = re.sub(r'(\bPaO2\b)', r'$\\mathrm{PaO_2}$', label)
        label = re.sub(r'(\bPaCO2\b)', r'$\\mathrm{PaCO_2}$', label)
        label = re.sub(r'(\bPaO2/FiO2\b)', r'$\\mathrm{PaO_2/FiO_2}$', label)
        label = re.sub(r'(\bSpO2\b)', r'$\\mathrm{SpO_2}$', label)
        return label

    s_col = [sc[0].upper() + sc[1:] if sc else sc for sc in s_col]
    categories = [convert_subscripts(x) for x in s_col]

    angles = np.linspace(0, 2*np.pi, len(categories), endpoint=False)
    angles = np.concatenate((angles, [angles[0]]))

    for j, version in enumerate(versions):
        ax = axes[i, j]

        local_min_val = float('inf')
        local_max_val = float('-inf')
        all_action_data = {}

        for act in actions:
            data = data_dict[target][algorithm][version]['test']['integrated_gradients'][act]
            all_action_data[act] = data
            local_min_val = min(local_min_val, data.min())
            local_max_val = max(local_max_val, data.max())

        if local_min_val == local_max_val:
            local_min_val -= 1e-6
            local_max_val += 1e-6

        scaled_data_dict = {}
        for act in actions:
            raw_data = all_action_data[act]
            scaled = 2.0 * (raw_data - local_min_val) / (local_max_val - local_min_val) - 1.0
            scaled_data_dict[act] = scaled

        ax.set_ylim(-1, 1)

        for idx, act in enumerate(actions):
            scaled_data = scaled_data_dict[act]
            plot_data = np.append(scaled_data, scaled_data[0])
            color = plt.cm.get_cmap('tab10')(idx)
            ax.fill(angles, plot_data, color=color, alpha=0.3)
            ax.plot(angles, plot_data, color=color, alpha=0.8, linewidth=2.0,
                    label=act if (i == 0 and j == 0) else None)

        label_r = 1.8
        for k, cat in enumerate(categories):
            theta = angles[k]
            angle_deg = np.degrees(theta)
            if 90 < angle_deg < 270:
                angle_deg += 180
            ax.text(
                theta, label_r, cat,
                rotation=angle_deg, rotation_mode='anchor',
                ha='center', va='center', fontsize=10
            )

        ax.set_xticks(angles[:-1])
        ax.set_xticklabels([''] * (len(angles)-1))
        ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
        ax.xaxis.grid(True, color='gray', linewidth=1.0, alpha=0.7)
        ax.yaxis.grid(True, color='black', linewidth=1.2)

        display_target = target_mapping.get(target, target)
        display_version = version_mapping.get(version, version)
        subplot_title = f"{display_target} - {display_version}"
        ax.set_title(
            subplot_title,
            fontsize=14,
            pad=15
        )

handles, labels = axes[0,0].get_legend_handles_labels()
fig.legend(
    handles, labels,
    loc='lower center',
    bbox_to_anchor=(0.5, -0.02),
    ncol=len(actions),
    fontsize=15,
    title="Actions",
    title_fontsize=20
)

fig.subplots_adjust(
    left=0.05,
    right=0.95,
    bottom=0.05,
    top=0.90,
    wspace=0.6,
    hspace=0.20
)

plt.savefig("figure/sensitivity_investigation.png", dpi=300, bbox_inches='tight')
plt.show()

Advise against A

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch, Rectangle
import colorsys
from matplotlib.colors import to_rgb

def set_saturation(base_color, new_saturation):
    r, g, b = base_color
    h, l, s = colorsys.rgb_to_hls(r, g, b)
    s = new_saturation
    r2, g2, b2 = colorsys.hls_to_rgb(h, l, s)
    return (r2, g2, b2)

def percent_to_saturation(p):
    sat_min, sat_max = 0.3, 1.0
    alpha = 1.0
    t = (p/100.0)**alpha
    sat = sat_min + t*(sat_max - sat_min)
    return sat

def advise_against_A(positive_df, negative_df, action_stats, left_label, right_label, algorithm, threshold):
    base_colors = {
        'positive': to_rgb('blue'),
        'negative': to_rgb('red'),
        'both':     to_rgb('green'),
        'combination': to_rgb('yellow'),
    }

    label_colors = {
        'positive': 'blue',
        'negative': 'red',
        'both': 'green',
        'combination': 'yellow'
    }
    legend_handles = [
        Patch(facecolor=label_colors['positive'],    edgecolor='k', label='R-network'),
        Patch(facecolor=label_colors['negative'],    edgecolor='k', label='D-network'),
        Patch(facecolor=label_colors['both'],        edgecolor='k', label='C-network'),
        Patch(facecolor=label_colors['combination'], edgecolor='k', label='Full')
    ]

    action = [
        'No-Medication', 'Cephalosporin', 'Glycopeptide',
        'Beta-lactam', 'Carbapenem', 'Penicillin',
        'Minor Antibiotic', 'Selective Antimicrobial', 'Combination'
    ]
    
    n_action = len(action)

    max_count_positive = len(positive_df)
    max_count_negative = len(negative_df)
    max_count_diagonal = max_count_positive + max_count_negative

    for i in range(n_action):
        for j in range(n_action):
            key = f'{action[i]}-{action[j]}'
            if key not in action_stats:
                continue

            c_pos_p  = action_stats[key]['positive']['count_positive']
            c_neg_p  = action_stats[key]['positive']['count_negative']
            c_both_p = action_stats[key]['positive']['count_both']
            c_comb_p = action_stats[key]['positive']['count_combination']

            c_pos_n  = action_stats[key]['negative']['count_positive']
            c_neg_n  = action_stats[key]['negative']['count_negative']
            c_both_n = action_stats[key]['negative']['count_both']
            c_comb_n = action_stats[key]['negative']['count_combination']

            pos_max = max(c_pos_p, c_neg_p, c_both_p, c_comb_p)
            neg_max = max(c_pos_n, c_neg_n, c_both_n, c_comb_n)

            if i < j:
                if pos_max > max_count_positive:
                    max_count_positive = pos_max
            elif i > j:
                if neg_max > max_count_negative:
                    max_count_negative = neg_max
            else:
                c_pos_d  = c_pos_p  + c_pos_n
                c_neg_d  = c_neg_p  + c_neg_n
                c_both_d = c_both_p + c_both_n
                c_comb_d = c_comb_p + c_comb_n
                diag_max = max(c_pos_d, c_neg_d, c_both_d, c_comb_d)
                if diag_max > max_count_diagonal:
                    max_count_diagonal = diag_max

    if max_count_positive == 0:
        max_count_positive = 1e-9
    if max_count_negative == 0:
        max_count_negative = 1e-9
    if max_count_diagonal == 0:
        max_count_diagonal = 1e-9

    fig, ax = plt.subplots(figsize=(12,12))
    fig.suptitle(
        f'{left_label}(upper) vs. {right_label}(lower)',
        fontsize=20, fontweight='bold', y=0.95
    )

    ax.set_xlim(0, n_action)
    ax.set_ylim(0, n_action)

    ax.set_xticks(np.arange(n_action) + 0.5)
    ax.set_yticks(np.arange(n_action) + 0.5)
    ax.set_xticklabels(action, rotation=90, ha='center', va='top', fontsize=12)
    ax.set_yticklabels(action, fontsize=12)

    ax.invert_yaxis()
    ax.tick_params(axis='x', pad=2)

    ax.spines['top'].set_zorder(0)
    ax.spines['bottom'].set_zorder(0)
    ax.spines['left'].set_zorder(0)
    ax.spines['right'].set_zorder(0)
    ax.set_axisbelow(False)

    if target == 'Septic_shock':
        fig.legend(
            handles=legend_handles,
            loc='lower center',
            bbox_to_anchor=(0.6, 0.01),
            ncol=4,
            frameon=True,
            fontsize=18
        )

    ax.plot([0, n_action], [0, n_action], color='black', linewidth=2, zorder=0, alpha=0.3)

    for i in range(n_action):
        for j in range(n_action):
            key = f'{action[i]}-{action[j]}'
            if key not in action_stats:
                continue

            c_pos_p  = action_stats[key]['positive']['count_positive']
            c_neg_p  = action_stats[key]['positive']['count_negative']
            c_both_p = action_stats[key]['positive']['count_both']
            c_comb_p = action_stats[key]['positive']['count_combination']

            c_pos_n  = action_stats[key]['negative']['count_positive']
            c_neg_n  = action_stats[key]['negative']['count_negative']
            c_both_n = action_stats[key]['negative']['count_both']
            c_comb_n = action_stats[key]['negative']['count_combination']

            x0, y0 = j, i
            cell_w = 1.0
            cell_h = 1.0

            if i < j:
                local_max = max_count_positive
                c_pos  = c_pos_p
                c_neg  = c_neg_p
                c_both = c_both_p
                c_comb = c_comb_p
            elif i > j:
                local_max = max_count_negative
                c_pos  = c_pos_n
                c_neg  = c_neg_n
                c_both = c_both_n
                c_comb = c_comb_n
            else:
                local_max = max_count_diagonal
                c_pos  = c_pos_p  + c_pos_n
                c_neg  = c_neg_p  + c_neg_n
                c_both = c_both_p + c_both_n
                c_comb = c_comb_p + c_comb_n
                x0 = x0 + 0.075
                y0 = y0 + 0.075
                cell_w = 0.85
                cell_h = 0.85

            p_pos  = (c_pos  / local_max)*100.0
            p_neg  = (c_neg  / local_max)*100.0
            p_both = (c_both / local_max)*100.0
            p_comb = (c_comb / local_max)*100.0

            alpha_pos  = percent_to_saturation(p_pos)
            alpha_neg  = percent_to_saturation(p_neg)
            alpha_both = percent_to_saturation(p_both)
            alpha_comb = percent_to_saturation(p_comb)

            sat_pos  = percent_to_saturation(p_pos)
            sat_neg  = percent_to_saturation(p_neg)
            sat_both = percent_to_saturation(p_both)
            sat_comb = percent_to_saturation(p_comb)

            color_pos  = set_saturation(base_colors['positive'], sat_pos)
            color_neg  = set_saturation(base_colors['negative'], sat_neg)
            color_both = set_saturation(base_colors['both'],     sat_both)
            color_comb = set_saturation(base_colors['combination'], sat_comb)

            rect_tl = Rectangle((x0, y0), 0.5*cell_w, 0.5*cell_h,
                                facecolor=color_pos, 
                                edgecolor='k', linewidth=0.7, alpha=alpha_pos, zorder=10)
            ax.add_patch(rect_tl)
            ax.text(x0+0.25*cell_w, y0+0.25*cell_h, f"{p_pos:.1f}%", ha='center', va='center',
                    fontsize=9, color='black', zorder=11)

            rect_tr = Rectangle((x0+0.5*cell_w, y0), 0.5*cell_w, 0.5*cell_h,
                                facecolor=color_neg, 
                                edgecolor='k', linewidth=0.7, alpha=alpha_neg, zorder=10)
            ax.add_patch(rect_tr)
            ax.text(x0+0.75*cell_w, y0+0.25*cell_h, f"{p_neg:.1f}%", ha='center', va='center',
                    fontsize=9, color='black', zorder=11)

            rect_bl = Rectangle((x0, y0+0.5*cell_h), 0.5*cell_w, 0.5*cell_h,
                                facecolor=color_both, 
                                edgecolor='k', linewidth=0.7, alpha=alpha_both, zorder=10)
            ax.add_patch(rect_bl)
            ax.text(x0+0.25*cell_w, y0+0.75*cell_h, f"{p_both:.1f}%", ha='center', va='center',
                    fontsize=9, color='black', zorder=11)

            rect_br = Rectangle((x0+0.5*cell_w, y0+0.5*cell_h), 0.5*cell_w, 0.5*cell_h,
                                facecolor=color_comb, 
                                edgecolor='k', linewidth=0.7, alpha=alpha_comb, zorder=10)
            ax.add_patch(rect_br)
            ax.text(x0+0.75*cell_w, y0+0.75*cell_h, f"{p_comb:.1f}%", ha='center', va='center',
                    fontsize=9, color='black', zorder=11)

            cell_border = Rectangle((x0, y0), cell_w, cell_h,
                                    fill=False, edgecolor='black', linewidth=2, zorder=12)
            ax.add_patch(cell_border)

    outer_rect = Rectangle((0,0), n_action, n_action,
                           fill=False, edgecolor='black', linewidth=4, zorder=15)
    ax.add_patch(outer_rect)

    ax.set_xticks(np.arange(n_action), minor=True)
    ax.set_yticks(np.arange(n_action), minor=True)
    ax.grid(which="minor", color="gray", linestyle=':', linewidth=0.5, zorder=1)

    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.savefig(f'figure/advise_against_A_heatmap_{right_label}{algorithm}_{threshold}.png', dpi=300)
    plt.show()

In [None]:
for target in tqdm(['Dead_icu', 'AKI_rrt', 'Septic_shock']):

    data = pd.read_csv(f'processed/df_{target}.csv', index_col=0)

    if 'Septic' in target : reward = 'r:reward_septic_shock'
    elif 'Dead' in target : reward = 'r:reward_dead'
    else : reward = 'r:reward_aki'

    if 'Dead' in target:
        left_label, right_label = 'Discharge', target
    elif 'rrt' in target:
        left_label, right_label = 'No-RRT', 'RRT'
    elif 'AKI' in target:
        left_label, right_label = 'Nn-C'+target, 'C'+target
    elif 'Septic' in target:
        left_label, right_label = 'No-Shock', 'Shock'
    else:
        left_label, right_label = 'Positive', 'Negative'

    for algorithm in ['_iql']:

        for threshold in ['threshold_p']:

            unique_conditions = np.unique(data_dict[target]['test'][reward])

            positive_traj = data_dict[target]['test'].groupby('traj').filter(lambda x: x.iloc[-1][reward] in [max(unique_conditions)])['traj'].drop_duplicates().reset_index(drop=True)
            negative_traj = data_dict[target]['test'].groupby('traj').filter(lambda x: x.iloc[-1][reward] in [min(unique_conditions)])['traj'].drop_duplicates().reset_index(drop=True)
            data = data_dict[target]['test'].groupby('traj', group_keys=False).apply(lambda x: x.iloc[:-1]).reset_index(drop=True)

            negative_df = pd.merge(data, negative_traj, on='traj', how='inner')
            positive_df = pd.merge(data, positive_traj, on='traj', how='inner')

            action_stats = {}

            action = [
                'No-Medication', 'Cephalosporin', 'Glycopeptide',
                'Beta-lactam', 'Carbapenem', 'Penicillin',
                'Minor Antibiotic', 'Selective Antimicrobial', 'Combination'
    ]

            def calculate_action_stats(df, treatment_A, treatment_B):
                stats = {}
                for label in ['positive', 'negative', 'both', 'combination']:

                    count = df.loc[
                        (df[f'advise_againt_{treatment_A}_{label}'] == 1) & 
                        (df[f'advise_againt_{treatment_B}_{label}'] == 1),
                        f'advise_againt_{treatment_A}_{label}'
                    ].sum()

                    mean = round(
                    df.loc[
                        (df[f'advise_againt_{treatment_A}_{label}'] == 1) & 
                        (df[f'advise_againt_{treatment_B}_{label}'] == 1),
                        treatment_A
                    ].mean() * 100, 1
                )

                    stats[f'count_{label}'] = count
                    stats[f'mean_{label}'] = mean

                return stats

            for idx, treatment_A in enumerate(action):
                for jdx, treatment_B in enumerate(action):

                    for label in ['positive', 'negative', 'both']:
                        data[f'advise_againt_{treatment_A}_{label}'] = (
                            (data_dict[target][algorithm][f'_{label}']['test']['q_value'][:,idx] <= 
                            data_dict[target][algorithm][f'_{label}']['test']['thresholds'][threshold+'_med'])
                        ).astype(int)

                        data[f'advise_againt_{treatment_B}_{label}'] = (
                            (data_dict[target][algorithm][f'_{label}']['test']['q_value'][:,jdx] <= 
                            data_dict[target][algorithm][f'_{label}']['test']['thresholds'][threshold+'_med'])
                        ).astype(int)

                        data[f'{treatment_A}'] = data_dict[target][algorithm][f'_{label}']['test']['q_value'][:,idx]
                        data[f'{treatment_B}'] = data_dict[target][algorithm][f'_{label}']['test']['q_value'][:,jdx]

                    data[f'advise_againt_{treatment_A}_combination'] = (
                        (data_dict[target][algorithm]['_positive']['test']['q_value'][:,idx] <= 
                        data_dict[target][algorithm]['_positive']['test']['thresholds'][threshold+'_med']) &
                        (data_dict[target][algorithm]['_negative']['test']['q_value'][:,idx] <= 
                        data_dict[target][algorithm]['_negative']['test']['thresholds'][threshold+'_med'])
                    ).astype(int)

                    data[f'advise_againt_{treatment_B}_combination'] = (
                        (data_dict[target][algorithm]['_positive']['test']['q_value'][:,jdx] <= 
                        data_dict[target][algorithm]['_positive']['test']['thresholds'][threshold+'_med']) &
                        (data_dict[target][algorithm]['_negative']['test']['q_value'][:,jdx] <= 
                        data_dict[target][algorithm]['_negative']['test']['thresholds'][threshold+'_med'])
                    ).astype(int)

                    positive_df = pd.merge(data, positive_traj, on='traj', how='inner')
                    negative_df = pd.merge(data, negative_traj, on='traj', how='inner')

                    action_stats[f'{treatment_A}-{treatment_B}'] = {
                        'positive': calculate_action_stats(positive_df, treatment_A, treatment_B),
                        'negative': calculate_action_stats(negative_df, treatment_A, treatment_B)
                    }           

            advise_against_A(positive_df, negative_df, action_stats, left_label, right_label, algorithm, threshold)

Trajectory analysis

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap, BoundaryNorm, to_rgba

ACTION_LIST = [
    'No-Medication',
    'Penicillin',
    'Beta-lactam',
    'Cephalosporin',
    'Carbapenem',
    'Glycopeptide',
    'Selective Antimicrobial',
    'Minor Antibiotic',
    'Combination'
]
n_drug = len(ACTION_LIST)

def combine_three_targets_into_array(df_dead, df_aki, df_ss, action_list):
    steps = df_dead['step'].values
    T = len(steps)
    data_full = np.full((25, T), np.nan)
    data_full[0, :] = df_dead['q_negative_median'].to_numpy()
    data_full[1, :] = df_dead['q_negative_gather'].to_numpy()
    data_full[2, :] = df_dead['q_positive_median'].to_numpy()
    data_full[3, :] = df_dead['q_positive_gather'].to_numpy()
    data_full[4, :] = df_aki['q_negative_median'].to_numpy()
    data_full[5, :] = df_aki['q_negative_gather'].to_numpy()
    data_full[6, :] = df_aki['q_positive_median'].to_numpy()
    data_full[7, :] = df_aki['q_positive_gather'].to_numpy()
    data_full[8, :]  = df_ss['q_negative_median'].to_numpy()
    data_full[9, :]  = df_ss['q_negative_gather'].to_numpy()
    data_full[10, :] = df_ss['q_positive_median'].to_numpy()
    data_full[11, :] = df_ss['q_positive_gather'].to_numpy()
    data_full[12, :] = df_dead['a:action'].to_numpy()
    data_full[13, :] = df_dead['AKI_max_stage'].to_numpy()
    for i, drug_name in enumerate(action_list):
        row_idx = 14 + i
        dead_adv = (((df_dead[f'advise_againt_{drug_name}_negative'].values == 1) |
                     (df_dead[f'advise_againt_{drug_name}_positive'].values == 1)).astype(int))
        aki_adv = (((df_aki[f'advise_againt_{drug_name}_negative'].values == 1) |
                    (df_aki[f'advise_againt_{drug_name}_positive'].values == 1)).astype(int))
        ss_adv = (((df_ss[f'advise_againt_{drug_name}_negative'].values == 1) |
                   (df_ss[f'advise_againt_{drug_name}_positive'].values == 1)).astype(int))
        data_full[row_idx, :] = (dead_adv << 2) | (aki_adv << 1) | ss_adv
    median_flag_combined = (df_dead['median_flag'].to_numpy() << 2) | \
                           (df_aki['median_flag'].to_numpy() << 1) | \
                           (df_ss['median_flag'].to_numpy())
    gather_flag_combined = (df_dead['gather_flag'].to_numpy() << 2) | \
                           (df_aki['gather_flag'].to_numpy() << 1) | \
                           (df_ss['gather_flag'].to_numpy())
    data_full[23, :] = median_flag_combined
    data_full[24, :] = gather_flag_combined
    return data_full, steps

def plot_combined_heatmap(data_full, steps, df_ss=None, action_list=ACTION_LIST,
                          traj_id='', threshold_list=None, threshold='threshold_p',
                          figsize=(12, 10), dataset='test', algorithm='_ddqn'):
    n_rows, T = data_full.shape
    if (threshold_list is not None) and (len(threshold_list) == 6):
        thr_pos_dead, thr_neg_dead, thr_pos_aki, thr_neg_aki, thr_pos_ss, thr_neg_ss = threshold_list
    else:
        thr_pos_dead = thr_neg_dead = thr_pos_aki = thr_neg_aki = thr_pos_ss = thr_neg_ss = None
    def _fmt_thr(x):
        if x is None: return '-'
        s = f"{x:.2f}"
        return s.rstrip('0').rstrip('.')
    title_str = (
        f"Traj={traj_id} | {algorithm} | Combined(Dead,AKI,SS) | "
        f"thr(pos)={[ _fmt_thr(thr_pos_dead), _fmt_thr(thr_pos_aki), _fmt_thr(thr_pos_ss) ]}, "
        f"thr(neg)={[ _fmt_thr(thr_neg_dead), _fmt_thr(thr_neg_aki), _fmt_thr(thr_neg_ss) ]}"
    )
    fig, ax = plt.subplots(figsize=figsize)
    mask_q = np.ones_like(data_full, dtype=bool)
    mask_q[0:12, :] = False
    sns.heatmap(data_full, mask=mask_q, cmap='plasma', vmin=-1, vmax=1,
                annot=True, fmt='.2f', linewidths=1, linecolor='white',
                cbar=False, xticklabels=False, yticklabels=False, ax=ax)
    mask_treat = np.ones_like(data_full, dtype=bool)
    mask_treat[12, :] = False
    data_treatment = np.zeros_like(data_full)
    annot_treat = np.full(data_full.shape, '', dtype=object)
    for c in range(T):
        annot_treat[12, c] = str(int(data_full[12, c]))
    sns.heatmap(data_treatment, mask=mask_treat, cmap=['black'], annot=annot_treat,
                fmt='s', cbar=False, linewidths=1, linecolor='white',
                xticklabels=False, yticklabels=False, ax=ax)
    mask_treat = np.ones_like(data_full, dtype=bool)
    mask_treat[13, :] = False
    data_treatment = np.zeros_like(data_full)
    annot_treat = np.full(data_full.shape, '', dtype=object)
    for c in range(T):
        annot_treat[13, c] = str(int(data_full[13, c]))
    sns.heatmap(data_treatment, mask=mask_treat, cmap=["#ffd54f"], annot=annot_treat,
                fmt='s', cbar=False, linewidths=1, linecolor='white',
                xticklabels=False, yticklabels=False, ax=ax)
    mask_advise = np.ones_like(data_full, dtype=bool)
    mask_advise[14:14+n_drug, :] = False
    boundaries = np.arange(9)
    norm_advise = BoundaryNorm(boundaries, ncolors=8, clip=True)
    advise_cmap = ListedColormap([
        to_rgba("#e0f7da", alpha=0.8), "#ffcc80", "#ffd54f", "#ff9800",
        "#ff8a80", "#d50000", "#a70000", "#000000",
    ])
    g_advise = sns.heatmap(data_full, mask=mask_advise, cmap=advise_cmap,
                           norm=norm_advise, cbar=True,
                           cbar_kws={'boundaries': boundaries, 'ticks': np.arange(8) + 0.5,
                                     'spacing': 'uniform'},
                           linewidths=1, linecolor='grey',
                           xticklabels=False, yticklabels=False, ax=ax)
    cbar_ax = g_advise.figure.axes[-1]
    cbar_ax.yaxis.set_ticks(np.arange(8) + 0.5)
    cbar_ax.yaxis.set_ticklabels(["No", "SS", "AKI", "AKI+SS",
                                  "Dead", "Dead+SS", "Dead+AKI", "Dead+AKI+SS"])
    cbar_ax.tick_params(labelsize=9)
    mask_flag = np.ones_like(data_full, dtype=bool)
    mask_flag[23:25, :] = False
    flag_cmap = ListedColormap([
        to_rgba("#e0f7da", alpha=0.8), "#ffcc80", "#ffd54f", "#ff9800",
        "#ff8a80", "#e53935", "#d32f2f", "#000000",
    ])
    sns.heatmap(data_full, mask=mask_flag, cmap=flag_cmap, norm=norm_advise,
                cbar=False, linewidths=1, linecolor='white',
                xticklabels=False, yticklabels=False, ax=ax)
    ax.set_xticks([i+0.5 for i in range(T)])
    x_tick_labels = [str(s) for s in steps]
    if df_ss is not None and ('sepsis' in df_ss.columns):
        sepsis_df = df_ss[df_ss['sepsis'] == 1].copy()
        if not sepsis_df.empty:
            sepsis_df = sepsis_df.sort_values('step')
            first_step = sepsis_df['step'].iloc[0]
            if first_step in steps:
                i_f = np.where(steps == first_step)[0][0]
                x_tick_labels[i_f] += r"$^{F}$"
    ax.set_xticklabels(x_tick_labels, rotation=0)
    ylabels = [
        r'$V_{D}^{Dead}$', r'$Q_{D}^{Dead}$', r'$V_{R}^{Dead}$', r'$Q_{R}^{Dead}$',
        r'$V_{D}^{AKI}$', r'$Q_{D}^{AKI}$', r'$V_{R}^{AKI}$', r'$Q_{R}^{AKI}$',
        r'$V_{D}^{SS}$', r'$Q_{D}^{SS}$', r'$V_{R}^{SS}$', r'$Q_{R}^{SS}$',
        'Treatment', r'$AKI_{stage}$'
    ]
    for i in range(n_drug):
        ylabels.append(f"Advise: {i+1}")
    ylabels += [r'$V_{flag}$', r'$Q_{flag}$']
    ax.set_yticks([i+0.5 for i in range(n_rows)])
    ax.set_yticklabels(ylabels, rotation=0)
    ax.set_title(title_str, fontsize=12, pad=10)
    ax.set_xlabel("Time (step=4hrs)")
    plt.tight_layout()
    plt.savefig(f"figure/traj_{dataset}/{traj_id}{algorithm}_{threshold}.png",
                dpi=300, bbox_inches='tight')
    plt.show()

from tqdm import tqdm

def attach_flags_and_advise(data, data_dict, target, algorithm, threshold, dataset):
    data['q_negative_median'] = data_dict[target][algorithm]['_negative'][dataset]['q_median'].round(2)
    data['q_positive_median'] = data_dict[target][algorithm]['_positive'][dataset]['q_median'].round(2)
    data['q_negative_gather'] = data_dict[target][algorithm]['_negative'][dataset]['q_gather'].round(2)
    data['q_positive_gather'] = data_dict[target][algorithm]['_positive'][dataset]['q_gather'].round(2)

    data['median_flag'] = (
        (data_dict[target][algorithm]['_positive'][dataset]['q_median']
         <= data_dict[target][algorithm]['_positive'][dataset]['thresholds'][threshold+'_med'])
        &
        (data_dict[target][algorithm]['_negative'][dataset]['q_median']
         <= data_dict[target][algorithm]['_negative'][dataset]['thresholds'][threshold+'_med'])
    ).astype(int)

    data['gather_flag'] = (
        (data_dict[target][algorithm]['_positive'][dataset]['q_gather']
         <= data_dict[target][algorithm]['_positive'][dataset]['thresholds'][threshold+'_gat'])
        &
        (data_dict[target][algorithm]['_negative'][dataset]['q_gather']
         <= data_dict[target][algorithm]['_negative'][dataset]['thresholds'][threshold+'_gat'])
    ).astype(int)

    actions = [
        'No-Medication', 'Penicillin', 'Beta-lactam', 'Cephalosporin',
        'Carbapenem', 'Glycopeptide', 'Selective Antimicrobial', 'Minor Antibiotic', 'Combination'
    ]
    for idx, treatment in enumerate(actions):
        for label in ['positive', 'negative']:
            data[f'advise_againt_{treatment}_{label}'] = (
                data_dict[target][algorithm][f'_{label}'][dataset]['q_value'][:, idx]
                <= data_dict[target][algorithm][f'_{label}'][dataset]['thresholds'][threshold+'_med']
            ).astype(int)
            data[f'{treatment}'] = data_dict[target][algorithm][f'_{label}'][dataset]['q_value'][:, idx]

        data[f'advise_againt_{treatment}_combination'] = (
            (data_dict[target][algorithm]['_positive'][dataset]['q_value'][:, idx]
             <= data_dict[target][algorithm]['_positive'][dataset]['thresholds'][threshold+'_med'])
            &
            (data_dict[target][algorithm]['_negative'][dataset]['q_value'][:, idx]
             <= data_dict[target][algorithm]['_negative'][dataset]['thresholds'][threshold+'_med'])
        ).astype(int)

    return data

def fails_filter_conditions(df_traj):
    if df_traj['q_positive_median'].iloc[0] <= df_traj['q_positive_median'].iloc[-1]:
        return True
    if df_traj['q_negative_median'].iloc[0] <= df_traj['q_negative_median'].iloc[-1]:
        return True
    if len(df_traj) < 3:
        return True
    for flag in ['median_flag', 'gather_flag']:
        count_ones = (df_traj[flag] == 1).sum()
        length = len(df_traj[flag])
        if (count_ones == length) or (count_ones == 0):
            return True
    return False

def run_pipeline(data_dict, dataset):
    for algorithm in tqdm(['_iql']):
        for threshold in ['threshold_p']:
            dead_traj = data_dict['Dead_icu'][dataset].groupby('traj')\
                .filter(lambda x: x.iloc[-1]['r:reward_dead'] == -1)['traj'].drop_duplicates()
            aki_traj = data_dict['AKI_rrt'][dataset].groupby('traj')\
                .filter(lambda x: x.iloc[-1]['r:reward_aki'] == -1)['traj'].drop_duplicates()
            ss_traj  = data_dict['Septic_shock'][dataset].groupby('traj')\
                .filter(lambda x: x.iloc[-1]['r:reward_septic_shock'] == -1)['traj'].drop_duplicates()

            common_traj = set(dead_traj) & set(aki_traj) & set(ss_traj)
            merged_traj = pd.DataFrame({'traj': list(common_traj)}).reset_index(drop=True)

            for target in ['Dead_icu', 'AKI_rrt', 'Septic_shock']:
                data = data_dict[target][dataset].copy()
                s_col = [x for x in data if x.startswith('s:') and 'prev' not in x] + ['AKI_max_stage']
                o_col = [col.replace('s:', '') for col in s_col if 'prev' not in col]
                data['step'] = data.groupby('traj').cumcount()
                data = data.groupby('traj', group_keys=False).apply(lambda x: x.iloc[:-1]).reset_index(drop=True)
                test_df_RAW = pd.read_csv(f'processed/df_{target}_RAW.csv')
                test_df_RAW = test_df_RAW.groupby('traj', group_keys=False).apply(lambda x: x.iloc[:-1]).reset_index(drop=True)
                data[s_col] = test_df_RAW[o_col]
                data = attach_flags_and_advise(data, data_dict, target, algorithm, threshold, dataset)
                target_df = pd.merge(data, merged_traj[['traj']], on='traj', how='inner')

                for traj_id, df_traj in target_df.groupby('traj'):
                    if fails_filter_conditions(df_traj):
                        merged_traj = merged_traj[merged_traj['traj'] != traj_id]

            for traj_id in merged_traj['traj']:
                if traj_id == 3388:
                    data_list = []
                    threshold_list = []
                    for target in ['Dead_icu','AKI_48','Septic_shock']:
                        data = data_dict[target][dataset].copy()
                        s_col = [x for x in data if x.startswith('s:') and 'prev' not in x] + ['AKI_max_stage']
                        o_col = [col.replace('s:', '') for col in s_col if 'prev' not in col]
                        data['step'] = data.groupby('traj').cumcount()
                        data = data.groupby('traj', group_keys=False).apply(lambda x: x.iloc[:-1]).reset_index(drop=True)
                        test_df_RAW = pd.read_csv(f'processed/df_{target}_RAW.csv')
                        test_df_RAW = pd.merge(test_df_RAW, data['traj'].drop_duplicates(), on='traj', how='inner').reset_index(drop=True)
                        test_df_RAW = test_df_RAW.groupby('traj', group_keys=False).apply(lambda x: x.iloc[:-1]).reset_index(drop=True)
                        print(data.shape, test_df_RAW.shape)
                        data[s_col] = test_df_RAW[o_col]
                        data = attach_flags_and_advise(data, data_dict, target, algorithm, threshold, dataset)
                        df_one_traj = data[data['traj'] == traj_id]
                        if df_one_traj.empty:
                            continue
                        data_list.append(df_one_traj)
                        threshold_list.append(data_dict[target][algorithm]['_positive'][dataset]['thresholds'][threshold+'_med'])
                        threshold_list.append(data_dict[target][algorithm]['_negative'][dataset]['thresholds'][threshold+'_med'])

                    if len(data_list) < 3:
                        continue

                    if not len(data_list[0]) == len(data_list[1]) == len(data_list[2]):
                        min_len = min(len(data_list[0]), len(data_list[1]), len(data_list[2]))
                        data_list[0] = data_list[0].iloc[:min_len]
                        data_list[1] = data_list[1].iloc[:min_len]
                        data_list[2] = data_list[2].iloc[:min_len]

                    df_dead = data_list[0]
                    df_aki  = data_list[1]
                    df_ss   = data_list[2]
                    data_full, steps = combine_three_targets_into_array(df_dead, df_aki, df_ss, ACTION_LIST)
                    if len(steps) == 0:
                        continue
                    plot_combined_heatmap(
                        data_full,
                        steps,
                        df_ss=df_ss,
                        action_list=ACTION_LIST,
                        traj_id=f"{traj_id}",
                        threshold_list=threshold_list,
                        threshold=threshold,
                        dataset=dataset,
                        algorithm=algorithm
                    )

In [None]:
run_pipeline(data_dict,'test')

In [None]:
run_pipeline(data_dict,'temporal')

Target_traj

In [None]:
temporal_id = pd.read_csv('processed/temporal_id.csv')
temporal_id = temporal_id[temporal_id['traj']==3388]
temporal_id

In [None]:
temporal_df = pd.read_csv('processed/temporal_df_Z.csv')
temporal_df = temporal_df[temporal_df['traj']==3388]
temporal_df

Cohort statics

In [None]:
import pandas as pd
import numpy as np

df = pd.read_csv("processed/df_Dead_icu_RAW.csv")

main_id = pd.concat([train_id, valid_id, test_id], axis=0)

df_main = pd.merge(df, main_id[["traj"]], on="traj", how="inner").reset_index(drop=True)
df_temp = pd.merge(df, temporal_id[["traj"]], on="traj", how="inner").reset_index(drop=True)

N_main = df_main["stay_id"].nunique()
N_temp = df_temp["stay_id"].nunique()

def preprocess_fluids_vaso(data: pd.DataFrame):
    grp = data.groupby("stay_id", as_index=False)
    df_fl = grp["input_4hr"].sum()
    df_fl["fluids_binary"] = (df_fl["input_4hr"] > 0).astype(int)
    df_vas = grp["median_vaso"].max()
    df_vas["vaso_binary"] = (df_vas["median_vaso"] > 0).astype(int)
    data = data.merge(df_fl[["stay_id","fluids_binary"]], on="stay_id", how="left")
    data = data.merge(df_vas[["stay_id","vaso_binary"]], on="stay_id", how="left")
    return data

df_main = preprocess_fluids_vaso(df_main)
df_temp = preprocess_fluids_vaso(df_temp)

age_bins = [18, 30, 40, 50, 60, 70, 80, 90, np.inf]
age_labels = ["18-29", "30-39", "40-49", "50-59", "60-69", "70-79", "80-89", "≥90"]
df_main["age_bin"] = pd.cut(df_main["age"], bins=age_bins, labels=age_labels, right=False)
df_temp["age_bin"] = pd.cut(df_temp["age"], bins=age_bins, labels=age_labels, right=False)

table_structure = {
    "Demographics": [
        ("Age, years", "age"),              
        ("Age range, years", "age_bin"),   
        ("Gender", "gender"),              
        ("Re-admission", "re_admission")
    ],
    "Physical exam findings": [
        ("Temperature (\\textcelsius)", "Temperature"),
        ("Weight (kg)", "Weight"),
        ("Heart rate (bpm)", "Heartrate"),
        ("Respiratory rate (breaths/min)", "Resprate"),
        ("Systolic blood pressure (mmHg)", "Systolic_BP"),
        ("Diastolic blood pressure (mmHg)", "Diastolic_BP"),
        ("Mean arterial pressure (mmHg)", "Mean_BP"),
        ("Fraction of inspired oxygen (\\%)", "FiO2"),
        ("P/F ratio", "PaO2/FiO2"),
        ("Glasgow Coma Scale", "GCS")
    ],
    "Laboratory findings": {
        "Hematology": [
            ("White blood cells (thousands/\\micro L)", "WBC"),
            ("Platelets (thousands/\\micro L)", "Platelets"),
            ("Hemoglobin (g/dL)", "Hemoglobin"),
            ("Base Excess (mmol/L)", "BaseExcess"),
        ],
        "Chemistry": [
            ("Sodium (mmol/L)", "Sodium"),
            ("Potassium (mmol/L)", "Potassium"),
            ("Chloride (mmol/L)", "Chloride"),
            ("Bicarbonate (mmol/L)", "Bicarbonate"),
            ("Calcium (mg/dL)", "Calcium"),
            ("Magnesium (mg/dL)", "Magnesium"),
            ("Blood urea nitrogen (mg/dL)", "BUN"),
            ("Creatinine (mg/dL)", "SCr"),
            ("Glucose (mg/dL)", "Glucose"),
            ("SGOT (units/L)", "SGOT"),
            ("SGPT (units/L)", "SGPT"),
            ("Lactate (mg/dL)", "Lactate"),
            ("Total bilirubin (mg/dL)", "Total_Bilirubin")
        ]
    },
    "Outcomes": [
        ("Deceased (ICU mortality)", "morta_icu"),
        ("Vasopressors administered", "vaso_binary"),
        ("Fluids administered", "fluids_binary"),
        ("Ventilator used", "MV")
    ],
    "Severity Scores": [
        ("SOFA", "SOFA"),
        ("SIRS", "SIRS"),
        ("Shock Index", "Shock_Index")
    ],
    "Coagulation": [
        ("Prothrombin time (sec)", "PT"),
        ("Partial thromboplastin time (sec)", "PTT"),
        ("INR", "INR")
    ],
    "Blood gas": [
        ("pH", "Arterial_ph"),
        ("Oxygen saturation (\\%)", "SpO2"),
        ("Partial pressure of O2 (mmHg)", "PaO2"),
        ("Partial pressure of CO2 (mmHg)", "PaCO2")
    ]
}

gender_map = {0: "Male", 1: "Female"}

def summarize_single_dataset(df_local: pd.DataFrame, col_name: str, is_main: bool=True) -> str:
    if col_name not in df_local.columns:
        return "N/A"
    series = df_local[col_name].dropna()
    if series.empty:
        return "No data"
    n_stay = df_local["stay_id"].nunique()
    if col_name == "age_bin":
        by_stay = df_local.groupby("stay_id")[col_name].first()
        counts = by_stay.value_counts(dropna=False)
        lines = []
        for label in age_labels:
            if label in counts.index:
                c = counts[label]
                ratio = (c / n_stay)*100
                lines.append(f"{label} & {c} ({ratio:.1f}\\%)\\\\")
        return (
            r"\begin{tabular}[l]{@{}l@{\hspace{1em}}r@{}}"
            + "\n".join(lines)
            + r"\end{tabular}"
        )
    if col_name == "gender":
        by_stay = df_local.groupby("stay_id")[col_name].first()
        total = len(by_stay)
        num_male   = (by_stay == 0).sum()
        num_female = (by_stay == 1).sum()
        male_pct   = (num_male / total)*100 if total>0 else 0
        fem_pct    = (num_female / total)*100 if total>0 else 0
        lines = []
        lines.append(f"Male & {num_male} ({male_pct:.1f}\\%)\\\\")
        lines.append(f"Female & {num_female} ({fem_pct:.1f}\\%)\\\\")
        return (
            r"\begin{tabular}[l]{@{}l@{\hspace{1em}}r@{}}"
            + "\n".join(lines)
            + r"\end{tabular}"
        )
    unique_vals = series.unique()
    if len(unique_vals) == 2 and sorted(unique_vals) == [0,1] and col_name != "gender":
        s_by_stay = df_local.groupby("stay_id")[col_name].max()
        c1 = (s_by_stay == 1).sum()
        ratio1 = (c1 / n_stay)*100
        return f"{c1} ({ratio1:.1f}\\%)"
    if pd.api.types.is_numeric_dtype(series):
        med = series.median()
        q1 = series.quantile(0.25)
        q3 = series.quantile(0.75)
        return f"{med:.1f} ({q1:.1f}--{q3:.1f})"
    by_stay = df_local.groupby("stay_id")[col_name].first()
    vc = by_stay.value_counts()
    top_cat = vc.index[0]
    top_cnt = vc.iloc[0]
    return f"{top_cat} ({top_cnt} freq)"

latex_lines = []
latex_lines.append(r"\begin{table}[htbp]")
latex_lines.append(r"\centering")
latex_lines.append(r"\caption{Dataset Summary}")
latex_lines.append(r"\begin{tabular}{p{0.4\textwidth} p{0.3\textwidth} p{0.3\textwidth}}")
latex_lines.append(r"\hline")
latex_lines.append(rf" \textbf{{Variable}} & \textbf{{Train+Val+Test (n={N_main})}} & \textbf{{Temporal (n={N_temp})}} \\")
latex_lines.append(r"\hline")

def add_section_header(text, indent=False):
    prefix = r"\quad " if indent else ""
    return rf"\multicolumn{{3}}{{l}}{{{prefix}\textbf{{{text}}}}} \\"

def add_row(var_name, val_main, val_temp, indent=False):
    prefix = r"\quad " if indent else ""
    return f"{prefix}{var_name} & {val_main} & {val_temp} \\\\"

for section_name, items in table_structure.items():
    if isinstance(items, dict):
        latex_lines.append(add_section_header(section_name))
        for subsec_name, varlist in items.items():
            latex_lines.append(add_section_header(subsec_name, indent=True))
            for disp_name, col_name in varlist:
                val_main = summarize_single_dataset(df_main, col_name)
                val_temp = summarize_single_dataset(df_temp, col_name, is_main=False)
                latex_lines.append(add_row(disp_name, val_main, val_temp, indent=True))
    else:
        latex_lines.append(add_section_header(section_name))
        for disp_name, col_name in items:
            val_main = summarize_single_dataset(df_main, col_name)
            val_temp = summarize_single_dataset(df_temp, col_name, is_main=False)
            latex_lines.append(add_row(disp_name, val_main, val_temp, indent=True))

latex_lines.append(r"\hline")
latex_lines.append(r"\end{tabular}")
latex_lines.append(r"\end{table}")

latex_code = "\n".join(latex_lines)

with open("final_summary_table.tex", "w", encoding="utf-8") as f:
    f.write(latex_code)

print(latex_code)


In [None]:
df = pd.read_csv("processed/df_Dead_icu_RAW.csv")

main_id = pd.concat([train_id, valid_id, test_id], axis=0)

df_main = pd.merge(df, main_id[["traj"]], on="traj", how="inner").reset_index(drop=True)
df_temp = pd.merge(df, temporal_id[["traj"]], on="traj", how="inner").reset_index(drop=True)

N_main = df_main["stay_id"].nunique()
N_temp = df_temp["stay_id"].nunique()

def preprocess_fluids_vaso(data: pd.DataFrame):
    grp = data.groupby("stay_id", as_index=False)
    
    df_fl = grp["input_4hr"].sum()
    df_fl["fluids_binary"] = (df_fl["input_4hr"] > 0).astype(int)
    
    df_vas = grp["median_vaso"].max()
    df_vas["vaso_binary"] = (df_vas["median_vaso"] > 0).astype(int)
    
    data = data.merge(df_fl[["stay_id","fluids_binary"]], on="stay_id", how="left")
    data = data.merge(df_vas[["stay_id","vaso_binary"]], on="stay_id", how="left")
    return data

df_main = preprocess_fluids_vaso(df_main)
df_temp = preprocess_fluids_vaso(df_temp)

age_bins = [18, 30, 40, 50, 60, 70, 80, 90, np.inf]
age_labels = ["18-29", "30-39", "40-49", "50-59", "60-69", "70-79", "80-89", "≥90"]

df_main["age_bin"] = pd.cut(df_main["age"], bins=age_bins, labels=age_labels, right=False)
df_temp["age_bin"] = pd.cut(df_temp["age"], bins=age_bins, labels=age_labels, right=False)

def summarize_variable(df_m: pd.DataFrame, df_t: pd.DataFrame, col_name: str):
    main_str = compute_stat_string(df_m, col_name)
    temp_str = compute_stat_string(df_t, col_name)
    return f"Main: {main_str}\\quad Temp: {temp_str}"

def compute_stat_string(df_local: pd.DataFrame, col_name: str) -> str:
    if col_name not in df_local.columns:
        return "N/A"
    
    s = df_local[col_name].dropna()
    if s.empty:
        return "No data"
    
    n_stay = df_local["stay_id"].nunique()
    unique_vals = sorted(s.unique())
    
    if col_name == "gender":
        by_stay = df_local.groupby("stay_id")[col_name].first()
        total = len(by_stay)
        nm = (by_stay==0).sum()
        nf = (by_stay==1).sum()
        pm = (nm/total)*100 if total>0 else 0
        pf = (nf/total)*100 if total>0 else 0
        return f"Male {nm} ({pm:.1f}%), Female {nf} ({pf:.1f}%)"
    
    if col_name == "age_bin":
        by_stay = df_local.groupby("stay_id")[col_name].first()
        counts = by_stay.value_counts()
        total = len(by_stay)
        lines = []
        for lab in age_labels:
            c = counts.get(lab, 0)
            ratio = (c/total)*100 if total>0 else 0
            lines.append(f"{lab}: {c} ({ratio:.1f}%)")
        return "; ".join(lines)
    
    if len(unique_vals)==2 and unique_vals==[0,1] and col_name!="gender":
        st = df_local.groupby("stay_id")[col_name].max()
        c1 = (st==1).sum()
        ratio1 = (c1/len(st))*100 if len(st)>0 else 0
        return f"{c1} ({ratio1:.1f}%)"
    

    if pd.api.types.is_numeric_dtype(s):
        med = s.median()
        q1  = s.quantile(0.25)
        q3  = s.quantile(0.75)
        return f"{med:.1f} ({q1:.1f}--{q3:.1f})"
    

    by_stay = df_local.groupby("stay_id")[col_name].first()
    vc = by_stay.value_counts()
    top_cat = vc.index[0]
    top_cnt = vc.iloc[0]
    return f"{top_cat} ({top_cnt} freq)"

latex_lines = []

latex_lines.append(r"\begin{table}[htbp]")
latex_lines.append(r"\centering")
latex_lines.append(r"\begin{tabular}{l}")
latex_lines.append(r"\hline")
latex_lines.append(r"\\[-1em]")

latex_lines.append(f"Train+Val+Test (n={N_main}), Temporal (n={N_temp})\\\\")
latex_lines.append(r"\\[-1em]\hline\\[-0.5em]")

latex_lines.append(r"\textbf{Demographics}\\")
latex_lines.append(r"\quad Age, years\\")
stat_age_years = summarize_variable(df_main, df_temp, "age")  
latex_lines.append(r"\quad\quad " + stat_age_years + r"\\")  

latex_lines.append(r"\quad Age range, years\\")
latex_lines.append(r"\quad\quad 18-29\\")
latex_lines.append(r"\quad\quad 30-39\\")
latex_lines.append(r"\quad\quad 40-49\\")
latex_lines.append(r"\quad\quad 50-59\\")
latex_lines.append(r"\quad\quad 60-69\\")
latex_lines.append(r"\quad\quad 70-79\\")
latex_lines.append(r"\quad\quad 80-89\\")
latex_lines.append(r"\quad\quad ≥90\\")

stat_age_bin = summarize_variable(df_main, df_temp, "age_bin")
latex_lines.append(r"\quad\quad " + stat_age_bin + r"\\")  

latex_lines.append(r"\quad Gender\\")
latex_lines.append(r"\quad\quad Male\\")
latex_lines.append(r"\quad\quad Female\\")

stat_gender = summarize_variable(df_main, df_temp, "gender")
latex_lines.append(r"\quad\quad " + stat_gender + r"\\")

latex_lines.append(r"\quad Re-admission\\")
stat_readm = summarize_variable(df_main, df_temp, "re_admission")
latex_lines.append(r"\quad\quad " + stat_readm + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\textbf{Physical exam findings}\\")

pe_vars = [
    ("Temperature (\\textcelsius)", "Temperature"),
    ("Weight (kg)", "Weight"),
    ("Heart rate (bpm)", "Heartrate"),
    ("Respiratory rate (breaths/min)", "Resprate"),
    ("Systolic blood pressure (mmHg)", "Systolic_BP"),
    ("Diastolic blood pressure (mmHg)", "Diastolic_BP"),
    ("Mean arterial pressure (mmHg)", "Mean_BP"),
    ("Fraction of inspired oxygen (\\%)", "FiO2"),
    ("P/F ratio", "PaO2/FiO2"),
    ("Glasgow Coma Scale", "GCS")
]
for disp, col in pe_vars:
    latex_lines.append(r"\quad " + disp + r"\\")
    stat_val = summarize_variable(df_main, df_temp, col)
    latex_lines.append(r"\quad\quad " + stat_val + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\textbf{Laboratory findings}\\")

latex_lines.append(r"\quad \textbf{Hematology}\\")
hema_vars = [
    ("White blood cells (thousands/\\micro L)", "WBC"),
    ("Platelets (thousands/\\micro L)", "Platelets"),
    ("Hemoglobin (g/dL)", "Hemoglobin"),
    ("Base Excess (mmol/L)", "BaseExcess")
]
for disp, col in hema_vars:
    latex_lines.append(r"\quad\quad " + disp + r"\\")
    stat_val = summarize_variable(df_main, df_temp, col)
    latex_lines.append(r"\quad\quad\quad " + stat_val + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\quad \textbf{Chemistry}\\")
chem_vars = [
    ("Sodium (mmol/L)", "Sodium"),
    ("Potassium (mmol/L)", "Potassium"),
    ("Chloride (mmol/L)", "Chloride"),
    ("Bicarbonate (mmol/L)", "Bicarbonate"),
    ("Calcium (mg/dL)", "Calcium"),
    ("Magnesium (mg/dL)", "Magnesium"),
    ("Blood urea nitrogen (mg/dL)", "BUN"),
    ("Creatinine (mg/dL)", "SCr"),
    ("Glucose (mg/dL)", "Glucose"),
    ("SGOT (units/L)", "SGOT"),
    ("SGPT (units/L)", "SGPT"),
    ("Lactate (mg/dL)", "Lactate"),
    ("Total bilirubin (mg/dL)", "Total_Bilirubin")
]
for disp, col in chem_vars:
    latex_lines.append(r"\quad\quad " + disp + r"\\")
    stat_val = summarize_variable(df_main, df_temp, col)
    latex_lines.append(r"\quad\quad\quad " + stat_val + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\textbf{Outcomes}\\")

out_vars = [
    ("Deceased (ICU mortality)", "morta_icu"),
    ("Vasopressors administered", "vaso_binary"),
    ("Fluids administered", "fluids_binary"),
    ("Ventilator used", "MV")
]
for disp, col in out_vars:
    latex_lines.append(r"\quad " + disp + r"\\")
    stat_val = summarize_variable(df_main, df_temp, col)
    latex_lines.append(r"\quad\quad " + stat_val + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\textbf{Severity Scores}\\")
sev_vars = [
    ("SOFA", "SOFA"),
    ("SIRS", "SIRS"),
    ("Shock Index", "Shock_Index")
]
for disp, col in sev_vars:
    latex_lines.append(r"\quad " + disp + r"\\")
    stat_val = summarize_variable(df_main, df_temp, col)
    latex_lines.append(r"\quad\quad " + stat_val + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\textbf{Coagulation}\\")
coag_vars = [
    ("Prothrombin time (sec)", "PT"),
    ("Partial thromboplastin time (sec)", "PTT"),
    ("INR", "INR")
]
for disp, col in coag_vars:
    latex_lines.append(r"\quad " + disp + r"\\")
    stat_val = summarize_variable(df_main, df_temp, col)
    latex_lines.append(r"\quad\quad " + stat_val + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\textbf{Blood gas}\\")
bg_vars = [
    ("pH", "Arterial_ph"),
    ("Oxygen saturation (\\%)", "SpO2"),
    ("Partial pressure of O2 (mmHg)", "PaO2"),
    ("Partial pressure of CO2 (mmHg)", "PaCO2")
]
for disp, col in bg_vars:
    latex_lines.append(r"\quad " + disp + r"\\")
    stat_val = summarize_variable(df_main, df_temp, col)
    latex_lines.append(r"\quad\quad " + stat_val + r"\\")

latex_lines.append(r"\\[-0.5em]")

latex_lines.append(r"\hline")
latex_lines.append(r"\end{tabular}")
latex_lines.append(r"\end{table}")

latex_code = "\n".join(latex_lines)

with open("final_structure.tex", "w", encoding="utf-8") as f:
    f.write(latex_code)

print(latex_code)