# Model Evaluation
This notebook demonstrates the interpretation and evaluation of the models (as trained in `predictive_algorithms_training.ipynb`):
- Performance Evaluation: Comparing models based on metrics such as concordance index (C-Index), integrated Brier score (IBS), calibration error (CE), and time-dependent AUC.
- Visualization: Generating survival curves and feature importance plots to interpret model predictions and uncover key insights.

In the file `utils/settings.py` all the experiment settings can be set (e.g. OS or PFS, grouped treatments or not), then the experiment can be run in this notebook.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch
import nbimporter
import shap

import os
os.chdir('/data/repos/actin-personalization/scripts/personalization/prediction')

from src.models import *
from src.data.data_processing import DataSplitter, DataPreprocessor
from src.data.lookups import lookup_manager
from src.utils.settings import settings
from src.predictive_algorithms_training import get_data, plot_different_models_survival_curves

preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

## Model outcome & metrics

### Load models & outcomes
The pretrained models and outcomes are stored in the Google Cloud Storage bucket: `gs://actin-personalization-models-v1/trained_models/`. 

To download the saved models from the bucket to the trained_models map, run the following command in your terminal:

`gsutil -m cp -r gs://actin-personalization-models-v1/trained_models/./trained_models/`

Make sure the trained_models folder is inside the models folder.

In [None]:
def load_model_outcomes():
    csv_file = os.path.join(f"{settings.save_path}", f"{settings.outcome}_model_outcomes.csv")
    
    if os.path.exists(csv_file):
        results_df = pd.read_csv(csv_file)
        print(f"Loaded model outcomes from {csv_file}")
    else:
        raise FileNotFoundError(f"No saved outcomes found for {settings.outcome} in {settings.save_path}")
    
    return results_df

In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        
        model.model.net.load_state_dict(state['net_state'])
    
        if 'labtrans' in state:
            model.labtrans             = state['labtrans']
            model.model.duration_index = model.labtrans.cuts
        
        if 'baseline_hazards' in state:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
            
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models(X_train):
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        print(model_name, model_class)
        try:
            loaded_model = load_trained_model(
                model_name=model_name, 
                model_class=model_class, 
                model_kwargs=model_kwargs
            )
            loaded_models[model_name] = loaded_model

            ModelTrainer._set_attention_indices(loaded_models[model_name], list(X_train.columns))
        except:
            print(f'Could not load: {model_name}')
            continue
    return loaded_models

In [None]:
model_outcomes = load_model_outcomes()

In [None]:
trained_models = load_all_trained_models(X_train)

### Metric comparison

The trained models are evaluated using the following metrics:

- **C-Index**: The Concordance Index measures how well the predicted survival times align with the actual outcomes. It is a measure of discrimination, indicating the model's ability to correctly rank the survival times of patients. A higher value indicates better predictive accuracy.

- **Integrated Brier Score (IBS)**: This metric evaluates the accuracy of the survival probability predictions over time. It combines the squared differences between predicted and actual survival probabilities, weighted by the survival distribution. Lower values indicate better predictive performance.

- **Calibration Error (CE)**: Calibration error assesses how well the predicted survival probabilities match the observed probabilities. It indicates whether the model is systematically overestimating or underestimating survival probabilities. Lower values signify better calibration.

- **Area Under the Curve (AUC)**: For survival models, AUC is typically computed over a time-dependent ROC curve, reflecting the model's discrimination ability at different time points. Higher AUC values indicate better overall performance.

This section visualizes the comparison of model performance metrics (C-Index, IBS, CE, AUC) for OS and PFS. The bar plots highlight the strengths and weaknesses of each model in the two prediction tasks. 


In [None]:
def extract_holdout_metrics(df):
    if df['holdout'].apply(lambda x: isinstance(x, str)).any():
        df['holdout'] = df['holdout'].apply(ast.literal_eval)
    
    holdout_metrics = df['holdout'].apply(pd.Series)
    holdout_metrics['Model'] = df['Model']
    
    return holdout_metrics

def plot_all_metrics(df, holdout=True):
    metrics = ['c_index', 'ibs', 'ce', 'auc']
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for ax, metric in zip(axes, metrics):
        sns.barplot(x='Model', y=metric, data=df, ax=ax, palette='Set1')
        title = f"{metric.upper()} Comparison"
        
        ax.set_title(title)
        ax.set(xlabel='Model', ylabel=metric.upper())
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        
        if metric in ['c_index', 'auc']:
            ax.set_ylim(0, 1)
        
        min_val, max_val = df[metric].min(), df[metric].max()
        cmap = sns.light_palette("#79C", as_cmap=True)
        for patch in ax.patches:
            height = patch.get_height()
            normalized = (height - min_val) / (max_val - min_val) if max_val > min_val else 0.5
            patch.set_facecolor(cmap(0.4 + 0.8 * normalized))
    
    plt.tight_layout()
    plt.show()
    
    return df

In [None]:
plot_all_metrics(model_outcomes)

### Time-Dependent ROC-AUC

This section visualizes the ROC curves and computes the AUC for survival models at specific time intervals. By evaluating the models' discriminative performance over time, we identify which models perform best at different prediction horizons.

Time interval:
- **Overall Survival (OS)**: For OS, the follow-up times in the dataset extend up to 5 years, allowing us to evaluate model performance over this longer horizon. As survival outcomes often have a broader timespan, a 5-year evaluation provides a comprehensive view of the model's ability to predict long-term survival.

The ROC curves for the models are plotted for each time interval, showcasing how well the models distinguish patients at risk across the respective timeframes for OS and PFS.


In [None]:
from sksurv.util import Surv
from sksurv.metrics import cumulative_dynamic_auc

def calculate_time_dependent_auc_for_models(model_dict, X_train, y_train, X_test, y_test, time_points):
  
    y_train_df = pd.DataFrame({'duration': y_train[settings.duration_col], 'event': y_train[settings.event_col]}, index=X_train.index)
    y_train_struct = Surv.from_dataframe('event', 'duration', y_train_df)

    y_test_df = pd.DataFrame({'duration': y_test[settings.duration_col], 'event': y_test[settings.event_col]}, index=X_test.index)
    y_test_struct = Surv.from_dataframe('event', 'duration', y_test_df)

    auc_results = {}

    for model_name, model in model_dict.items():
        if hasattr(model, "model") and hasattr(model.model, "predict"):
            preds = model.model.predict(X_test.values.astype("float32"))
        else:
            preds = model.predict(X_test)

        if preds.ndim == 1:
            risk_scores = preds
        elif preds.shape[1] == 1:
            risk_scores = preds.ravel()
        else:
            T = preds.shape[1]
            desired_T = len(time_points)
            if T == desired_T:
                risk_scores = preds
            elif T % desired_T == 0:
                factor = T // desired_T
                risk_scores = preds[:, ::factor]

        auc_values, mean_auc = cumulative_dynamic_auc(y_train_struct, y_test_struct, risk_scores, time_points)
        auc_results[model_name] = (auc_values, mean_auc)

    plt.figure(figsize=(10, 6))
    years = [t / 365.0 for t in time_points]
    
    cmap = plt.cm.get_cmap("tab20", len(auc_results))

    for i, (model_name, (auc_vals, mean_auc)) in enumerate(auc_results.items()):
        plt.plot(years, auc_vals, marker='o', label=f"{model_name} (Mean AUC={mean_auc:.3f})", color=cmap(i))

    plt.xlabel("Time (years)")
    plt.ylabel("Time-Dependent AUC")
    plt.title(settings.outcome)
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

    return auc_results


