In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch
import nbimporter
import shap

import os
import sys

os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

from models import *
from data.data_processing import DataSplitter, DataPreprocessor
from data.lookups import lookup_manager
from utils.settings import settings
from src.main.python.analysis.predictive_algorithms_training import get_data, plot_different_models_survival_curves

preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

In [None]:
def get_preprocessed_data_with_sourceId(preprocessor):
    df_raw = preprocessor.load_data()
    df_all, updated_features, _ = preprocessor.preprocess_data(
        lookup_manager.features, df=df_raw
    )
    df_all["sourceId"] = df_raw.loc[df_all.index, "sourceId"]
    return df_raw, df_all, updated_features

df_raw, df_all, updated_features = get_preprocessed_data_with_sourceId(preprocessor)

## Load trained models

In [None]:
def load_model_outcomes():
    csv_file = os.path.join(f"{settings.save_path}", f"{settings.outcome}_model_outcomes.csv")
    
    if os.path.exists(csv_file):
        results_df = pd.read_csv(csv_file)
        print(f"Loaded model outcomes from {csv_file}")
    else:
        raise FileNotFoundError(f"No saved outcomes found for {settings.outcome} in {settings.save_path}")
    
    return results_df

In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        
        model.model.net.load_state_dict(state['net_state'])
    
        if 'labtrans' in state:
            model.labtrans             = state['labtrans']
            model.model.duration_index = model.labtrans.cuts
        
        if 'baseline_hazards' in state:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
            
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models(X_train):
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        print(model_name, model_class)
        try:
            loaded_model = load_trained_model(
                model_name=model_name, 
                model_class=model_class, 
                model_kwargs=model_kwargs
            )
            loaded_models[model_name] = loaded_model

            ModelTrainer._set_attention_indices(loaded_models[model_name], list(X_train.columns))
        except:
            print(f'Could not load: {model_name}')
            continue
    return loaded_models

In [None]:
model_outcomes = load_model_outcomes()

In [None]:
trained_models = load_all_trained_models(X_train)

## Model Output

In [None]:
import json
with open('src/main/python/data/treatment_combinations.json', 'r') as f:
    valid_treatment_combinations = json.load(f)

## Analysis

In [None]:
def plot_threshold_analysis_by_treatment(
    model,
    df_all,
    treatment_map: dict = None,
    treatment_prefix: str = "systemicTreatmentPlan",
    horizon_days: int = 365,
    model_name: str = "Model",
    n_thresholds: int = 100,
    palette=plt.cm.tab20,
):
    if treatment_map is None:
        cols = [c for c in df_all.columns if c.startswith(treatment_prefix)]
        if not cols:
            raise ValueError(
                f"No columns found with prefix '{treatment_prefix}'. "
            )
        treatment_map = {"No Treatment": {}}
        treatment_map.update({
            col[len(treatment_prefix):] or col: {col: 1}
            for col in cols
        })

    preds_by_tx = get_preds_for_all_treatments(
        model, df_all,
        treatment_map=treatment_map,
        treatment_prefix=treatment_prefix,
        horizon_days=horizon_days,
    )

    thresholds = np.linspace(0.0, 1.0, n_thresholds)
    fig, (ax_top, ax_bot) = plt.subplots(
        2, 1, figsize=(11, 8), sharex=True,
        gridspec_kw={"height_ratios": [3, 1.2]}
    )

    colors = palette(np.linspace(0, 1, len(preds_by_tx)))
    for (label, df_preds), col in zip(preds_by_tx.items(), colors):
        med, q1, q3, lo, hi, n = [], [], [], [], [], []

        for t in thresholds:
            sub = df_preds[df_preds["predicted_prob_1yr"] < t]
            n.append(len(sub))
            if sub.empty:
                med.append(np.nan); q1.append(np.nan); q3.append(np.nan)
                lo.append(np.nan);  hi.append(np.nan)
                continue
            times = sub["actual_survival_time"].values
            q1_t, q3_t = np.percentile(times, [25, 75])
            iqr = q3_t - q1_t
            med.append(np.median(times))
            q1.append(q1_t); q3.append(q3_t)
            lo.append(max(times.min(), q1_t - 1.5 * iqr))
            hi.append(min(times.max(), q3_t + 1.5 * iqr))

        ax_top.fill_between(thresholds, q1, q3, color=col, alpha=0.15)
        ax_top.plot(thresholds, med, color=col, lw=2, label=f"{label} – median")
        ax_top.plot(thresholds, lo,  color=col, ls=":")
        ax_top.plot(thresholds, hi,  color=col, ls=":")
        ax_bot.step(thresholds, n, where="post", color=col, lw=1.5, label=label)

    label = f"{horizon_days//365}-year ({horizon_days} d)" if horizon_days % 365 == 0 else f"{horizon_days} days"
    ax_top.axhline(horizon_days, color="black", ls="--", label=label)
    ax_top.set_ylabel("Actual survival time (days)")
    ax_top.grid(alpha=0.3); ax_top.set_ylim(bottom=0)
    ax_top.legend(fontsize=8)

    ax_bot.set_xlabel("Predicted 1-year survival probability")
    ax_bot.set_ylabel("Patients\nbelow cut-off")
    ax_bot.grid(alpha=0.3); ax_bot.set_ylim(bottom=0)
    ax_bot.legend(fontsize=8, ncol=2)

    plt.suptitle(
        f"{model_name}: Actual survival vs survival probability "
        f"(survival predicted at {horizon_days} d)",
        y=0.97
    )
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [None]:
def get_preds_for_all_treatments(
    model,
    df_all,
    treatment_map: dict,
    treatment_prefix: str = "systemicTreatmentPlan",
    horizon_days: int = 365,
    background_size: int = 150
):

    durations = df_all[settings.duration_col].values
    events    = df_all[settings.event_col].astype(bool).values
    X_base    = df_all.drop(columns=["sourceId",
                                     settings.event_col,
                                     settings.duration_col])

    treatment_cols = [c for c in X_base if c.startswith(treatment_prefix)]

    preds_by_tx = {}

    for label, mapping in treatment_map.items():
        X_mod = apply_treatment(
            X_base, mapping, treatment_cols,
            msi_flag=(df_all.get("hasMsi", pd.Series(0))).astype(bool)
        )
        df_preds = get_patient_predictions(
            model,
            X_mod, durations, events,
            horizon_days=horizon_days
        )
        preds_by_tx[label] = df_preds

    return preds_by_tx


