In [None]:
import sys
import numpy as np
import pandas as pd
from PredictionPipelineV3 import *

In [None]:
merged_df_val = pd.read_csv("/manitou/pmg/users/mc5672/post_processing_data/merged_df_val_clr_for_testing.csv")
merged_df_val = merged_df_val.rename(columns={'days_since_icu': 'days since icu'})
merged_df_val.head()

In [None]:
def include_upto_infection_val(group):
    group['date_of_sample'] = pd.to_datetime(group['date_of_sample'])
    if pd.notnull(group['infectiondate1']).any():
        first_infection_date = pd.to_datetime(group['infectiondate1'].min())
        return group[group['date_of_sample'] <= first_infection_date]
    else:
        return group  # keep all if no infection recorded

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics
import seaborn as sns

# -------- Parameters --------
base_path = "/burg/pmg/users/mc5672/preds_new/preds_new"
save_path = "/burg/pmg/users/mc5672/preds_new/images"
os.makedirs(save_path, exist_ok=True)

# Map for prediction targets (some filenames differ from the column name)
target_map = {
    "infection_0": "infection",
    "infection_any": "infection",
    "infection_next7": "infection_next7",
    "infection_next10": "infection_next10",
    "death_0": "death",
    "death_any": "death",
    "death_next7": "death_next7",
    "death_next10": "death_next10",
}

outcomes = list(target_map.keys())

# -------- Evaluation and Plotting --------
for outcome in outcomes:
    actual_target = target_map[outcome]
    # Customize title
    if outcome.endswith("_0"):
        title = f"{actual_target.title()} Day 0 Prediction"
    else:
        title_target = actual_target.replace("_", " ").title()
        title = f"{title_target} Prediction"

    model_info = {
        "SOFA":        f"{outcome}_sofa_nested_logit_logistic_optimized_5_10_3_1476",
        "SOFA+ASV":    f"{outcome}_sofaasv_nested_logit_logistic_optimized_5_10_3_1476",
        "ASV":         f"{outcome}_asv_nested_logit_logistic_optimized_5_10_3_1476",
    }

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    cohorts = ["Original", "Validation"]

    for ax, is_orig, cohort_label in zip(axes, [True, False], cohorts):
        for label, path_stub in model_info.items():
            pred = load_pred(os.path.join(base_path, f"{path_stub}.pkl"))

            if is_orig:
                tmp = pred.run_nested_evaluation(metric='auc', plot=False)
                res = pd.DataFrame({'y_test': tmp['y_test'], 'y_pred': tmp['y_pred']})
            else:
                merged_df_val_copy = merged_df_val.copy()
                if outcome in ['infection_next7', 'infection_next10']:
                    merged_df_val_copy = merged_df_val_copy.groupby('id', group_keys=False).apply(include_upto_infection_val)
                elif outcome in ['infection_0', 'death_0']:
                    merged_df_val_copy = merged_df_val_copy[merged_df_val_copy['days since icu'] <= 1] #NB: Setting days <= 1 instead of 0
                val_X = merged_df_val_copy[pred.X.columns]
                val_y = merged_df_val_copy[[actual_target]].rename(columns={actual_target: 'outcome'})
                top_iteration = pd.DataFrame(pred.nested_evals).iloc[0]['top_iteration']
                best_model = pred.pipes[top_iteration]
                best_model.fit(pred.X, pred.y)
                val_y_pred = best_model.predict(val_X)
                res = pd.DataFrame({
                    'y_test': val_y.loc[val_y_pred.index, 'outcome'].values,
                    'y_pred': val_y_pred[1].values
                })

            fpr, tpr, _ = metrics.roc_curve(res['y_test'], res['y_pred'])
            auc = metrics.roc_auc_score(res['y_test'], res['y_pred'])
            ax.plot(fpr, tpr, label=f"{label} (AUROC={auc:.2f})")

        ax.plot([0, 1], [0, 1], 'k--')
        ax.set_xlim(-0.01, 1.01)
        ax.set_ylim(-0.01, 1.01)
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title(f"{cohort_label} Cohort", fontsize=12, fontweight='bold')
        ax.legend(loc='lower right', fontsize=9)
        ax.tick_params(labelsize=11)

    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(os.path.join(save_path, f"{outcome}_roc.png"), dpi=300)
    plt.close()