In [None]:
models_to_evaluate = {
    "DeepSurv": trained_models["DeepSurv"],
    "DeepSurv_attention": trained_models["DeepSurv_attention"],
    
    "LogisticHazardModel": trained_models["LogisticHazardModel"],
    "LogisticHazardModel_attention": trained_models["LogisticHazardModel_attention"],
    
    "DeepHitModel": trained_models["DeepHitModel"],
    "DeepHitModel_attention": trained_models["DeepHitModel_attention"],
    
    "PCHazardModel": trained_models["PCHazardModel"],
    "PCHazardModel_attention": trained_models["PCHazardModel_attention"],
    
    "MTLRModel": trained_models["MTLRModel"],
    "MTLRModel_attention": trained_models["MTLRModel_attention"],
    
    "GradientBoosting": trained_models["GradientBoosting"],   
    "RandomSurvivalForest": trained_models["RandomSurvivalForest"],    
}

calculate_time_dependent_auc_for_models(
    models_to_evaluate, X_train, y_train, X_test, y_test, 
    time_points=settings.time_points
)

## Model Interpretation

In [None]:
feature_short_names = {
        "ageAtMetastasisDetection": "Age (metastasis)",
        "albumine": "Albumin",
        "alkalinePhosphatase": "Alk. phosphatase",
        "anorectalVergeDistanceCategory": "Tumor–ARV distance",
        "asaClassificationPreSurgeryOrEndoscopy": "ASA class",
        "carcinoEmbryonicAntigen": "CEA",
        "cci": "CCI (score)",
        "cciHasAids": "CCI: AIDS",
        "cciHasCerebrovascularDisease": "CCI: Cerebrovascular disease",
        "cciHasCollagenosis": "CCI: Collagenosis",
        "cciHasCongestiveHeartFailure": "CCI: Heart failure",
        "cciHasCopd": "CCI: COPD",
        "cciHasDementia": "CCI: Dementia",
        "cciHasDiabetesMellitus": "CCI: Diabetes",
        "cciHasDiabetesMellitusWithEndOrganDamage": "CCI: Diabetes w/ EOD",
        "cciHasHemiplegiaOrParaplegia": "CCI: Hemiplegia",
        "cciHasLiverDisease": "CCI: Liver disease",
        "cciHasMildLiverDisease": "CCI: Mild liver disease",
        "cciHasMyocardialInfarct": "CCI: MI",
        "cciHasOtherMalignancy": "CCI: Other malignancy",
        "cciHasOtherMetastaticSolidTumor": "CCI: Other metastasis",
        "cciHasPeripheralVascularDisease": "CCI: PVD",
        "cciHasRenalDisease": "CCI: Renal disease",
        "cciHasUlcerDisease": "CCI: Ulcer disease",
        "cciNumberOfCategories": "CCI: # categories",
        "distanceToMesorectalFasciaMm": "MRF distance (mm)",
        "hasBrafMutation": "BRAF mut.",
        "hasBrafV600EMutation": "BRAF V600E",
        "hasDoublePrimaryTumor": "Double primary",
        "hasHadPriorTumor": "Prior tumor",
        "hasKrasG12CMutation": "KRAS G12C",
        "hasMsi": "MSI",
        "hasRasMutation": "RAS mut.",
        "investigatedLymphNodesNumber": "# nodes (investigated)",
        "lactateDehydrogenase": "LDH",
        "leukocytesAbsolute": "Leukocytes",
        "maximumSizeOfLiverMetastasisMm": "Max liver met (mm)",
        "mesorectalFasciaIsClear": "MRF clear",
        "neutrophilsAbsolute": "Neutrophils",
        "numberOfLiverMetastases": "# liver mets",
        "positiveLymphNodesNumber": "# positive nodes",
        "presentedWithIleus": "Ileus at presentation",
        "presentedWithPerforation": "Perforation",
        "sex": "Sex",
        "sidedness": "Tumor sidedness",
        "stageCTNM": "cTNM",
        "stagePTNM": "pTNM",
        "stageTNM": "TNM",
        "tumorDifferentiationGrade": "Grade",
        "tumorIncidenceYear": "Incidence year",
        "whoStatusPreTreatmentStart": "WHO status",
        "observedOsFromMetastasisDetectionDays": "OS (days)",

        # Systemic treatment flags
        "systemicTreatmentPlan_5-FU": "5-FU",
        "systemicTreatmentPlan_oxaliplatin": "Oxaliplatin",
        "systemicTreatmentPlan_irinotecan": "Irinotecan",
        "systemicTreatmentPlan_bevacizumab": "Bevacizumab",
        "systemicTreatmentPlan_panitumab": "Panitumumab",
        "systemicTreatmentPlan_pembrolizumab": "Pembrolizumab",
        "systemicTreatmentPlan_nivolumab": "Nivolumab",

        # Metastasis locations
        "metastasisLocationGroupsPriorToSystemicTreatment_BRAIN": "Met: Brain",
        "metastasisLocationGroupsPriorToSystemicTreatment_BRONCHUS_AND_LUNG": "Met: Lung",
        "metastasisLocationGroupsPriorToSystemicTreatment_LIVER_AND_INTRAHEPATIC_BILE_DUCTS": "Met: Liver",
        "metastasisLocationGroupsPriorToSystemicTreatment_LYMPH_NODES": "Met: LN",
        "metastasisLocationGroupsPriorToSystemicTreatment_OTHER": "Met: Other",
        "metastasisLocationGroupsPriorToSystemicTreatment_PERITONEUM": "Met: Peritoneum",

        # Tumor types
        "consolidatedTumorType_CRC_ADENOCARCINOMA": "CRC: Adeno",
        "consolidatedTumorType_CRC_MUCINOUS": "CRC: Mucinous",
        "consolidatedTumorType_CRC_OTHER": "CRC: Other",
        "consolidatedTumorType_CRC_SIGNET_RING_CELL": "CRC: Signet ring",

        # Extra mural invasion
        "extraMuralInvasionCategory_ABOVE_FIVE_MM": "EMI >5mm",
        "extraMuralInvasionCategory_LESS_THAN_FIVE_MM": "EMI <5mm",
        "extraMuralInvasionCategory_NA": "EMI: NA",

        # Basis of diagnosis
        "tumorBasisOfDiagnosis_CLINICAL_AND_DIAGNOSTIC_INVESTIGATION": "Dx: Clinical + diag",
        "tumorBasisOfDiagnosis_CLINICAL_ONLY_INVESTIGATION": "Dx: Clinical only",
        "tumorBasisOfDiagnosis_CYTOLOGICAL_CONFIRMATION": "Dx: Cytology",
        "tumorBasisOfDiagnosis_HISTOLOGICAL_CONFIRMATION": "Dx: Histology",
        "tumorBasisOfDiagnosis_HISTOLOGICAL_CONFIRMATION_METASTASES": "Dx: Histology (met)",
        "tumorBasisOfDiagnosis_SPEC_BIOCHEMICAL_IMMUNOLOGICAL_LAB_INVESTIGATION": "Dx: Biochem/immuno",

        # Tumor location
        "tumorLocation_APPENDIX": "Tumor: Appendix",
        "tumorLocation_ASCENDING_COLON": "Tumor: Asc. colon",
        "tumorLocation_COECUM": "Tumor: Cecum",
        "tumorLocation_COLON_NOS": "Tumor: Colon NOS",
        "tumorLocation_COLON_OVERLAPPING": "Tumor: Overlapping",
        "tumorLocation_DESCENDING_COLON": "Tumor: Desc. colon",
        "tumorLocation_FLEXURA_HEPATICA": "Tumor: Hepatic flexure",
        "tumorLocation_FLEXURA_LIENALIS": "Tumor: Splenic flexure",
        "tumorLocation_OVARY": "Tumor: Ovary",
        "tumorLocation_RECTOSIGMOID": "Tumor: Rectosigmoid",
        "tumorLocation_RECTUM": "Tumor: Rectum",
        "tumorLocation_SIGMOID_COLON": "Tumor: Sigmoid",
        "tumorLocation_TRANSVERSE_COLON": "Tumor: Transverse",
    }