In [None]:
def apply_treatment(df, mapping, treatment_cols, msi_flag):
    df_copy = df.copy()
    df_copy[treatment_cols] = 0
    for col, val in mapping.items():
        if col in df_copy:
            df_copy[col] = val
    if "hasMsi" in df_copy:
        df_copy["hasMsi"] = msi_flag
        
    df_copy["hasTreatment"] = (
        df_copy[treatment_cols].sum(axis=1) > 0
    ).astype(int)
    return df_copy

In [None]:
def get_patient_predictions(model, X, durations, events, horizon_days=365):
    mask_known = ~((durations <= horizon_days) & (~events))
    durations_masked = durations[mask_known]
    events_masked = events[mask_known]
    X_masked = X.loc[mask_known]

    if len(X_masked) == 0:
        return None

    surv_funcs = model.predict_survival_function(X_masked)
    predicted_probs = np.array([fn(horizon_days) for fn in surv_funcs])

    df_preds = pd.DataFrame({
        "predicted_prob_1yr": predicted_probs,
        "actual_survival_time": durations_masked,
        "event_observed": events_masked
    })

    return df_preds

In [None]:
five_fu_map = {"5-FU": valid_treatment_combinations["5-FU"]}

treat_prefix = "systemicTreatmentPlan_"
cols_treat   = [c for c in df_all.columns if c.startswith(treat_prefix)]

col_5fu      = "systemicTreatmentPlan_5-FU"
other_cols   = [c for c in cols_treat if c != col_5fu]

mask_5fu = (
    (df_all[col_5fu] == 1) &
    (df_all[other_cols].sum(axis=1) == 0)
)

df_5fu = df_all[mask_5fu].copy()
print(f"{len(df_5fu)} patients with 5-FU")


plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_5fu,
    treatment_map    = five_fu_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 365,
    model_name       = "DeepSurv + Attention"
)


