In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch
import nbimporter
import shap

import os
import sys

os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

from models import *
from data.data_processing import DataSplitter, DataPreprocessor
from data.lookups import lookup_manager
from utils.settings import settings
from src.main.python.analysis.predictive_algorithms_training import get_data, plot_different_models_survival_curves

preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

In [None]:
def get_preprocessed_data_with_sourceId(preprocessor):
    df_raw = preprocessor.load_data()
    df_all, updated_features, _ = preprocessor.preprocess_data(
        lookup_manager.features, df=df_raw
    )
    df_all["sourceId"] = df_raw.loc[df_all.index, "sourceId"]
    return df_raw, df_all, updated_features

df_raw, df_all, updated_features = get_preprocessed_data_with_sourceId(preprocessor)

## Load trained models

In [None]:
def load_model_outcomes():
    csv_file = os.path.join(f"{settings.save_path}", f"{settings.outcome}_model_outcomes.csv")
    
    if os.path.exists(csv_file):
        results_df = pd.read_csv(csv_file)
        print(f"Loaded model outcomes from {csv_file}")
    else:
        raise FileNotFoundError(f"No saved outcomes found for {settings.outcome} in {settings.save_path}")
    
    return results_df

In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        
        model.model.net.load_state_dict(state['net_state'])
    
        if 'labtrans' in state:
            model.labtrans             = state['labtrans']
            model.model.duration_index = model.labtrans.cuts
        
        if 'baseline_hazards' in state:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
            
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models(X_train):
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        print(model_name, model_class)
        try:
            loaded_model = load_trained_model(
                model_name=model_name, 
                model_class=model_class, 
                model_kwargs=model_kwargs
            )
            loaded_models[model_name] = loaded_model

            ModelTrainer._set_attention_indices(loaded_models[model_name], list(X_train.columns))
        except:
            print(f'Could not load: {model_name}')
            continue
    return loaded_models

In [None]:
model_outcomes = load_model_outcomes()

In [None]:
trained_models = load_all_trained_models(X_train)

## Model Output

In [None]:
import json
with open('src/main/python/data/treatment_combinations.json', 'r') as f:
    valid_treatment_combinations = json.load(f)

## Analysis