In [None]:
def compute_effective_input(model, X_tensor: torch.Tensor) -> np.ndarray:
    attn_layer = model.model.net.attention

    x = X_tensor.clone()
    x_gated = attn_layer._apply_gate(x)

    with torch.no_grad():
        weights = attn_layer.attn(x_gated).cpu().numpy()[0]  

    x_gated_np = x_gated.cpu().numpy()[0]

    x_effective = x_gated_np * weights

    return x_effective

### Feature Importance

Feature importance analysis helps us identify which features most strongly influence survival predictions across various models. In this section SHAP (SHapley Additive exPlanations) is used for all models to ensure uniformity and interpretability. SHAP values provide a consistent and locally accurate measure of feature importance for individual predictions.

How SHAP is used:
- For classical models like CoxPH and Aalen Additive, SHAP values are calculated based on risk scores or cumulative hazard coefficients. Custom prediction functions are used when necessary (e.g., for Aalen Additive) to align feature importance with model-specific outputs.
- For tree-based models like Random Survival Forest (RSF) and Gradient Boosting Survival Model (GBM), SHAP values replace traditional feature importance metrics to ensure consistency.
- For neural network-based models like DeepSurv and DeepHit, SHAP values are derived using the model's prediction function.

#### Visualization
SHAP provides the following insights:
- Summary Plot - Bar: Displays the average magnitude of SHAP values for each feature, indicating the overall importance of features in the model.
- Summary Plot - Dot: Highlights the distribution of SHAP values for each feature, showing their impact across different samples.


In [None]:
def nn_predict(X, model, X_train):
    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X, columns=X_train.columns)
    X_tensor = X.values.astype('float32')
    return model.model.predict(X_tensor)

def custom_aalen_predict(X, model):
    """
    Custom predict function for AalenAdditiveModel.
    Aligns cumulative hazard coefficients with input features.
    """
    cumulative_coefs = model.model.cumulative_hazards_
    X = X[model.selected_features].copy()

    # Interpolate coefficients at the latest time point
    latest_coefs = cumulative_coefs.iloc[-1].values
    
    if len(latest_coefs) > X.shape[1]:
        X = X.copy()
        X.insert(0, "Intercept", 1.0)  

    X_array = X.values
    risk_scores = np.einsum('ij,j->i', X_array, latest_coefs)
    return risk_scores

def shap_interpret_model(model_name, model, X_train, feature_short_names, max_features=20, shap_sample=200):

    X_sample = X_train.sample(min(shap_sample, len(X_train)), random_state=42)
    
    X_display = X_sample.rename(columns=feature_short_names)
    display_names = list(X_display.columns)
    
    predict_functions = {
        'AalenAdditive': lambda X: custom_aalen_predict(X, model),
        'default': model.predict
    }

    if model_name == 'AalenAdditive':
        prediction_fn = predict_functions['AalenAdditive']
    else:
        try:
            model.predict(X_sample.head(1))
            prediction_fn = predict_functions['default']
        except:
            prediction_fn = lambda x: nn_predict(x, model, X_train)

    explainer = shap.Explainer(prediction_fn, X_sample, feature_names = display_names)
    shap_values = explainer(X_sample)

    print(f"SHAP Summary for {model_name}:")

    if len(shap_values.values.shape) == 3: #If time dimension present
        aggregated_shap = shap_values.values.mean(axis=1)
        shap.summary_plot(aggregated_shap, features=X_display, plot_type="bar", max_display=max_features)
        shap.summary_plot(aggregated_shap, features=X_display, max_display=max_features)
    else:
        shap.summary_plot(shap_values, features=X_display, plot_type="bar", max_display=max_features)
        shap.summary_plot(shap_values, features=X_display, max_display=max_features)

In [None]:
for model_name, model_instance in models_to_evaluate.items():
    print(f"\n--- Interpreting {model_name} ---")
    shap_interpret_model(model_name, model_instance, X_train, feature_short_names)

### Attention weights

In [None]:
def plot_effective_feature_contributions(model, preprocessor, settings, ncr_id, max_display=20):
    features = ["ncrId"] + [f for f in lookup_manager.features if f != "ncrId"]
    df_all, _, _ = preprocessor.preprocess_data(features)
    df_all = df_all.loc[:, ~df_all.columns.duplicated()].copy()

    patient_df = df_all[df_all["ncrId"] == ncr_id]
    if patient_df.empty:
        raise ValueError(f"No patient found with ncrId={ncr_id}")

    X_input = patient_df.drop(columns=["ncrId", settings.event_col, settings.duration_col])
    feature_names = X_input.columns.tolist()
    X_tensor = torch.tensor(X_input.values, dtype=torch.float32)

    x_effective = compute_effective_input(model, X_tensor)

    df_plot = pd.DataFrame({
        "Feature": [feature_short_names.get(f, f) for f in feature_names],
        "Effective Contribution": x_effective
    })

    df_plot = df_plot.loc[df_plot["Effective Contribution"].abs() > 0]
    df_plot = df_plot.sort_values("Effective Contribution", key=abs, ascending=False).head(max_display)

    plt.figure(figsize=(10, 6))
    sns.barplot(data=df_plot, x="Effective Contribution", y="Feature", palette="viridis")
    plt.title(f"Effective Input Contribution to MLP – Patient {ncr_id}")
    plt.tight_layout()
    plt.show()

In [None]:
plot_effective_feature_contributions(
    model=trained_models["DeepSurv_attention"],
    preprocessor=preprocessor,
    settings=settings,
    ncr_id=#NcrID,
    max_display=15
)

## Model Output

### Comparison of Treatments
In this section the predicted survival probabilities can be visualized for a single patient under different treatment scenarios. By simulating the patient receiving each available treatment, we can compare how the model predicts their survival trajectory across treatments.

This analysis provides insights into the model's predictions for different treatment options, helping to identify potentially better treatment choices for the patient based on the predicted survival probabilities over time.