In [None]:
feature_short_names = {
        "ageAtMetastasisDetection": "Age (metastasis)",
        "albumine": "Albumin",
        "alkalinePhosphatase": "Alk. phosphatase",
        "anorectalVergeDistanceCategory": "Tumor–ARV distance",
        "asaClassificationPreSurgeryOrEndoscopy": "ASA class",
        "carcinoEmbryonicAntigen": "CEA",
        "cci": "CCI (score)",
        "cciHasAids": "CCI: AIDS",
        "cciHasCerebrovascularDisease": "CCI: Cerebrovascular disease",
        "cciHasCollagenosis": "CCI: Collagenosis",
        "cciHasCongestiveHeartFailure": "CCI: Heart failure",
        "cciHasCopd": "CCI: COPD",
        "cciHasDementia": "CCI: Dementia",
        "cciHasDiabetesMellitus": "CCI: Diabetes",
        "cciHasDiabetesMellitusWithEndOrganDamage": "CCI: Diabetes w/ EOD",
        "cciHasHemiplegiaOrParaplegia": "CCI: Hemiplegia",
        "cciHasLiverDisease": "CCI: Liver disease",
        "cciHasMildLiverDisease": "CCI: Mild liver disease",
        "cciHasMyocardialInfarct": "CCI: MI",
        "cciHasOtherMalignancy": "CCI: Other malignancy",
        "cciHasOtherMetastaticSolidTumor": "CCI: Other metastasis",
        "cciHasPeripheralVascularDisease": "CCI: PVD",
        "cciHasRenalDisease": "CCI: Renal disease",
        "cciHasUlcerDisease": "CCI: Ulcer disease",
        "cciNumberOfCategories": "CCI: # categories",
        "distanceToMesorectalFasciaMm": "MRF distance (mm)",
        "hasBrafMutation": "BRAF mut.",
        "hasBrafV600EMutation": "BRAF V600E",
        "hasDoublePrimaryTumor": "Double primary",
        "hasHadPriorTumor": "Prior tumor",
        "hasKrasG12CMutation": "KRAS G12C",
        "hasMsi": "MSI",
        "hasRasMutation": "RAS mut.",
        "investigatedLymphNodesNumber": "# nodes (investigated)",
        "lactateDehydrogenase": "LDH",
        "leukocytesAbsolute": "Leukocytes",
        "maximumSizeOfLiverMetastasisMm": "Max liver met (mm)",
        "mesorectalFasciaIsClear": "MRF clear",
        "neutrophilsAbsolute": "Neutrophils",
        "numberOfLiverMetastases": "# liver mets",
        "positiveLymphNodesNumber": "# positive nodes",
        "presentedWithIleus": "Ileus at presentation",
        "presentedWithPerforation": "Perforation",
        "sex": "Sex",
        "sidedness": "Tumor sidedness",
        "stageCTNM": "cTNM",
        "stagePTNM": "pTNM",
        "stageTNM": "TNM",
        "tumorDifferentiationGrade": "Grade",
        "tumorIncidenceYear": "Incidence year",
        "whoStatusPreTreatmentStart": "WHO status",
        "observedOsFromMetastasisDetectionDays": "OS (days)",

        # Systemic treatment flags
        "systemicTreatmentPlan_5-FU": "5-FU",
        "systemicTreatmentPlan_oxaliplatin": "Oxaliplatin",
        "systemicTreatmentPlan_irinotecan": "Irinotecan",
        "systemicTreatmentPlan_bevacizumab": "Bevacizumab",
        "systemicTreatmentPlan_panitumab": "Panitumumab",
        "systemicTreatmentPlan_pembrolizumab": "Pembrolizumab",
        "systemicTreatmentPlan_nivolumab": "Nivolumab",

        # Metastasis locations
        "metastasisLocationGroupsPriorToSystemicTreatment_BRAIN": "Met: Brain",
        "metastasisLocationGroupsPriorToSystemicTreatment_BRONCHUS_AND_LUNG": "Met: Lung",
        "metastasisLocationGroupsPriorToSystemicTreatment_LIVER_AND_INTRAHEPATIC_BILE_DUCTS": "Met: Liver",
        "metastasisLocationGroupsPriorToSystemicTreatment_LYMPH_NODES": "Met: LN",
        "metastasisLocationGroupsPriorToSystemicTreatment_OTHER": "Met: Other",
        "metastasisLocationGroupsPriorToSystemicTreatment_PERITONEUM": "Met: Peritoneum",

        # Tumor types
        "consolidatedTumorType_CRC_ADENOCARCINOMA": "CRC: Adeno",
        "consolidatedTumorType_CRC_MUCINOUS": "CRC: Mucinous",
        "consolidatedTumorType_CRC_OTHER": "CRC: Other",
        "consolidatedTumorType_CRC_SIGNET_RING_CELL": "CRC: Signet ring",

        # Extra mural invasion
        "extraMuralInvasionCategory_ABOVE_FIVE_MM": "EMI >5mm",
        "extraMuralInvasionCategory_LESS_THAN_FIVE_MM": "EMI <5mm",
        "extraMuralInvasionCategory_NA": "EMI: NA",

        # Basis of diagnosis
        "tumorBasisOfDiagnosis_CLINICAL_AND_DIAGNOSTIC_INVESTIGATION": "Dx: Clinical + diag",
        "tumorBasisOfDiagnosis_CLINICAL_ONLY_INVESTIGATION": "Dx: Clinical only",
        "tumorBasisOfDiagnosis_CYTOLOGICAL_CONFIRMATION": "Dx: Cytology",
        "tumorBasisOfDiagnosis_HISTOLOGICAL_CONFIRMATION": "Dx: Histology",
        "tumorBasisOfDiagnosis_HISTOLOGICAL_CONFIRMATION_METASTASES": "Dx: Histology (met)",
        "tumorBasisOfDiagnosis_SPEC_BIOCHEMICAL_IMMUNOLOGICAL_LAB_INVESTIGATION": "Dx: Biochem/immuno",

        # Tumor location
        "tumorLocation_APPENDIX": "Tumor: Appendix",
        "tumorLocation_ASCENDING_COLON": "Tumor: Asc. colon",
        "tumorLocation_COECUM": "Tumor: Cecum",
        "tumorLocation_COLON_NOS": "Tumor: Colon NOS",
        "tumorLocation_COLON_OVERLAPPING": "Tumor: Overlapping",
        "tumorLocation_DESCENDING_COLON": "Tumor: Desc. colon",
        "tumorLocation_FLEXURA_HEPATICA": "Tumor: Hepatic flexure",
        "tumorLocation_FLEXURA_LIENALIS": "Tumor: Splenic flexure",
        "tumorLocation_OVARY": "Tumor: Ovary",
        "tumorLocation_RECTOSIGMOID": "Tumor: Rectosigmoid",
        "tumorLocation_RECTUM": "Tumor: Rectum",
        "tumorLocation_SIGMOID_COLON": "Tumor: Sigmoid",
        "tumorLocation_TRANSVERSE_COLON": "Tumor: Transverse",
    }

In [None]:
import matplotlib.gridspec as gridspec
from PIL import Image
import io
import pymysql