In [None]:
def plot_threshold_analysis_by_treatment(
    model,
    df_all,
    treatment_map: dict = None,
    treatment_prefix: str = "systemicTreatmentPlan",
    horizon_days: int = 365,
    model_name: str = "Model",
    n_thresholds: int = 100,
    palette=plt.cm.tab20,
):
    if treatment_map is None:
        cols = [c for c in df_all.columns if c.startswith(treatment_prefix)]
        if not cols:
            raise ValueError(
                f"No columns found with prefix '{treatment_prefix}'. "
            )
        treatment_map = {"No Treatment": {}}
        treatment_map.update({
            col[len(treatment_prefix):] or col: {col: 1}
            for col in cols
        })

    preds_by_tx = get_preds_for_all_treatments(
        model, df_all,
        treatment_map=treatment_map,
        treatment_prefix=treatment_prefix,
        horizon_days=horizon_days,
    )

    thresholds = np.linspace(0.0, 1.0, n_thresholds)
    fig, (ax_top, ax_bot) = plt.subplots(
        2, 1, figsize=(11, 8), sharex=True,
        gridspec_kw={"height_ratios": [3, 1.2]}
    )

    colors = palette(np.linspace(0, 1, len(preds_by_tx)))
    for (label, df_preds), col in zip(preds_by_tx.items(), colors):
        med, q1, q3, lo, hi, n = [], [], [], [], [], []

        for t in thresholds:
            sub = df_preds[df_preds["predicted_prob_1yr"] < t]
            n.append(len(sub))
            if sub.empty:
                med.append(np.nan); q1.append(np.nan); q3.append(np.nan)
                lo.append(np.nan);  hi.append(np.nan)
                continue
            times = sub["actual_survival_time"].values
            q1_t, q3_t = np.percentile(times, [25, 75])
            iqr = q3_t - q1_t
            med.append(np.median(times))
            q1.append(q1_t); q3.append(q3_t)
            lo.append(max(times.min(), q1_t - 1.5 * iqr))
            hi.append(min(times.max(), q3_t + 1.5 * iqr))

        ax_top.fill_between(thresholds, q1, q3, color=col, alpha=0.15)
        ax_top.plot(thresholds, med, color=col, lw=2, label=f"{label} – median")
        ax_top.plot(thresholds, lo,  color=col, ls=":")
        ax_top.plot(thresholds, hi,  color=col, ls=":")
        ax_bot.step(thresholds, n, where="post", color=col, lw=1.5, label=label)

    label = f"{horizon_days//365}-year ({horizon_days} d)" if horizon_days % 365 == 0 else f"{horizon_days} days"
    ax_top.axhline(horizon_days, color="black", ls="--", label=label)
    ax_top.set_ylabel("Actual survival time (days)")
    ax_top.grid(alpha=0.3); ax_top.set_ylim(bottom=0)
    ax_top.legend(fontsize=8)

    ax_bot.set_xlabel("Predicted 1-year survival probability")
    ax_bot.set_ylabel("Patients\nbelow cut-off")
    ax_bot.grid(alpha=0.3); ax_bot.set_ylim(bottom=0)
    ax_bot.legend(fontsize=8, ncol=2)

    plt.suptitle(
        f"{model_name}: Actual survival vs survival probability "
        f"(survival predicted at {horizon_days} d)",
        y=0.97
    )
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [None]:
def get_preds_for_all_treatments(
    model,
    df_all,
    treatment_map: dict,
    treatment_prefix: str = "systemicTreatmentPlan",
    horizon_days: int = 365,
    background_size: int = 150
):

    durations = df_all[settings.duration_col].values
    events    = df_all[settings.event_col].astype(bool).values
    X_base    = df_all.drop(columns=["sourceId",
                                     settings.event_col,
                                     settings.duration_col])

    treatment_cols = [c for c in X_base if c.startswith(treatment_prefix)]

    preds_by_tx = {}

    for label, mapping in treatment_map.items():
        X_mod = apply_treatment(
            X_base, mapping, treatment_cols,
            msi_flag=(df_all.get("hasMsi", pd.Series(0))).astype(bool)
        )
        df_preds = get_patient_predictions(
            model,
            X_mod, durations, events,
            horizon_days=horizon_days
        )
        preds_by_tx[label] = df_preds

    return preds_by_tx


In [None]:
def apply_treatment(df, mapping, treatment_cols, msi_flag):
    df_copy = df.copy()
    df_copy[treatment_cols] = 0
    for col, val in mapping.items():
        if col in df_copy:
            df_copy[col] = val
    if "hasMsi" in df_copy:
        df_copy["hasMsi"] = msi_flag
        
    df_copy["hasTreatment"] = (
        df_copy[treatment_cols].sum(axis=1) > 0
    ).astype(int)
    return df_copy

In [None]:
def get_patient_predictions(model, X, durations, events, horizon_days=365):
    mask_known = ~((durations <= horizon_days) & (~events))
    durations_masked = durations[mask_known]
    events_masked = events[mask_known]
    X_masked = X.loc[mask_known]

    if len(X_masked) == 0:
        return None

    surv_funcs = model.predict_survival_function(X_masked)
    predicted_probs = np.array([fn(horizon_days) for fn in surv_funcs])

    df_preds = pd.DataFrame({
        "predicted_prob_1yr": predicted_probs,
        "actual_survival_time": durations_masked,
        "event_observed": events_masked
    })

    return df_preds

In [None]:
five_fu_map = {"5-FU": valid_treatment_combinations["5-FU"]}

treat_prefix = "systemicTreatmentPlan_"
cols_treat   = [c for c in df_all.columns if c.startswith(treat_prefix)]