In [None]:
valid_treatment_combinations = {
    "No Treatment": {
        "systemicTreatmentPlan_5-FU": 0,
        "systemicTreatmentPlan_oxaliplatin": 0,
        "systemicTreatmentPlan_irinotecan": 0,
        "systemicTreatmentPlan_bevacizumab": 0,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 0,
        "systemicTreatmentPlan_irinotecan": 0,
        "systemicTreatmentPlan_bevacizumab": 0,
        "systemicTreatmentPlan_panitumab": 0, 
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU + oxaliplatin": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 1,
        "systemicTreatmentPlan_irinotecan": 0,
        "systemicTreatmentPlan_bevacizumab": 0,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU + oxaliplatin + bevacizumab": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 1,
        "systemicTreatmentPlan_irinotecan": 0,
        "systemicTreatmentPlan_bevacizumab": 1,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU + irinotecan": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 0,
        "systemicTreatmentPlan_irinotecan": 1,
        "systemicTreatmentPlan_bevacizumab": 0,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU + irinotecan + bevacizumab": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 0,
        "systemicTreatmentPlan_irinotecan": 1,
        "systemicTreatmentPlan_bevacizumab": 1,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU + irinotecan + panitumumab": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 0,
        "systemicTreatmentPlan_irinotecan": 1,
        "systemicTreatmentPlan_bevacizumab": 0,
        "systemicTreatmentPlan_panitumab": 1,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU + oxaliplatin + irinotecan": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 1,
        "systemicTreatmentPlan_irinotecan": 1,
        "systemicTreatmentPlan_bevacizumab": 0,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "5-FU + oxaliplatin + irinotecan + bevacizumab": {
        "systemicTreatmentPlan_5-FU": 1,
        "systemicTreatmentPlan_oxaliplatin": 1,
        "systemicTreatmentPlan_irinotecan": 1,
        "systemicTreatmentPlan_bevacizumab": 1,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 0,
        "systemicTreatmentPlan_nivolumab": 0
    },
    "PEMBROLIZUMAB": {
        "systemicTreatmentPlan_5-FU": 0,
        "systemicTreatmentPlan_oxaliplatin": 0,
        "systemicTreatmentPlan_irinotecan": 0,
        "systemicTreatmentPlan_bevacizumab": 0,
        "systemicTreatmentPlan_panitumab": 0,
        "systemicTreatmentPlan_pembrolizumab": 1,
        "systemicTreatmentPlan_nivolumab": 0
    }
}

#### Median Survival Time Calculation

 Rather than relying on the mean survival time (which can be skewed by tail behavior), we focus on the median survival time, defined as the time point t at which the survival function S(t) first drops below 0.5. This metric is often more robust in practical settings, as it is less sensitive to subtle differences in the survival curve’s tail.

The median survival time is obtained by scanning the survival curve from time zero until finding the earliest point where S(t)≤0.5. If the survival probability never dips below 0.5 within the observed follow-up, the median is considered to be at (or beyond) the maximum time point in our data.



In [None]:
def compute_survival_stats(time_grid: np.ndarray, surv_probs: np.ndarray):
    """
    Given a survival curve sampled on `time_grid`, returns:
      - median_months: time where S(t) ≤ 0.5, in months
      - mean_months: area under S(t) curve (restricted mean survival), in months
    """
    below = np.where(surv_probs <= 0.5)[0]
    if below.size:
        median_days = time_grid[below[0]]
    else:
        median_days = time_grid[-1]

    auc_days = np.trapz(surv_probs, time_grid)

    return median_days, auc_days

#### Patient specific prediction

In [None]:
import matplotlib.gridspec as gridspec
from PIL import Image
import io
import pymysql

def compute_risk_at_horizon(model, X, horizon_days=360):
    survival_functions = model.predict_survival_function(X)
    return np.array([
        1.0 - np.interp(horizon_days, sf.x, sf.y)
        for sf in survival_functions
    ])

def create_shap_image(shap_values, max_display=15):
    shap_values.feature_names = [
        feature_short_names.get(name, name)
        for name in shap_values.feature_names
    ]
    fig, ax = plt.subplots(figsize=(16, 5.5))
    shap.plots.bar(shap_values, max_display=max_display, show=False)
    for label in ax.get_yticklabels():
        label.set_clip_on(False)
        label.set_horizontalalignment("right")
    plt.subplots_adjust(left=0.28)
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=600)
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)

def apply_treatment(df, mapping, treatment_cols, msi_flag):
    df_copy = df.copy()
    df_copy[treatment_cols] = 0
    for col, val in mapping.items():
        if col in df_copy:
            df_copy[col] = val
    if "hasMsi" in df_copy:
        df_copy["hasMsi"] = msi_flag
        
    df_copy["hasTreatment"] = (
        df_copy[treatment_cols].sum(axis=1) > 0
    ).astype(int)
    return df_copy

def plot_survival_and_shap_grid(
    model, preprocessor, settings, patient_id: int = None, patient_row: dict = None,
    treatment_map: dict = None, treatment_prefix="systemicTreatmentPlan",
    horizons=(12,), background_size=150, max_shap_display=15, display_stats=('median', 'auc')
):
    
    raw_df = preprocessor.load_data()
    
    if patient_row is not None:
        raw_df = pd.concat([raw_df, pd.DataFrame([patient_row])], ignore_index=True)
        target_id = patient_row['ncrId']
    else:
        target_id = patient_id

    features = ["ncrId"] + lookup_manager.features
    df_all, updated_features, _ = preprocessor.preprocess_data(features, df=raw_df)
    
    patient_df = df_all[df_all["ncrId"] == target_id]
    if patient_df.empty:
        raise ValueError(f"Could not find ncrId={target_id} after preprocessing")

    X_base = patient_df[updated_features].drop(columns=["ncrId"])
    flag = lambda col: int(
        X_base.get(col, pd.Series([0])).iloc[0]
    )
    
    msi_label = "MSI" if flag("hasMsi") else "MSS"

    if patient_row is None:
        with pymysql.connect(
            read_default_file=preprocessor.db_config_path,
            read_default_group="RAnalysis",
            db=preprocessor.db_name
        ) as conn:
            raw = pd.read_sql(f"SELECT * FROM {settings.view_name}", conn)
        raw_patient = raw.loc[raw["ncrId"] == target_id].iloc[0]
    else:
        raw_patient = patient_row

    treatment_cols = [c for c in X_base if c.startswith(treatment_prefix)]
    survival_fs = model.predict_survival_function(X_base)
    time_start = max(sf.x[0] for sf in survival_fs)
    time_end = min(sf.x[-1] for sf in survival_fs)
    time_grid = np.linspace(time_start, time_end, 100)
    cmap = plt.cm.tab20(np.linspace(0, 1, len(treatment_map)))
    risks = {}

    fig = plt.figure(figsize=(28, 20))
    gs_master = gridspec.GridSpec(2, 2, height_ratios=[2.0, 4.5], width_ratios=[1.3, 1.1], figure=fig)
    
    ax_curve = fig.add_subplot(gs_master[0, 0])
    for i, (label, mapping) in enumerate(treatment_map.items()):
        X_mod = apply_treatment(X_base, mapping, treatment_cols, flag("hasMsi"))
        surv_fn, = model.predict_survival_function(X_mod)
        surv_prob = surv_fn(time_grid)
        
        median_days, auc_days = compute_survival_stats(time_grid, surv_prob)
        median_mo, auc_mo = median_days/30.44 , auc_days/30.44
        
        ax_curve.step(time_grid / 30.44, surv_prob, where="post", color=cmap[i], label=f"{label}  (median ≈ {median_mo:.0f} mo, auc = {auc_mo:.0f})")
        risks[label] = compute_risk_at_horizon(model, X_mod, 12 * 30)
        
    actual_time = patient_df[settings.duration_col].iloc[0] / 30.44
    event_flag = int(patient_df[settings.event_col].iloc[0])
    
    ax_curve.axvline(actual_time, color="red" if event_flag else "blue", linestyle="--", label="Event" if event_flag else "Censor")
    ax_curve.set_title(f"Patient {target_id} – survival curves")
    ax_curve.set_xlabel("Time (months)")
    ax_curve.set_ylabel("Survival probability")
    ax_curve.legend(ncol=2, fontsize=8)
    ax_curve.grid(True)

    ax_info = fig.add_subplot(gs_master[0, 1])
    ax_info.axis("off")
    info_lines = [
        f"Patient {target_id}",
        f"Age at metastasis detection: {raw_patient.get('ageAtMetastasisDetection','NA')}",
        f"WHO status: {raw_patient.get('whoStatusPreTreatmentStart','NA')}",
        f"MSI status: {msi_label}",
        f"BRAF mutation: {'Yes' if flag('hasBrafMutation') else 'No'}",
        f"BRAF V600E: {'Yes' if flag('hasBrafV600EMutation') else 'No'}",
        f"KRAS G12C: {'Yes' if flag('hasKrasG12CMutation') else 'No'}",
        f"RAS mutation: {'Yes' if flag('hasRasMutation') else 'No'}",
    ]
    ax_info.text(0, 1, "\n".join(info_lines), va="top", ha="left", fontsize=11)

    is_io = lambda label: any(t in label.lower() for t in ("pembrolizumab", "nivolumab"))
    is_chemo = lambda label: not is_io(label) and "no treatment" not in label.lower()
    no_tx = next((l for l in risks if "No Treatment" in l),list(risks)[0])

    best_io = min((l for l in risks if is_io(l)), key=risks.get)
    best_chemo = min((l for l in risks if is_chemo(l)), key=risks.get)
    worst = max((l for l in risks if l != no_tx), key=risks.get)
    showcase = [no_tx, best_io, best_chemo, worst]

    X_bg = df_all.drop(columns=["ncrId", settings.event_col, settings.duration_col]).sample(min(background_size, len(df_all)), random_state=42)

    gs_shap = gridspec.GridSpecFromSubplotSpec(len(horizons), len(showcase), subplot_spec=gs_master[1, :], hspace=0.005, wspace=0.005)
    
    for r, H in enumerate(horizons):
        for c, lbl in enumerate(showcase):
            X_mod = apply_treatment(X_base, treatment_map[lbl], treatment_cols, flag("hasMsi"))
            
            explainer = shap.Explainer(
                lambda x: compute_risk_at_horizon(model, x, H * 30),
                X_bg,
                feature_names=X_bg.columns
            )
            sv = explainer(X_mod)[0]
            
            if hasattr(model.model.net, "attention"):                
                X_tensor = torch.tensor(X_mod.values, dtype=torch.float32)
                x_effective = compute_effective_input(model, X_tensor)
                mask = (x_effective != 0).astype(float)
                sv.values = sv.values * mask
             
            sv_exp = shap.Explanation(
                sv.values,
                base_values=sv.base_values,
                data=sv.data,
                feature_names=X_bg.columns.tolist()
            )
            img = create_shap_image(sv_exp, max_display=max_shap_display)

            ax = fig.add_subplot(gs_shap[r, c])
            ax.imshow(img)
            ax.axis("off")
            
            if r == 0:
                ax.set_title(lbl, fontsize=12)
            if c == 0:
                ax.annotate(
                    f"{H} mo",
                    xy=(-0.1, 0.5),
                    xycoords="axes fraction",
                    rotation=90,
                    va="center",
                    ha="right",
                    fontsize=11
                )
    plt.tight_layout()
    plt.show()
    