def compute_risk_at_horizon(model, X, horizon_days=360):
    survival_functions = model.predict_survival_function(X)
    return np.array([
        1.0 - np.interp(horizon_days, sf.x, sf.y)
        for sf in survival_functions
    ])

def create_shap_image(shap_values, max_display=15, feature_short_names=None):
    import matplotlib.pyplot as plt
    from PIL import Image
    import io

    if feature_short_names:
        shap_values.feature_names = [
            feature_short_names.get(name, name)
            for name in shap_values.feature_names
        ]

    fig, ax = plt.subplots(figsize=(16, 5.5))
    shap.plots.bar(shap_values, max_display=max_display, show=False)
    for label in ax.get_yticklabels():
        label.set_clip_on(False)
        label.set_horizontalalignment("right")
    plt.subplots_adjust(left=0.28)
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=600)
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)


def apply_treatment(df, mapping, treatment_cols, msi_flag):
    df_copy = df.copy()
    df_copy[treatment_cols] = 0
    for col, val in mapping.items():
        if col in df_copy:
            df_copy[col] = val
    if "hasMsi" in df_copy:
        df_copy["hasMsi"] = msi_flag
        
    df_copy["hasTreatment"] = (
        df_copy[treatment_cols].sum(axis=1) > 0
    ).astype(int)
    return df_copy
    

In [None]:
def compute_effective_input(model, X_tensor: torch.Tensor) -> np.ndarray:
    attn_layer = model.model.net.attention

    x = X_tensor.clone()
    x_gated = attn_layer._apply_gate(x)

    with torch.no_grad():
        weights = attn_layer.attn(x_gated).cpu().numpy()[0]  

    x_gated_np = x_gated.cpu().numpy()[0]

    x_effective = x_gated_np * weights

    return x_effective

In [None]:
def plot_survival_and_shap_grid(
    model, df_raw, df_all, patient_id: int = None, patient_row: dict = None,
    treatment_map: dict = None, treatment_prefix="systemicTreatmentPlan",
    horizons=(12,), background_size=150, max_shap_display=15, display_stats=('median', 'auc'),
    counter: dict = None, total: int = None
):
    import pandas as pd
    import numpy as np
    import shap
    import torch
    from numpy import trapz

    rmst_results = {}

    if patient_row is not None:
        df_raw = pd.concat([df_raw, pd.DataFrame([patient_row])], ignore_index=True)
        target_id = patient_row['sourceId']
    else:
        target_id = patient_id

    patient_df = df_all[df_all["sourceId"] == target_id]
    if patient_df.empty:
        print(f"Skipping patient {target_id}: not found after preprocessing")
        return

    X_base = patient_df.drop(columns=["sourceId", settings.event_col, settings.duration_col])

    def flag(col):
        if col in X_base.columns and not X_base[col].isna().all():
            return int(X_base[col].iloc[0])
        return 0

    treatment_cols = [c for c in X_base if c.startswith(treatment_prefix)]
    survival_fs = model.predict_survival_function(X_base)
    time_start = max(sf.x[0] for sf in survival_fs)
    time_end = min(sf.x[-1] for sf in survival_fs)
    time_grid = np.linspace(time_start, time_end, 100)
    risks = {}

    for label, mapping in treatment_map.items():
        X_mod = apply_treatment(X_base, mapping, treatment_cols, flag("hasMsi"))
        surv_fn, = model.predict_survival_function(X_mod)
        surv_prob = surv_fn(time_grid)

        tau_days = 1095
        rmst_mask = time_grid <= tau_days
        rmst_days = trapz(surv_prob[rmst_mask], time_grid[rmst_mask])
        rmst_mo = rmst_days / 30.44
        rmst_results[label] = rmst_mo

        risks[label] = compute_risk_at_horizon(model, X_mod, 12 * 30)

    is_chemo = lambda label: "no treatment" not in label.lower()
    no_tx = next((l for l in risks if "no treatment" in l.lower()), None)
    chemo_candidates = [l for l in risks if is_chemo(l)]
    best_chemo = min(chemo_candidates, key=risks.get) if chemo_candidates else None
    worst_candidates = [l for l in risks if l != no_tx] if no_tx else list(risks.keys())
    worst = max(worst_candidates, key=risks.get) if worst_candidates else None

    showcase = [no_tx, best_chemo, worst]
    showcase = [s for s in showcase if s is not None]

    X_bg = df_all.drop(columns=["sourceId", settings.event_col, settings.duration_col]).sample(min(background_size, len(df_all)), random_state=42)

    for H in horizons:
        for lbl in showcase:
            X_mod = apply_treatment(X_base, treatment_map[lbl], treatment_cols, flag("hasMsi"))
            explainer = shap.Explainer(lambda x: compute_risk_at_horizon(model, x, H * 30), X_bg, feature_names=X_bg.columns)
            sv = explainer(X_mod)[0]

            if hasattr(model.model.net, "attention"):
                X_tensor = torch.tensor(X_mod.values, dtype=torch.float32)
                x_effective = compute_effective_input(model, X_tensor)
                mask = (x_effective != 0).astype(float)
                sv.values = sv.values * mask

            sv_exp = shap.Explanation(
                sv.values,
                base_values=sv.base_values,
                data=sv.data,
                feature_names=X_bg.columns.tolist()
            )

            save_df = pd.DataFrame({
                "feature": sv_exp.feature_names,
                "shap_value": sv_exp.values,
                "data_value": sv_exp.data
            })
            save_df.insert(0, "treatment", lbl)
            save_df.insert(0, "horizon_months", H)
            save_df.insert(0, "patient_id", target_id)
            save_df.to_csv(f"shap_patient_{target_id}_{lbl.replace(' ', '_')}_{H}mo.csv", index=False)

    rmst_df = pd.DataFrame.from_dict(rmst_results, orient="index", columns=["RMST_months"])
    rmst_df.insert(0, "patient_id", target_id)
    rmst_df.to_csv(f"rmst_patient_{target_id}.csv")

    if counter is not None:
        counter['processed'] += 1
        if total is not None:
            print(f"Processed SHAP for patient {target_id} ({counter['processed']}/{total})")
        else:
            print(f"Processed SHAP for patient {target_id} (#{counter['processed']})")

    return rmst_results