col_5fu      = "systemicTreatmentPlan_5-FU"
other_cols   = [c for c in cols_treat if c != col_5fu]

mask_5fu = (
    (df_all[col_5fu] == 1) &
    (df_all[other_cols].sum(axis=1) == 0)
)

df_5fu = df_all[mask_5fu].copy()
print(f"{len(df_5fu)} patients with 5-FU")


plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_5fu,
    treatment_map    = five_fu_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 365,
    model_name       = "DeepSurv + Attention"
)


In [None]:
plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_5fu,
    treatment_map    = five_fu_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 182,
    model_name       = "DeepSurv + Attention"
)

In [None]:
plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_5fu,
    treatment_map    = five_fu_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 91,
    model_name       = "DeepSurv + Attention"
)

In [None]:
five_fu_oxa_beva_map = {"5-FU + oxaliplatin + bevacizumab": valid_treatment_combinations["5-FU + oxaliplatin + bevacizumab"]}

treat_prefix  = "systemicTreatmentPlan_"
cols_treat    = [c for c in df_all.columns if c.startswith(treat_prefix)]

col_5fu       = f"{treat_prefix}5-FU"
col_oxa     = f"{treat_prefix}oxaliplatin"
col_beva      = f"{treat_prefix}bevacizumab"

required_cols = [col_5fu, col_oxa, col_beva]
other_cols    = [c for c in cols_treat if c not in required_cols]

mask_combo = (
    df_all[required_cols].eq(1).all(axis=1)
    &
    (df_all[other_cols].sum(axis=1) == 0)
)

df_5fu_oxa_beva = df_all[mask_combo].copy()
print(f"{len(df_5fu_oxa_beva)} patients with 5-FU + Oxaplatin + Bevacizumab")


plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_5fu_oxa_beva,
    treatment_map    = five_fu_oxa_beva_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 365,
    model_name       = "DeepSurv + Attention"
)


In [None]:
no_treatment_map = {"No treatment": valid_treatment_combinations["No treatment"]}

treat_prefix = "systemicTreatmentPlan_"
cols_treat = [c for c in df_all.columns if c.startswith(treat_prefix)]

mask_no_treatment = df_all[cols_treat].sum(axis=1) == 0

df_no_treatment = df_all[mask_no_treatment].copy()
print(f"{len(df_no_treatment)} patients with no systemic treatment")

# Run the threshold analysis
plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_no_treatment,
    treatment_map    = no_treatment_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 365,
    model_name       = "DeepSurv + Attention"
)


In [None]:
no_treatment_map = {"No treatment": valid_treatment_combinations["No treatment"]}

treat_prefix = "systemicTreatmentPlan_"
cols_treat = [c for c in df_all.columns if c.startswith(treat_prefix)]

mask_no_treatment = df_all[cols_treat].sum(axis=1) == 0

df_no_treatment = df_all[mask_no_treatment].copy()
print(f"{len(df_no_treatment)} patients with no systemic treatment")

# Run the threshold analysis
plot_threshold_analysis_by_treatment(
    model            = trained_models["DeepSurv_attention"],
    df_all           = df_no_treatment,
    treatment_map    = five_fu_map,
    treatment_prefix = treat_prefix,
    horizon_days     = 365,
    model_name       = "DeepSurv + Attention"
)


In [None]:
treat_prefix = "systemicTreatmentPlan_"
cols_treat = [c for c in df_all.columns if c.startswith(treat_prefix)]

unique_treatments = df_all[cols_treat].drop_duplicates().reset_index(drop=True)

def make_label(row):
    active = [col.replace(treat_prefix, "") for col in cols_treat if row[col] == 1]
    return " + ".join(active) if active else "No treatment"

unique_treatments["label"] = unique_treatments.apply(make_label, axis=1)

valid_treatment_combinations = {
    row["label"]: row[cols_treat].to_dict()
    for _, row in unique_treatments.iterrows()
}

for label, treatment_vector in valid_treatment_combinations.items():
    print(f"{label}: {treatment_vector}")