In [None]:
preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

plot_survival_and_shap_grid(
    model=trained_models["DeepSurv_attention"],
    preprocessor=preprocessor,
    settings=settings,
    treatment_map=valid_treatment_combinations,
    patient_id=950861304    
)

# also possible to give a new patient row:
# patient = {'variable1': 0, ... 'lastvariable': 1}
# plot_survival_and_shap_grid(
#     model=trained_models["DeepSurv_attention"],
#     preprocessor=preprocessor,
#     settings=settings,
#     treatment_map=valid_treatment_combinations,
#     patient_row=patient
# )

### Predicted Median Survival per Treatment
This section compares the predicted median survival times across different systemic treatments for the full patient cohort.

**Violin plots** are used to capture both the distribution and variability of median/auc survival times under each treatment scenario.




In [None]:
def generate_violin_data_real_treatment_with_observed(preprocessor, model, treatment_map, settings, msi_only=False):
    df_all, _, _ = preprocessor.preprocess_data([f for f in lookup_manager.features if f != "ncrId"])
    df_all = df_all.loc[:, ~df_all.columns.duplicated()].copy()

    if msi_only:
        if "hasMsi" not in df_all.columns:
            raise ValueError("MSI feature (hasMsi) not found in dataset.")
        df_all = df_all[df_all["hasMsi"] == 1]
        if df_all.empty:
            raise ValueError("No MSI patients found!")

    model_input_cols = df_all.drop(columns=[settings.event_col, settings.duration_col, "ncrId"], errors="ignore").columns.tolist()
    treatment_cols = [c for c in model_input_cols if c.startswith("systemicTreatmentPlan")]
    time_grid = np.linspace(0, 2000, 100)

    results = []

    for treatment_name, treatment_flags in treatment_map.items():
        mask = np.ones(len(df_all), dtype=bool)
        for drug_flag, expected_val in treatment_flags.items():
            if drug_flag in df_all.columns:
                mask &= (df_all[drug_flag] == expected_val)

        selected_patients = df_all[mask]

        if selected_patients.empty:
            print(f"Skipping {treatment_name}: no patients found.")
            continue

        for _, row in selected_patients.iterrows():
            row_input = row[model_input_cols]
            row_df = pd.DataFrame([row_input])

            try:
                surv_fn, = model.predict_survival_function(row_df)
                surv_prob = surv_fn(time_grid)
                median_days_predicted, auc_days_predicted = compute_survival_stats(time_grid, surv_prob)
                
                # median_predicted = compute_median_survival(model, row_df, time_grid)
                days_observed = row[settings.duration_col]
                results.append({
                    "treatment": treatment_name,
                    "predicted_median_survival": median_days_predicted,
                    "predicted_auc_survival": auc_days_predicted,
                    "observed_survival": days_observed
                })
            except Exception as e:
                print(f"Skipping patient in {treatment_name} due to error: {e}")
                continue

    return pd.DataFrame(results)

def plot_predicted_and_observed_violin(df_violin, type_days = 'median'):
    plt.figure(figsize=(18, 7))

    if type_days == 'median':
        data = df_violin.melt(id_vars="treatment", 
                                value_vars=["predicted_median_survival", "observed_survival"], 
                                var_name="Type", value_name="Survival")
    else: 
        data = df_violin.melt(id_vars="treatment", 
                                value_vars=["predicted_auc_survival", "observed_survival"], 
                                var_name="Type", value_name="Survival")
   
    sns.violinplot(
        data=data,
        x="treatment",
        y="Survival",
        hue="Type",
        split=True,
        inner="quartile"
    )

    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Survival Time (days)")
    plt.title("Predicted vs Observed Survival per Treatment")
    plt.grid(axis="y", linestyle="--", alpha=0.5)
    plt.legend(title="Type", loc="upper right")
    plt.tight_layout()
    plt.show()

In [None]:
df_violin = generate_violin_data_real_treatment_with_observed(
    preprocessor,
    trained_models["DeepSurv_attention"],
    valid_treatment_combinations,
    settings,
    msi_only=True
)

plot_predicted_and_observed_violin(df_violin)

In [None]:
plot_predicted_and_observed_violin(df_violin, 
    type_days = 'auc')

### Treatment-Specific Time-Dependent AUCs
To further evaluate model discrimination, we compute time-dependent AUCs for survival prediction within each treatment subgroup.
- AUCs are calculated at 1, 2, 3, 4, and 5 years post-treatment.
- Treatment groups with fewer than 30 patients were excluded to ensure statistical reliability.

In [None]:
from sksurv.util import Surv
from sksurv.metrics import cumulative_dynamic_auc