In [None]:
def plot_survival_and_shap_grid(
    model, df_raw, df_all, patient_id: int = None, patient_row: dict = None,
    treatment_map: dict = None, treatment_prefix="systemicTreatmentPlan",
    horizons=(12,), background_size=150, max_shap_display=15, display_stats=('median', 'auc'),
    counter: dict = None, total: int = None
):
    import pandas as pd
    import numpy as np
    import shap
    import torch
    from numpy import trapz

    rmst_results = {}

    if patient_row is not None:
        df_raw = pd.concat([df_raw, pd.DataFrame([patient_row])], ignore_index=True)
        target_id = patient_row['sourceId']
    else:
        target_id = patient_id

    patient_df = df_all[df_all["sourceId"] == target_id]
    if patient_df.empty:
        print(f"Skipping patient {target_id}: not found after preprocessing")
        return

    actual_time = patient_df[settings.duration_col].values[0]
    actual_event = patient_df[settings.event_col].values[0]

    X_base = patient_df.drop(columns=["sourceId", settings.event_col, settings.duration_col])

    def flag(col):
        if col in X_base.columns and not X_base[col].isna().all():
            return int(X_base[col].iloc[0])
        return 0

    treatment_cols = [c for c in X_base if c.startswith(treatment_prefix)]
    survival_fs = model.predict_survival_function(X_base)
    time_start = max(sf.x[0] for sf in survival_fs)
    time_end = min(sf.x[-1] for sf in survival_fs)
    time_grid = np.linspace(time_start, time_end, 100)
    risks = {}

    for label, mapping in treatment_map.items():
        X_mod = apply_treatment(X_base, mapping, treatment_cols, flag("hasMsi"))
        surv_fn, = model.predict_survival_function(X_mod)
        surv_prob = surv_fn(time_grid)

        tau_days = 1095
        rmst_mask = time_grid <= tau_days
        rmst_days = trapz(surv_prob[rmst_mask], time_grid[rmst_mask])
        rmst_mo = rmst_days / 30.44
        rmst_results[label] = rmst_mo

        risks[label] = compute_risk_at_horizon(model, X_mod, 12 * 30)

    is_chemo = lambda label: "no treatment" not in label.lower()
    no_tx = next((l for l in risks if "no treatment" in l.lower()), None)
    chemo_candidates = [l for l in risks if is_chemo(l)]
    best_chemo = min(chemo_candidates, key=risks.get) if chemo_candidates else None
    worst_candidates = [l for l in risks if l != no_tx] if no_tx else list(risks.keys())
    worst = max(worst_candidates, key=risks.get) if worst_candidates else None

    showcase = [no_tx, best_chemo, worst]
    showcase = [s for s in showcase if s is not None]

    X_bg = df_all.drop(columns=["sourceId", settings.event_col, settings.duration_col]).sample(
        min(background_size, len(df_all)), random_state=42
    )

    for H in horizons:
        for lbl in showcase:
            X_mod = apply_treatment(X_base, treatment_map[lbl], treatment_cols, flag("hasMsi"))
            explainer = shap.Explainer(
                lambda x: compute_risk_at_horizon(model, x, H * 30),
                X_bg,
                feature_names=X_bg.columns
            )
            sv = explainer(X_mod)[0]

            if hasattr(model.model.net, "attention"):
                X_tensor = torch.tensor(X_mod.values, dtype=torch.float32)
                x_effective = compute_effective_input(model, X_tensor)
                mask = (x_effective != 0).astype(float)
                sv.values = sv.values * mask

            sv_exp = shap.Explanation(
                sv.values,
                base_values=sv.base_values,
                data=sv.data,
                feature_names=X_bg.columns.tolist()
            )

            save_df = pd.DataFrame({
                "feature": sv_exp.feature_names,
                "shap_value": sv_exp.values,
                "data_value": sv_exp.data
            })
            save_df.insert(0, "treatment", lbl)
            save_df.insert(0, "horizon_months", H)
            save_df.insert(0, "patient_id", target_id)
            save_df.insert(0, "event_observed", actual_event)
            save_df.insert(0, "actual_survival_time", actual_time)
            save_df.to_csv(f"shap_patient_{target_id}_{lbl.replace(' ', '_')}_{H}mo.csv", index=False)

    rmst_df = pd.DataFrame.from_dict(rmst_results, orient="index", columns=["RMST_months"])
    rmst_df.insert(0, "patient_id", target_id)
    rmst_df.insert(1, "actual_survival_time", actual_time)
    rmst_df.insert(2, "event_observed", actual_event)
    rmst_df.to_csv(f"rmst_patient_{target_id}.csv")


    return rmst_results


In [None]:
counter = {'processed': 0}
total_patients = min(100, len(df_5fu["sourceId"].unique()))

for i, pid in enumerate(df_5fu["sourceId"].unique()):
    if i >= 1300:
        break
    plot_survival_and_shap_grid(
        model=trained_models["DeepSurv_attention"],
        df_raw=df_raw,
        df_all=df_5fu,
        treatment_map=five_fu_map,
        treatment_prefix=treat_prefix,
        patient_id=pid,
        counter=counter,
        total=total_patients
    )


In [None]:
import pandas as pd
import glob
import matplotlib.pyplot as plt

def combine_and_plot_shap_summary(folder_path=".", horizon_months=12, treatment_filter=None, output_file="shap_combined.csv"):
    """
    Combine all per-patient SHAP CSVs and plot a mean SHAP summary bar chart.

    Parameters:
    - folder_path (str): Where shap_patient_*.csv files are stored.
    - horizon_months (int): Filter for a specific horizon if needed.
    - treatment_filter (str or None): Filter for a specific treatment (e.g., "5-FU").
    - output_file (str or None): Optional path to save combined data.
    """
    pattern = f"{folder_path}/shap_patient_*_{horizon_months}mo.csv"
    files = glob.glob(pattern)
    
    if not files:
        print("No SHAP files found for horizon:", horizon_months)
        return pd.DataFrame()

    df_list = []
    for f in files:
        df = pd.read_csv(f)
        if treatment_filter and treatment_filter.lower() not in df["treatment"].iloc[0].lower():
            continue
        df_list.append(df)

    if not df_list:
        print(f"No SHAP files matched treatment='{treatment_filter}' and horizon={horizon_months}")
        return pd.DataFrame()

    df_all = pd.concat(df_list, ignore_index=True)

    summary_df = (
        df_all.groupby("feature")["shap_value"]
        .mean()
        .sort_values(ascending=False)
        .reset_index()
    )

    top_n = 20
    plt.figure(figsize=(10, 6))
    plt.barh(summary_df["feature"].iloc[:top_n][::-1], summary_df["shap_value"].iloc[:top_n][::-1])
    plt.xlabel("Mean SHAP Value")
    plt.title(f"Top {top_n} Features (Mean SHAP, Horizon={horizon_months}mo)")
    plt.tight_layout()
    plt.show()

    if output_file:
        df_all.to_csv(output_file, index=False)
        print(f"Saved combined SHAP data to '{output_file}'")

    return summary_df


In [None]:
combine_and_plot_shap_summary(
    folder_path=".",
    horizon_months=12,
    treatment_filter="5-FU",
    output_file="shap_combined_5fu.csv"
)


In [None]:
import pandas as pd
import glob
import matplotlib.pyplot as plt
import os

def combine_and_plot_shap_summary_by_survival_group(
    folder_path=".",
    horizon_months=12,
    treatment_filter=None,
    output_prefix="shap_combined"
):
    """
    Combine per-patient SHAP CSVs, split by survival group, and plot mean SHAP summaries.

    Groups:
    - Short survival: < 6 months
    - Long survival: > 12 months

    Parameters:
    - folder_path (str): Where shap_patient_*.csv files are stored.
    - horizon_months (int): Filter for a specific horizon.
    - treatment_filter (str or None): Filter by treatment name if needed.
    - output_prefix (str): Prefix for saved output files.
    """
    pattern = f"{folder_path}/shap_patient_*_{horizon_months}mo.csv"
    files = glob.glob(pattern)

    if not files:
        print("No SHAP files found for horizon:", horizon_months)
        return {}

    df_list = []
    for f in files:
        df = pd.read_csv(f)
        if treatment_filter and treatment_filter.lower() not in df["treatment"].iloc[0].lower():
            continue
        df["source_file"] = os.path.basename(f)
        df_list.append(df)

    if not df_list:
        print(f"No matching SHAP files for treatment='{treatment_filter}' and horizon={horizon_months}")
        return {}

    df_all = pd.concat(df_list, ignore_index=True)

    df_all["survival_group"] = pd.cut(
        df_all["actual_survival_time"],
        bins=[-1, 183, 365, float("inf")],
        labels=["short", "middle", "long"]
    )

    results = {}

    for group_label in ["short", "long"]:
        group_df = df_all[df_all["survival_group"] == group_label]
        if group_df.empty:
            print(f"No patients in '{group_label}' group.")
            continue

        summary_df = (
            group_df.groupby("feature")["shap_value"]
            .mean()
            .sort_values(ascending=False)
            .reset_index()
        )

        top_n = 20
        plt.figure(figsize=(10, 6))
        plt.barh(summary_df["feature"].iloc[:top_n][::-1], summary_df["shap_value"].iloc[:top_n][::-1])
        plt.xlabel("Mean SHAP Value")
        plt.title(f"Top {top_n} Features – {group_label.capitalize()} Survival (Horizon={horizon_months}mo)")
        plt.tight_layout()
        plt.show()

        output_file = f"{output_prefix}_{group_label}.csv"
        group_df.to_csv(output_file, index=False)
        print(f"Saved SHAP data for group '{group_label}' to '{output_file}'")

        results[group_label] = summary_df

    return results