def compute_time_dependent_auc_real_treatment(
    model, preprocessor, settings, treatment_map, times_in_days
):
    df_all, _, _ = preprocessor.preprocess_data(
        [f for f in lookup_manager.features if f != "ncrId"]
    )
    df_all = df_all.loc[:, ~df_all.columns.duplicated()].copy()

    drop_cols = [settings.event_col, settings.duration_col, "ncrId"]
    input_cols = [c for c in df_all.columns if c not in drop_cols]

    results = {}

    for name, flags in treatment_map.items():
        mask = np.ones(len(df_all), dtype=bool)
        for df, val in flags.items():
            if df in df_all:
                mask &= (df_all[df] == val)
        sel = df_all[mask]
        n = len(sel)
        if n < 30:
            print(f"Skipping {name}: only {n} patients")
            continue

        print(f"Computing AUCs for {name} on {n} patients…")

        y_df = sel[[settings.event_col, settings.duration_col]]
        y_struct = Surv.from_dataframe(
            settings.event_col, settings.duration_col, y_df
        )

        min_t = y_df[settings.duration_col].min()
        max_t = y_df[settings.duration_col].max()
        times = [t for t in times_in_days if min_t <= t <= max_t]
        if not times:
            print(f"  no valid times in [{min_t:.0f},{max_t:.0f}]")
            continue

        X = sel[input_cols]
        surv_fns = model.predict_survival_function(X)

        surv_mat = np.vstack([
            np.interp(times, fn.x, fn.y) for fn in surv_fns
        ])  
        
        risk_mat = 1 - surv_mat
        aucs, _ = cumulative_dynamic_auc(y_struct, y_struct, risk_mat, np.array(times))
        
        col_names = [f"AUC_{int(t/365)}yr" for t in times]
        results[name] = dict(zip(col_names, aucs))

    return pd.DataFrame(results).T


In [None]:
auc_per_treatment = compute_time_dependent_auc_real_treatment(
    trained_models["DeepSurv_attention"],
    preprocessor,
    settings,
    valid_treatment_combinations,
    times_in_days=[365, 730, 1095, 1460, 1825]  # 1,2,3,4,5 years
)

print(auc_per_treatment)


### Predicted vs Observed

To check that our model does **more than spit out the same prognosis for every patient**, we plot the **predicted median survival (in days)** against the **actually observed survival time** for each individual.

> **Why not predict the exact number of days?**  
> Predicting a precise survival day is extremely difficult—many clinical and biological factors that affect outcome are not (or cannot be) captured in the registry.  
> That is exactly why we focus on *survival probabilities* (or the median of a full survival curve) instead of a single “day‑of‑death” estimate.  
> Consequently, comparing the model’s *median* to the *exact* observed day is not a perfectly fair yard‑stick—but it is still a useful sanity check.




In [None]:
from matplotlib.ticker import MaxNLocator
from sklearn.metrics import mean_absolute_error, mean_squared_error