In [None]:
combine_and_plot_shap_summary_by_survival_group(
    folder_path=".", 
    horizon_months=12,
    treatment_filter="5-FU",
    output_prefix="shap_grouped" 
)


In [None]:
for pid in df_5fu["sourceId"].unique():
    plot_survival_and_shap_grid(
        model=trained_models["DeepSurv_attention"],
        df_raw=df_raw,
        df_all=df_5fu,
        treatment_map=five_fu_map,
        treatment_prefix=treat_prefix,
        patient_id=pid
    )


In [None]:
def plot_shap_difference_bar(summary_short, summary_long, top_n=20, title="SHAP Difference (Long - Short)"):
    """
    Plot the difference in mean SHAP values between long and short survival groups.

    Parameters:
    - summary_short (DataFrame): SHAP summary for short survival group.
    - summary_long (DataFrame): SHAP summary for long survival group.
    - top_n (int): Number of top features to display by absolute difference.
    - title (str): Plot title.
    """
    merged = pd.merge(summary_long, summary_short, on="feature", suffixes=("_long", "_short"))
    merged["diff"] = merged["shap_value_long"] - merged["shap_value_short"]

    merged_sorted = merged.reindex(merged["diff"].abs().sort_values(ascending=False).index)

    plt.figure(figsize=(10, 6))
    plt.barh(merged_sorted["feature"].iloc[:top_n][::-1], merged_sorted["diff"].iloc[:top_n][::-1])
    plt.axvline(0, color='gray', linestyle='--')
    plt.xlabel("Mean SHAP Difference (Long - Short)")
    plt.title(title)
    plt.tight_layout()
    plt.show()

    return merged_sorted


In [None]:
results = combine_and_plot_shap_summary_by_survival_group(
    folder_path=".",
    horizon_months=12,
    treatment_filter="5-FU",
    output_prefix="shap_grouped"
)

if "short" in results and "long" in results:
    plot_shap_difference_bar(
        summary_short=results["short"],
        summary_long=results["long"],
        top_n=20
    )


In [None]:
import pandas as pd
import glob
import matplotlib.pyplot as plt
import os

def combine_and_plot_shap_summary_by_survival_group(
    folder_path=".",
    horizon_months=12,
    treatment_filter=None,
    output_prefix="shap_combined"
):
    """
    Combine per-patient SHAP CSVs, split by survival group, and plot mean SHAP summaries.

    Groups:
    - Short survival: < 12 months (365 days)
    - Long survival: ≥ 12 months (365 days)

    Parameters:
    - folder_path (str): Where shap_patient_*.csv files are stored.
    - horizon_months (int): Filter for a specific horizon.
    - treatment_filter (str or None): Filter by treatment name if needed.
    - output_prefix (str): Prefix for saved output files.
    """
    pattern = f"{folder_path}/shap_patient_*_{horizon_months}mo.csv"
    files = glob.glob(pattern)

    if not files:
        print("No SHAP files found for horizon:", horizon_months)
        return {}

    df_list = []
    for f in files:
        df = pd.read_csv(f)
        if treatment_filter and treatment_filter.lower() not in df["treatment"].iloc[0].lower():
            continue
        df["source_file"] = os.path.basename(f)
        df_list.append(df)

    if not df_list:
        print(f"No matching SHAP files for treatment='{treatment_filter}' and horizon={horizon_months}")
        return {}

    df_all = pd.concat(df_list, ignore_index=True)

    df_all["survival_group"] = df_all["actual_survival_time"].apply(
        lambda t: "short" if t < 365 else "long"
    )

    results = {}

    for group_label in ["short", "long"]:
        group_df = df_all[df_all["survival_group"] == group_label]
        if group_df.empty:
            print(f"No patients in '{group_label}' group.")
            continue

        summary_df = (
            group_df.groupby("feature")["shap_value"]
            .mean()
            .sort_values(ascending=False)
            .reset_index()
        )

        top_n = 20
        plt.figure(figsize=(10, 6))
        plt.barh(summary_df["feature"].iloc[:top_n][::-1], summary_df["shap_value"].iloc[:top_n][::-1])
        plt.xlabel("Mean SHAP Value")
        plt.title(f"Top {top_n} Features – {group_label.capitalize()} Survival (< vs ≥ 1 year, Horizon={horizon_months}mo)")
        plt.tight_layout()
        plt.show()

        output_file = f"{output_prefix}_{group_label}.csv"
        group_df.to_csv(output_file, index=False)
        print(f"Saved SHAP data for group '{group_label}' to '{output_file}'")

        results[group_label] = summary_df

    return results