def plot_predicted_vs_observed_hexbin(
    model,
    preprocessor,
    settings,
    *,
    max_observed_days: int,
    metric: str = "median",
    figsize=(6,6),
    gridsize: int = 60,
    cmap: str = "Oranges",
    bins: str = "log",
    show_fit: bool = False, 
    print_errors: bool = True
):
    
    feats = ["ncrId"] + [f for f in lookup_manager.features if f != "ncrId"]
    df_all, _, _ = preprocessor.preprocess_data(feats)
    df_all = df_all.loc[:, ~df_all.columns.duplicated()]

    X = df_all.drop(columns=["ncrId", settings.event_col, settings.duration_col])
    times_obs = df_all[settings.duration_col].values

    sfs = model.predict_survival_function(X)
    metrics_pred = []
    for sf in sfs:
        time_grid = sf.x
        surv_probs = sf.y
        median_days, auc_days = compute_survival_stats(time_grid, surv_probs)
        if metric == "median":
            metrics_pred.append(median_days)
        elif metric == "auc":
            metrics_pred.append(auc_days)
        else:
            raise ValueError("metric must be 'median' or 'auc'")
            
    metrics_pred = np.array(metrics_pred)

    mask = times_obs <= max_observed_days
    tobs = times_obs[mask]
    tpred = metrics_pred[mask]
    
    if print_errors:
        mae  = mean_absolute_error(tobs, tpred)
        rmse = np.sqrt(mean_squared_error(tobs, tpred))
        print(f"{metric.upper()} prediction error (days):")
        print(f"  MAE  = {mae:.1f}")
        print(f"  RMSE = {rmse:.1f}")


    fig, ax = plt.subplots(figsize=figsize, dpi=110)
    ax.plot([0, max_observed_days], [0, max_observed_days],
            ls=":", color="black")
    hb = ax.hexbin(
        tobs, tpred,
        gridsize=gridsize,
        cmap=cmap,
        bins=bins
    )
    fig.colorbar(hb, ax=ax, label=f"{bins} count")

    if show_fit:
        valid = ~np.isnan(tpred)
        if valid.sum() > 1:
            coeffs = np.polyfit(tobs[valid], tpred[valid], 1)
            xfit = np.linspace(0, max_observed_days, 100)
            ax.plot(xfit, np.polyval(coeffs, xfit),
                    ls="--", color="tab:orange")

    ax.set_xlabel("Observed survival (days)")
    ax.set_ylabel("Predicted median (days)")
    ax.set_title(f"Pred vs Obs (hexbin, ≤ {max_observed_days}d, metric={metric})")
    ax.set_xlim(0, max_observed_days)
    ax.set_ylim(0, 1200)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    ax.grid(True, ls=":", lw=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
plot_predicted_vs_observed_hexbin(
    model              = trained_models["DeepSurv_attention"],
    preprocessor       = preprocessor,
    settings           = settings,
    max_observed_days  = 365,   
    metric             = "median",
    gridsize           = 80,         
    bins               = "log",      
    show_fit           = True
)

In [None]:
plot_predicted_vs_observed_hexbin(
    model              = trained_models["DeepHitModel_attention"],
    preprocessor       = preprocessor,
    settings           = settings,
    max_observed_days  = 365,         
    metric             = "auc",
    gridsize           = 80,         
    bins               = "log",      
    show_fit           = True
)

### Model prediction: No Treatment
In this section, we aim to understand how well our model identifies patients for whom "no treatment" might not be worse than receiving treatment. We do this by comparing model predictions to actual observed outcomes.

Clinical intuition—and data—tell us that untreated patients generally have worse outcomes. Yet our model occasionally predicts that no treatment could be among the better options. We want to:

- Verify how often the model makes this prediction.
- Compare it to actual outcomes using patient survival.
- Understand for which patients the model is right (or wrong).

We bidirectionally match treated and untreated patients. For each patient (treated or untreated), match to nearest neighbor in opposite group and compare observed survival (OS). Each match is only counted once — no double-counting.
   

In [None]:
from matplotlib.patches import Circle
from sklearn.neighbors import NearestNeighbors

def count_no_treatment_rankings(
    model, preprocessor, settings, treatment_map, horizon_days=365
) -> pd.DataFrame:
    """Compute per-patient ranking of 'No Treatment' against other treatments."""
    df_raw = preprocessor.load_data()
    features = ["ncrId"] + lookup_manager.features
    df_all, updated_features, _ = preprocessor.preprocess_data(features, df=df_raw)
    treatment_cols = [c for c in updated_features if c.startswith("systemicTreatmentPlan")]

    records = []
    for pid in df_all["ncrId"].unique():
        row = df_all[df_all["ncrId"] == pid]
        X_base = row[updated_features].drop(columns=["ncrId"])
        msi_flag = int(X_base.get("hasMsi", pd.Series([0])).iloc[0])
        risk_scores = {
            label: compute_risk_at_horizon(model, 
                   apply_treatment(X_base, mapping, treatment_cols, msi_flag), horizon_days)[0]
            for label, mapping in treatment_map.items()
        }
        no_tx = risk_scores.get("No Treatment")
        if no_tx is not None:
            worse_count = sum(1 for l, r in risk_scores.items() if l != "No Treatment" and r > no_tx)
            records.append({
                "ncrId": pid,
                "treatments_worse": worse_count
            })
    return pd.DataFrame(records)


def compare_matched_os_bidirectional(preprocessor, settings, lookup_manager, n_neighbors=5):
   
    raw = preprocessor.load_data()
    features = ["ncrId"] + lookup_manager.features
    df_all, updated_features, _ = preprocessor.preprocess_data(features, df=raw)
    
    treated = df_all[df_all["hasTreatment"] == 1].copy()
    untreated = df_all[df_all["hasTreatment"] == 0].copy()

    feature_cols = [
        col for col in updated_features
        if pd.api.types.is_numeric_dtype(df_all[col])
        and col not in [settings.duration_col, settings.event_col, "hasTreatment"]
    ]

    nn_untreated = NearestNeighbors(n_neighbors=n_neighbors).fit(untreated[feature_cols])
    nn_treated = NearestNeighbors(n_neighbors=n_neighbors).fit(treated[feature_cols])

    matches = []

    distances, indices = nn_untreated.kneighbors(treated[feature_cols])
    for i, tr in treated.reset_index(drop=True).iterrows():
        ut_idx = indices[i][0]
        ut = untreated.iloc[ut_idx]
        if tr["ncrId"] < ut["ncrId"]: 
            matches.append({
                "ncrId_patient": tr["ncrId"],
                "os_patient": tr[settings.duration_col],
                "ncrId_match": ut["ncrId"],
                "os_match": ut[settings.duration_col],
                "patient_treated": True
            })

    distances, indices = nn_treated.kneighbors(untreated[feature_cols])
    for i, ut in untreated.reset_index(drop=True).iterrows():
        tr_idx = indices[i][0]
        tr = treated.iloc[tr_idx]
        if ut["ncrId"] < tr["ncrId"]:  # prevent duplicate pairs
            matches.append({
                "ncrId_patient": ut["ncrId"],
                "os_patient": ut[settings.duration_col],
                "ncrId_match": tr["ncrId"],
                "os_match": tr[settings.duration_col],
                "patient_treated": False
            })

    df_matches = pd.DataFrame(matches)
    df_matches["patient_not_worse"] = df_matches["os_patient"] > df_matches["os_match"]
    
    not_worse_ids = set(df_matches.loc[df_matches["patient_not_worse"], "ncrId_patient"])

    return df_matches, not_worse_ids

def compute_overlap_statistics_bidirectional(
    model, preprocessor, settings, lookup_manager, treatment_map, n_neighbors=1, horizon_days=365
):
    df_pred = count_no_treatment_rankings(
        model, preprocessor, settings, treatment_map, horizon_days
    )
    df_all, *_ = preprocessor.preprocess_data(["ncrId"] + lookup_manager.features)
    pred_set = set(df_pred.loc[df_pred["treatments_worse"] > 0, "ncrId"])
    pred_count = len(pred_set)
    pred_total = df_all.shape[0]

    df_matches, actual_set = compare_matched_os_bidirectional(
        preprocessor, settings, lookup_manager, n_neighbors
    )
    actual_count = len(actual_set)
    actual_total = df_all.shape[0]

    overlap = actual_set & pred_set
    overlap_count = len(overlap)

    return {
        "actual":    {"total": actual_total,    "count": actual_count},
        "predicted": {"total": pred_total,      "count": pred_count},
        "overlap":   {"count": overlap_count}
    }

def plot_area_venn(stats):
    actual = stats["actual"]["count"]
    predicted = stats["predicted"]["count"]
    overlap = stats["overlap"]["count"]
    actual_only = actual - overlap
    predicted_only = predicted - overlap

    r1 = np.sqrt(actual / np.pi)
    r2 = np.sqrt(predicted / np.pi)

    def intersection_area(r0, r1, d):
        if d >= r0 + r1: return 0
        if d <= abs(r1 - r0): return np.pi * min(r0, r1)**2
        r0sq, r1sq, dsq = r0*r0, r1*r1, d*d
        alpha = np.arccos((dsq + r0sq - r1sq) / (2*d*r0))
        beta  = np.arccos((dsq + r1sq - r0sq) / (2*d*r1))
        return (r0sq*alpha + r1sq*beta
                - 0.5*np.sqrt((-d+r0+r1)*(d+r0-r1)*(d-r0+r1)*(d+r0+r1)))

    lo, hi = abs(r1 - r2), r1 + r2
    for _ in range(50):
        mid = (lo + hi) / 2
        if intersection_area(r1, r2, mid) < overlap:
            hi = mid
        else:
            lo = mid
    d = (lo + hi) / 2

    fig, ax = plt.subplots(figsize=(8,6))
    ax.add_patch(Circle((0,0),   r1, color='skyblue', alpha=0.5))
    ax.add_patch(Circle((d,0),   r2, color='salmon',  alpha=0.5))
    ax.text(-r1*0.3, 0, str(actual_only), ha='center', va='center', fontsize=12)
    ax.text(d+r2*0.3,0, str(predicted_only), ha='center', va='center', fontsize=12)
    ax.text(d/2,      0, str(overlap),      ha='center', va='center', fontsize=12, color='white')
    ax.text(0,   r1+2, f"Actual Not Worse\n(n={actual})", ha='center', va='bottom', fontsize=13)
    ax.text(d,   r2+2, f"Predicted Not Worse\n(n={predicted})", ha='center', va='bottom', fontsize=13)
    ax.set_xlim(-r1*1.2, d+r2*1.2)
    ax.set_ylim(-max(r1,r2)*1.2, max(r1,r2)*1.2 + 2.5)
    ax.set_aspect('equal')
    ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
stats = compute_overlap_statistics_bidirectional(
    trained_models['DeepSurv_attention'],
    preprocessor, settings, lookup_manager, valid_treatment_combinations,
    n_neighbors=1, horizon_days=365
)

print(f"Actual fraction not worse:    {stats['actual']['count']}/{stats['actual']['total']} "
      f"= {stats['actual']['count']/stats['actual']['total']:.1%}")
print(f"Predicted fraction not worse: {stats['predicted']['count']}/{stats['predicted']['total']} "
      f"= {stats['predicted']['count']/stats['predicted']['total']:.1%}")

plot_area_venn(stats)

#### Feature differences in untreated patients who survive longer
In this section, we explore why some untreated patients survive longer than expected—both according to the model and in real-world observations.
To do so, we analyze which patient features are associated with better survival outcomes without systemic treatment. Specifically, we look at:

- Patients the model predicts will do better without treatment.
- Patients who actually did better than their matched treated counterparts.

We compare these "not worse" patients to other untreated patients who did fare worse, using two strategies:

1. **Standardized Mean Differences (SMD)**:
    This helps us understand which features differ the most between untreated patients who do better and those who do not. We compute SMDs both for *model-predicted* and *actually observed* "not worse" patients (via matched survival comparison)

2. **Paired Feature Differences vs. Matched Treated Patients**:
    For untreated patients who survived longer than their matched treated peers, we look at average differences in feature values compared to those treated patients. This gives insight into what makes some untreated patients truly exceptional.

The goal is to uncover whether certain biological, clinical, or demographic characteristics help explain why some patients might not need systemic therapy—and whether the model is learning these patterns or not.

In [None]:
def analyze_untreated_feature_differences(
    model,
    preprocessor,
    settings,
    lookup_manager,
    treatment_map,
    n_neighbors: int = 5,
    horizon_days: int = 365
):
    df_pred = count_no_treatment_rankings(model, preprocessor, settings, treatment_map, horizon_days)
    pred_not_worse = set(df_pred.loc[df_pred["treatments_worse"] > 0, "ncrId"])

    df_matches, actual_not_worse = compare_matched_os_bidirectional(
        preprocessor, settings, lookup_manager, n_neighbors
    )

    features = ["ncrId"] + lookup_manager.features
    df_all, updated_features, _ = preprocessor.preprocess_data(features)
    unt = df_all[df_all["hasTreatment"] == 0].copy()
    unt["actual_not_worse"] = unt["ncrId"].isin(actual_not_worse)
    unt["pred_not_worse"]   = unt["ncrId"].isin(pred_not_worse)

    num_feats = [
        col for col in updated_features
        if pd.api.types.is_numeric_dtype(df_all[col])
        and col not in [settings.duration_col, settings.event_col, "hasTreatment", 'ncrId']
    ]

    def smd(a, b):
        m1, m2 = a.mean(), b.mean()
        s1, s2 = a.std(ddof=1), b.std(ddof=1)
        pooled = np.sqrt((s1*s1 + s2*s2) / 2)
        return (m1 - m2) / pooled

    grp_true  = unt.loc[unt["actual_not_worse"], num_feats]
    grp_false = unt.loc[~unt["actual_not_worse"], num_feats]
    smd_actual = {f: smd(grp_true[f], grp_false[f]) for f in num_feats}
    df_smd_actual = (
        pd.DataFrame.from_dict(smd_actual, orient="index", columns=["SMD"])
          .assign(absSMD=lambda df: df["SMD"].abs())
          .sort_values("absSMD", ascending=False)
    )

    grp_true  = unt.loc[unt["pred_not_worse"], num_feats]
    grp_false = unt.loc[~unt["pred_not_worse"], num_feats]
    smd_pred = {f: smd(grp_true[f], grp_false[f]) for f in num_feats}
    df_smd_pred = (
        pd.DataFrame.from_dict(smd_pred, orient="index", columns=["SMD"])
          .assign(absSMD=lambda df: df["SMD"].abs())
          .sort_values("absSMD", ascending=False)
    )

    pairs = df_matches.query("patient_treated == False and patient_not_worse")
    left = pairs[["ncrId_patient", "ncrId_match"]].rename(columns={"ncrId_patient": "ncrId"})
    df_pair = (
        left
        .merge(df_all[["ncrId"] + num_feats], on="ncrId")
        .merge(
            df_all[["ncrId"] + num_feats].set_index("ncrId"),
            left_on="ncrId_match", right_index=True,
            suffixes=("_unt", "_tr")
        )
    )

    delta = {
        f: (df_pair[f+"_unt"] - df_pair[f+"_tr"]).mean()
        for f in num_feats
    }
    df_delta = (
        pd.DataFrame.from_dict(delta, orient="index", columns=["Mean_Untreated−Treated"])
          .assign(absDiff=lambda df: df["Mean_Untreated−Treated"].abs())
          .sort_values("absDiff", ascending=False)
    )

    return df_smd_actual, df_smd_pred, df_delta


In [None]:
df_smd_actual, df_smd_pred, df_delta = analyze_untreated_feature_differences(
    trained_models["DeepSurv_attention"],
    preprocessor, settings, lookup_manager,
    valid_treatment_combinations,
    n_neighbors=5
)

print("Actual SMDs:")
display(df_smd_actual.head(10))

print("Predicted SMDs:")
display(df_smd_pred.head(10))

print("Mean Pairwise Deltas:")
display(df_delta.head(10))

In [None]:
def plot_smds(df_smd, title, number_of_features = 10):

    df_smd.index = [feature_short_names.get(f, f) for f in df_smd.index]

    df_top = df_smd.head(number_of_features).sort_values("SMD")
    colors = ['salmon' if v < 0 else 'skyblue' for v in df_top["SMD"]]

    plt.figure(figsize=(8,6))
    plt.barh(df_top.index, df_top["SMD"], color=colors)
    plt.axvline(0, color='black', linewidth=0.8)
    plt.title(title)
    plt.xlabel("Standardized Mean Difference")
    plt.tight_layout()
    plt.show()

plot_smds(df_smd_actual, "Top features distinguishing actual not-worse untreated patients")
plot_smds(df_smd_pred,   "Top features driving model predictions for not-worse untreated patients")


### Classification-Based Evaluation of Survival Models

Survival models typically predict a **survival curve** over time. However, in many clinical situations, we may want a simpler answer to the question:

> *"Will this patient survive at least 1 year?"*

To evaluate this, we reframe survival prediction as a **binary classification task**:

- Patients who survive beyond a defined time horizon (e.g. 365 days) are labeled **positive**.
- Patients who die **before** that horizon are labeled **negative**.
- Patients censored **before** the horizon are **excluded** from evaluation (since their true outcome is unknown).

We use the model's survival probability at the chosen horizon as a **pseudo-probability** of surviving:
- If the model predicts a survival probability > *certain threshold* at 365 days → predict *survives*.
- Otherwise → predict *does not survive*.

For each model, we compute classic classification metrics: **Accuracy**, **Precision**, **Recall**, **F1-score**, **ROC AUC**

These metrics provide an intuitive understanding of how well the model distinguishes between short-term and long-term survivors.

In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score,
    recall_score, f1_score, roc_auc_score
)