In [None]:
results = combine_and_plot_shap_summary_by_survival_group(
    folder_path=".",
    horizon_months=12,
    treatment_filter="5-FU",
    output_prefix="shap_grouped"
)

if "short" in results and "long" in results:
    plot_shap_difference_bar(
        summary_short=results["short"],
        summary_long=results["long"],
        top_n=20
    )


In [None]:
five_fu_oxa_beva_map = {"5-FU + oxaliplatin + bevacizumab": valid_treatment_combinations["5-FU + oxaliplatin + bevacizumab"]}

treat_prefix  = "systemicTreatmentPlan_"
cols_treat    = [c for c in df_all.columns if c.startswith(treat_prefix)]

col_5fu       = f"{treat_prefix}5-FU"
col_oxa     = f"{treat_prefix}oxaliplatin"
col_beva      = f"{treat_prefix}bevacizumab"

required_cols = [col_5fu, col_oxa, col_beva]
other_cols    = [c for c in cols_treat if c not in required_cols]

mask_combo = (
    df_all[required_cols].eq(1).all(axis=1)
    &
    (df_all[other_cols].sum(axis=1) == 0)
)

df_5fu_oxa_beva = df_all[mask_combo].copy()
print(f"{len(df_5fu_oxa_beva)} patients with 5-FU + Oxaplatin + Bevacizumab")


plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_5fu_oxa_beva,
    treatment_map    = five_fu_oxa_beva_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 365,
    model_name       = "DeepSurv + Attention"
)


In [None]:
counter = {'processed': 0}
total_patients = min(100, len(df_5fu_oxa_beva["sourceId"].unique()))

for i, pid in enumerate(df_5fu_oxa_beva["sourceId"].unique()):
    if i >= 1500:
        break
    plot_survival_and_shap_grid(
        model=trained_models["DeepSurv_attention"],
        df_raw=df_raw,
        df_all=df_5fu_oxa_beva,
        treatment_map=five_fu_oxa_beva_map,
        treatment_prefix=treat_prefix,
        patient_id=pid,
        counter=counter,
        total=total_patients
    )


In [None]:
import pandas as pd
import glob
import matplotlib.pyplot as plt
import os

def combine_and_plot_shap_summary_by_survival_group(
    folder_path=".",
    horizon_months=12,
    treatment_filter=None,
    output_prefix="shap_combined"
):
    """
    Combine per-patient SHAP CSVs, split by survival group, and plot mean SHAP summaries.

    Groups:
    - Short survival: < 6 months
    - Long survival: > 12 months

    Parameters:
    - folder_path (str): Where shap_patient_*.csv files are stored.
    - horizon_months (int): Filter for a specific horizon.
    - treatment_filter (str or None): Filter by treatment name if needed.
    - output_prefix (str): Prefix for saved output files.
    """
    pattern = f"{folder_path}/shap_patient_*_{horizon_months}mo.csv"
    files = glob.glob(pattern)

    if not files:
        print("No SHAP files found for horizon:", horizon_months)
        return {}

    df_list = []
    for f in files:
        df = pd.read_csv(f)
        if treatment_filter and treatment_filter.lower() not in df["treatment"].iloc[0].lower():
            continue
        df["source_file"] = os.path.basename(f)
        df_list.append(df)

    if not df_list:
        print(f"No matching SHAP files for treatment='{treatment_filter}' and horizon={horizon_months}")
        return {}

    df_all = pd.concat(df_list, ignore_index=True)

    df_all["survival_group"] = pd.cut(
        df_all["actual_survival_time"],
        bins=[-1, 183, 365, float("inf")],
        labels=["short", "middle", "long"]
    )

    results = {}

    for group_label in ["short", "long"]:
        group_df = df_all[df_all["survival_group"] == group_label]
        if group_df.empty:
            print(f"No patients in '{group_label}' group.")
            continue

        summary_df = (
            group_df.groupby("feature")["shap_value"]
            .mean()
            .sort_values(ascending=False)
            .reset_index()
        )

        top_n = 20
        plt.figure(figsize=(10, 6))
        plt.barh(summary_df["feature"].iloc[:top_n][::-1], summary_df["shap_value"].iloc[:top_n][::-1])
        plt.xlabel("Mean SHAP Value")
        plt.title(f"Top {top_n} Features – {group_label.capitalize()} Survival (Horizon={horizon_months}mo)")
        plt.tight_layout()
        plt.show()

        output_file = f"{output_prefix}_{group_label}.csv"
        group_df.to_csv(output_file, index=False)
        print(f"Saved SHAP data for group '{group_label}' to '{output_file}'")

        results[group_label] = summary_df

    return results


In [None]:
for treatment_name in five_fu_oxa_beva_map.keys():
    print(f"\n🔍 Processing treatment: {treatment_name}")

    results = combine_and_plot_shap_summary_by_survival_group(
        folder_path=".",
        horizon_months=12,
        treatment_filter=treatment_name,
        output_prefix=f"shap_grouped_{treatment_name.replace(' ', '_')}"
    )

    if "short" in results and "long" in results:
        plot_shap_difference_bar(
            summary_short=results["short"],
            summary_long=results["long"],
            top_n=20,
            title=f"SHAP Differences (Long – Short)\nTreatment: {treatment_name}"
        )