def evaluate_survival_as_classifier(
    models: dict,
    preprocessor,
    settings,
    horizon_days: int = 365,
    threshold: float = 0.5
) -> pd.DataFrame:

    feats = ["ncrId"] + [f for f in lookup_manager.features if f != "ncrId"]
    df_all, _, _ = preprocessor.preprocess_data(feats)
    df_all = df_all.loc[:, ~df_all.columns.duplicated()].copy()

    durations = df_all[settings.duration_col].values
    events    = df_all[settings.event_col].astype(bool).values

    mask_known = ~((durations <= horizon_days) & (~events))
    durations = durations[mask_known]
    events    = events[mask_known]
    
    y_true = durations > horizon_days

    X = df_all.drop(columns=["ncrId", settings.duration_col, settings.event_col])
    X = X.iloc[mask_known]

    results = {}
    for name, model in models.items():
        surv_funcs = model.predict_survival_function(X)
        probs = np.array([fn(horizon_days) for fn in surv_funcs])

        y_pred = probs > threshold

        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, zero_division=0)
        rec = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        try:
            auc = roc_auc_score(y_true, probs)
        except ValueError:
            auc = np.nan

        results[name] = {
            "accuracy":  acc,
            "precision": prec,
            "recall":    rec,
            "f1":        f1,
            "roc_auc":   auc,
            "n_samples": len(y_true)
        }

    return pd.DataFrame(results).T

In [None]:
year_survival_df = evaluate_survival_as_classifier(
    trained_models,
    preprocessor,
    settings,
    horizon_days=365,
    threshold=0.45
)
print("1-year performance:\n", year_survival_df)

## Comparison of Models
Next, we focus on the patient's actual treatment choice and compare the predictions from the best-performing models. By evaluating the survival curves generated by both models for the chosen treatment, we assess their agreement and gain a deeper understanding of the patient-specific predictions.

In [None]:
plot_different_models_survival_curves(
    trained_models=models_to_evaluate,
    X_test=X_test,
    y_test=y_test,
    patient_index=70
)