# Model Evaluation
This notebook demonstrates the interpretation and evaluation of the models (as trained in `predictive_algorithms_training.ipynb`):
- Performance Evaluation: Comparing models based on metrics such as concordance index (C-Index), integrated Brier score (IBS), calibration error (CE), and time-dependent AUC.
- Visualization: Generating survival curves and feature importance plots to interpret model predictions and uncover key insights.

In the file `utils/settings.py` all the experiment settings can be set (e.g. OS or PFS, grouped treatments or not), then the experiment can be run in this notebook.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.chdir('/data/repos/actin-personalization/scripts/personalization/prediction')

In [None]:
import pandas as pd
import seaborn as sns
import ast
import matplotlib.pyplot as plt
import dill
import torch

from src.models import *
from src.data.data_processing import DataSplitter, DataPreprocessor
from src.data.lookups import lookup_manager
from src.utils.settings import settings

In [None]:
import nbimporter
from src.predictive_algorithms_training import get_data, plot_different_models_survival_curves

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

### Metric comparison: OS vs. PFS

The trained models are evaluated using the following metrics:

- **C-Index**: The Concordance Index measures how well the predicted survival times align with the actual outcomes. It is a measure of discrimination, indicating the model's ability to correctly rank the survival times of patients. A higher value indicates better predictive accuracy.

- **Integrated Brier Score (IBS)**: This metric evaluates the accuracy of the survival probability predictions over time. It combines the squared differences between predicted and actual survival probabilities, weighted by the survival distribution. Lower values indicate better predictive performance.

- **Calibration Error (CE)**: Calibration error assesses how well the predicted survival probabilities match the observed probabilities. It indicates whether the model is systematically overestimating or underestimating survival probabilities. Lower values signify better calibration.

- **Area Under the Curve (AUC)**: For survival models, AUC is typically computed over a time-dependent ROC curve, reflecting the model's discrimination ability at different time points. Higher AUC values indicate better overall performance.

This section visualizes the comparison of model performance metrics (C-Index, IBS, CE, AUC) for OS and PFS. The bar plots highlight the strengths and weaknesses of each model in the two prediction tasks. 


In [None]:
def load_model_outcomes():
    csv_file = os.path.join(settings.save_path, f"{settings.outcome}_model_outcomes.csv")
    
    if os.path.exists(csv_file):
        results_df = pd.read_csv(csv_file)
        print(f"Loaded model outcomes from {csv_file}")
    else:
        raise FileNotFoundError(f"No saved outcomes found for {settings.outcome} in {settings.save_path}")
    
    return results_df

In [None]:
model_outcomes = load_model_outcomes()

In [None]:
def extract_holdout_metrics(df):
    if df['holdout'].apply(lambda x: isinstance(x, str)).any():
        df['holdout'] = df['holdout'].apply(ast.literal_eval)
    
    holdout_metrics = df['holdout'].apply(pd.Series)
    holdout_metrics['Model'] = df['Model']
    
    return holdout_metrics


def plot_all_metrics(df, holdout=True):
    metrics = ['c_index', 'ibs', 'ce', 'auc']
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for ax, metric in zip(axes, metrics):
        sns.barplot(x='Model', y=metric, data=df, ax=ax, palette='Set1')
        title = f"{metric.upper()} Comparison"
        
        ax.set_title(title)
        ax.set(xlabel='Model', ylabel=metric.upper())
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        
        if metric in ['c_index', 'auc']:
            ax.set_ylim(0, 1)
        
        # Set bar colors based on the metric values for some visual weighting
        min_val, max_val = df[metric].min(), df[metric].max()
        cmap = sns.light_palette("#79C", as_cmap=True)
        for patch in ax.patches:
            height = patch.get_height()
            normalized = (height - min_val) / (max_val - min_val) if max_val > min_val else 0.5
            patch.set_facecolor(cmap(0.4 + 0.8 * normalized))
    
    plt.tight_layout()
    plt.show()
    
    return df

In [None]:
plot_all_metrics(model_outcomes)

### Import Trained Models
The pretrained models are stored in the Google Cloud Storage bucket: `gs://actin-personalization-models-v1/trained_models/`. 

To download the saved models from the bucket to the trained_models map, run the following command in your terminal:

`gsutil -m cp -r gs://actin-personalization-models-v1/trained_models/./trained_models/`

Make sure the trained_models folder is inside the models folder.


In [None]:
def load_trained_model(model_name, model_class, model_kwargs={}):
    model_file_prefix = os.path.join(settings.save_path, f"{settings.outcome}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
        
    if model_name in ['CoxPH', 'RandomSurvivalForest', 'GradientBoosting', 'AalenAdditive']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        model = model_class(**model_kwargs)
    
        state = torch.load(nn_file, map_location=torch.device('cpu'))
        model.model.net.load_state_dict(state['net_state'])
        
        if 'baseline_hazards' in state and state['baseline_hazards'] is not None:
            model.model.baseline_hazards_ = state['baseline_hazards']
            model.model.baseline_cumulative_hazards_ = state['baseline_cumulative_hazards']
            
            print(f"Baseline hazards loaded for {model_name}.")
        model.model.net.eval()     
        print(f"Model {model_name} loaded from {nn_file}")
        
        return model
    
def load_all_trained_models():
    loaded_models = {}
    config_mgr = ExperimentConfig(settings.json_config_file)
    loaded_configs = config_mgr.load_model_configs()

    for model_name, (model_class, model_kwargs) in loaded_configs.items():
        loaded_model = load_trained_model(
            model_name=model_name, 
            model_class=model_class, 
            model_kwargs=model_kwargs
        )
        loaded_models[model_name] = loaded_model
    return loaded_models

In [None]:
trained_models = load_all_trained_models()

### Time-Dependent ROC-AUC

This section visualizes the ROC curves and computes the AUC for survival models at specific time intervals for both OS and PFS. By evaluating the models' discriminative performance over time, we identify which models perform best at different prediction horizons.

Time interval:
- **Overall Survival (OS)**: For OS, the follow-up times in the dataset extend up to 5 years, allowing us to evaluate model performance over this longer horizon. As survival outcomes often have a broader timespan, a 5-year evaluation provides a comprehensive view of the model's ability to predict long-term survival.

- **Progression-Free Survival (PFS)**: In contrast, PFS events typically occur sooner than OS events. Instead, we limit the PFS analysis to 3 years, which encompasses the majority of observed progression events while maintaining meaningful statistical power.

The ROC curves for the best-performing models (Gradient Boosting, CoxPH, DeepSurv, and RSF) are plotted for each time interval, showcasing how well the models distinguish patients at risk across the respective timeframes for OS and PFS.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sksurv.util import Surv
from sksurv.metrics import cumulative_dynamic_auc

def calculate_time_dependent_auc_for_models(model_dict, X_train, y_train, X_test, y_test, time_points):
  
    y_train_df = pd.DataFrame({'duration': y_train[settings.duration_col], 'event': y_train[settings.event_col]}, index=X_train.index)
    y_train_struct = Surv.from_dataframe('event', 'duration', y_train_df)

    y_test_df = pd.DataFrame({'duration': y_test[settings.duration_col], 'event': y_test[settings.event_col]}, index=X_test.index)
    y_test_struct = Surv.from_dataframe('event', 'duration', y_test_df)

    auc_results = {}

    for model_name, model in model_dict.items():
        if hasattr(model, "model") and hasattr(model.model, "predict"):
            preds = model.model.predict(X_test.values.astype("float32"))
        else:
            preds = model.predict(X_test)

        if preds.ndim == 1:
            risk_scores = preds
        elif preds.shape[1] == 1:
            risk_scores = preds.ravel()
        else:
            T = preds.shape[1]
            desired_T = len(time_points)
            if T == desired_T:
                risk_scores = preds
            elif T % desired_T == 0:
                factor = T // desired_T
                risk_scores = preds[:, ::factor]

        auc_values, mean_auc = cumulative_dynamic_auc(y_train_struct, y_test_struct, risk_scores, time_points)
        auc_results[model_name] = (auc_values, mean_auc)

    plt.figure(figsize=(10, 6))
    years = [t / 365.0 for t in time_points]

    for model_name, (auc_vals, mean_auc) in auc_results.items():
        plt.plot(years, auc_vals, marker='o', label=f"{model_name} (Mean AUC={mean_auc:.3f})")

    plt.xlabel("Time (years)")
    plt.ylabel("Time-Dependent AUC")
    plt.title(settings.outcome)
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

    return auc_results


In [None]:
models_to_evaluate = {
    "DeepSurv": trained_models["DeepSurv"],
    "LogisticHazardModel": trained_models["LogisticHazardModel"],
    # "DeepHitModel": trained_models["DeepHitModel"],
    # "PCHazardModel": trained_models["PCHazardModel"],
    # "MTLRModel": trained_models["MTLRModel"],
    # "GradientBoosting": trained_models["GradientBoosting"],   
    # "RandomSurvivalForest": trained_models["RandomSurvivalForest"],    
}

calculate_time_dependent_auc_for_models(
    models_to_evaluate, X_train, y_train, X_test, y_test, 
    time_points=settings.time_points
)

## Model Interpretation

### Feature Importance

Feature importance analysis helps us identify which features most strongly influence survival predictions across various models. In this section SHAP (SHapley Additive exPlanations) is used for all models to ensure uniformity and interpretability. SHAP values provide a consistent and locally accurate measure of feature importance for individual predictions.

How SHAP is used:
- For classical models like CoxPH and Aalen Additive, SHAP values are calculated based on risk scores or cumulative hazard coefficients. Custom prediction functions are used when necessary (e.g., for Aalen Additive) to align feature importance with model-specific outputs.
- For tree-based models like Random Survival Forest (RSF) and Gradient Boosting Survival Model (GBM), SHAP values replace traditional feature importance metrics to ensure consistency.
- For neural network-based models like DeepSurv and DeepHit, SHAP values are derived using the model's prediction function.

#### Visualization
SHAP provides the following insights:
- Summary Plot - Bar: Displays the average magnitude of SHAP values for each feature, indicating the overall importance of features in the model.
- Summary Plot - Dot: Highlights the distribution of SHAP values for each feature, showing their impact across different samples.


In [None]:
import shap

def nn_predict(X, model, X_train):
    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X, columns=X_train.columns)
    X_tensor = X.values.astype('float32')
    return model.model.predict(X_tensor)

def custom_aalen_predict(X, model):
    """
    Custom predict function for AalenAdditiveModel.
    Aligns cumulative hazard coefficients with input features.
    """
    cumulative_coefs = model.model.cumulative_hazards_
    X = X[model.selected_features].copy()

    # Interpolate coefficients at the latest time point
    latest_coefs = cumulative_coefs.iloc[-1].values
    
    if len(latest_coefs) > X.shape[1]:
        X = X.copy()
        X.insert(0, "Intercept", 1.0)  

    X_array = X.values
    risk_scores = np.einsum('ij,j->i', X_array, latest_coefs)
    return risk_scores

def shap_interpret_model(model_name, model, X_train, max_features=20, shap_sample=200):

    X_sample = X_train.sample(min(shap_sample, len(X_train)), random_state=42)

    predict_functions = {
        'AalenAdditive': lambda X: custom_aalen_predict(X, model),
        'default': model.predict
    }

    if model_name == 'AalenAdditive':
        prediction_fn = predict_functions['AalenAdditive']
    else:
        try:
            model.predict(X_sample.head(1))
            prediction_fn = predict_functions['default']
        except:
            prediction_fn = lambda x: nn_predict(x, model, X_train)

    explainer = shap.Explainer(prediction_fn, X_sample)
    shap_values = explainer(X_sample)

    print(f"SHAP Summary for {model_name}:")

    if len(shap_values.values.shape) == 3: #If time dimension present
        aggregated_shap = shap_values.values.mean(axis=1)
        shap.summary_plot(aggregated_shap, features=X_sample, plot_type="bar", max_display=max_features)
        shap.summary_plot(aggregated_shap, features=X_sample, max_display=max_features)
    else:
        shap.summary_plot(shap_values, features=X_sample, plot_type="bar", max_display=max_features)
        shap.summary_plot(shap_values, features=X_sample, max_display=max_features)

In [None]:
for model_name, model_instance in trained_models.items():
    print(f"\n--- Interpreting {model_name} ---")
    shap_interpret_model(model_name, model_instance, X_train)

### Model Output

### Median Survival Time Calculation

 Rather than relying on the mean survival time (which can be skewed by tail behavior), we focus on the median survival time, defined as the time point t at which the survival function S(t) first drops below 0.5. This metric is often more robust in practical settings, as it is less sensitive to subtle differences in the survival curve’s tail.

The median survival time is obtained by scanning the survival curve from time zero until finding the earliest point where S(t)≤0.5. If the survival probability never dips below 0.5 within the observed follow-up, the median is considered to be at (or beyond) the maximum time point in our data.



In [None]:
def get_median_and_quartiles_survival_time(times, surv_probs):
    """
    Finds the median (50th percentile), 1st quartile (25th percentile), and 
    3rd quartile (75th percentile) survival times.
    """
    times = np.asarray(times, dtype=float)
    surv_probs = np.asarray(surv_probs, dtype=float)

    def find_time_for_percentile(percentile):
        for i in range(1, len(times)):
            if surv_probs[i] <= percentile:
                x0, x1 = times[i - 1], times[i]
                y0, y1 = surv_probs[i - 1], surv_probs[i]
                frac = (percentile - y0) / (y1 - y0)
                return x0 + frac * (x1 - x0)
        return times[-1]

    median_time = find_time_for_percentile(0.5)
    q1_time = find_time_for_percentile(0.75)  # Q1: 75% survival probability
    q3_time = find_time_for_percentile(0.25)  # Q3: 25% survival probability

    return median_time, q1_time, q3_time


#### Comparison of Treatments
In this section the predicted survival probabilities are visualized for a single patient under different treatment scenarios. By simulating the patient receiving each available treatment, we can compare how the model predicts their survival trajectory across treatments.

This analysis provides insights into the model's predictions for different treatment options, helping to identify potentially better treatment choices for the patient based on the predicted survival probabilities over time.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

def plot_survival_curves_for_different_treatment_options(model, X_test, y_test, X_train, y_train, patient_index, treatment_col, plot_title):
    X_patient_original = X_test.iloc[[patient_index]].copy()
    X_patient_original = X_patient_original[X_test.columns]

    treatment_cols = [col for col in X_test.columns if treatment_col.lower() in col.lower()]
    if len(treatment_cols) == 0:
        print(f"No columns found containing '{treatment_col}' in X_test.columns")
        return pd.DataFrame() 

    treatment_options = []
    if len(treatment_cols) == 1:
        t = treatment_cols[0]
        treatment_options.append(("No Treatment", {t: 0}))
        treatment_options.append(("Treatment", {t: 1}))
    else:
        for t in treatment_cols:
            treatment_name = t.replace(treatment_col, "").lstrip("_")
            treatment_options.append((treatment_name if treatment_name else t, {t: 1}))

    times = np.arange(1, settings.max_time, 30) 
    months_to_check = [6, 12, 18, 24, 30, 36]
    time_points_days = np.array(months_to_check) * 30

    plt.figure(figsize=(10, 6))
    df_treatment_probabilities = []
    colors = plt.cm.tab20(np.linspace(0, 1, len(treatment_options)))

    for i, (treatment_label, option) in enumerate(treatment_options):
        X_patient = X_patient_original.copy()
        for col in treatment_cols:
            X_patient[col] = 0
        for col, value in option.items():
            X_patient[col] = value

        try:
            surv_funcs = model.predict_survival_function(X_patient)
            if not surv_funcs:
                print(f"No survival function returned for treatment: {treatment_label}")
                continue

            surv_func = surv_funcs[0]
            surv_probs = surv_func(times)
            plt.step(times, surv_probs, where="post", color=colors[i], label=treatment_label)

            median, q1, q3 = get_median_and_quartiles_survival_time(times, surv_probs)
            median_surv = surv_func(median)
            plt.text(median, median_surv, f"{int(median)}", fontsize=9, 
                     bbox=dict(facecolor="white", edgecolor=colors[i], boxstyle="round,pad=0.2"),
                     color="black")

            month_probs = surv_func(time_points_days)
            for month, prob in zip(months_to_check, month_probs):
                df_treatment_probabilities.append({
                    "Treatment": treatment_label,
                    "Month": month,
                    "SurvivalProbability": prob
                })
        except Exception as e:
            print(f"Error plotting survival curves for treatment '{treatment_label}': {e}")

    handles = [Line2D([0], [0], color=colors[i], lw=2) for i in range(len(treatment_options))]
    labels = [opt[0] for opt in treatment_options]
    plt.legend(handles, labels, loc='center left', bbox_to_anchor=(1, 0.5))

    plt.title(plot_title)
    plt.xlabel("Time (days)")
    plt.ylabel("Survival Probability")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return pd.DataFrame(df_treatment_probabilities)


In [None]:
best_model = "DeepSurv"
plot_survival_curves_for_different_treatment_options(
    model=trained_models[best_model],
    X_test=X_test,
    y_test=y_test,
    X_train = X_train, 
    y_train = y_train,
    patient_index=100,
    treatment_col="treatment",
    plot_title=f"{settings.outcome}: {settings.experiment_type}"
)

#### Comparison of Models
Next, we focus on the patient's actual treatment choice and compare the predictions from the two best-performing models. By evaluating the survival curves generated by both models for the chosen treatment, we assess their agreement and gain a deeper understanding of the patient-specific predictions.

In [None]:
plot_different_models_survival_curves(
    trained_models=models_to_evaluate,
    X_test=X_test,
    y_test=y_test,
    patient_index=78
)

### Look at NCR specific patient

In [None]:
def get_raw_patient_row(db_config_path: str, db_name: str, ncrId: int):
    import pymysql
    connection = pymysql.connect(
        read_default_file=db_config_path,
        read_default_group='RAnalysis',
        db=db_name
    )
    raw_df = pd.read_sql(f"SELECT * FROM {settings.view_name}", connection)
    connection.close()
    raw_df = raw_df[raw_df["ncrId"] == ncrId]
    return raw_df.iloc[0] if not raw_df.empty else None

def plot_survival_curves_for_ncrId_different_treatments(model, ncrId, features, treatment_col_prefix="systemicTreatmentPlan", plot_title="Survival Curves"):

    features_with_id = features.copy()
    if "ncrId" not in features_with_id:
        features_with_id.insert(0, "ncrId")
    
    preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)
    
    df, updated_features, encoded_columns = preprocessor.preprocess_data(features_with_id)
    
    patient_df = df[df["ncrId"] == ncrId]
    if patient_df.empty:
        print(f"No patient found with ncrId {ncrId}")
        return
    
    row = patient_df.iloc[0]
    
    raw_row = get_raw_patient_row(preprocessor.db_config_path, preprocessor.db_name, ncrId)

    print(f"\n🧬 Patient {ncrId} characteristics:")
    print(f"  - Age at metastasis detection: {raw_row.get('ageAtMetastasisDetection', 'NA')}")
    print(f"  - WHO status: {raw_row.get('whoStatusPreTreatmentStart', 'NA')}")
    print(f"  - MSI status: {'MSI' if row.get('hasMsi', 0) == 1 else 'MSS or NA'}")
    print(f"  - BRAF mutation: {'Yes' if row.get('hasBrafMutation', 0) == 1 else 'No'}")
    print(f"  - BRAF V600E: {'Yes' if row.get('hasBrafV600EMutation', 0) == 1 else 'No'}")
    print(f"  - KRAS G12C: {'Yes' if row.get('hasKrasG12CMutation', 0) == 1 else 'No'}")
    print(f"  - RAS mutation: {'Yes' if row.get('hasRasMutation', 0) == 1 else 'No'}")
       
    X_patient_base = patient_df.drop(columns=["ncrId", "ncrId_missing", settings.event_col, settings.duration_col]).copy()
    
    treatment_cols = [col for col in X_patient_base.columns if col.startswith(treatment_col_prefix)]
    actual_treatments = [col[len(treatment_col_prefix)+1:] for col in treatment_cols if row.get(col, 0) == 1]
    print(f"  - Actual received treatment: {', '.join(actual_treatments) if actual_treatments else 'No Treatment'}\n")
    
    treatment_options = [("No Treatment", {tc: 0 for tc in treatment_cols})]
    
    for tc in treatment_cols:
        option = {t: 0 for t in treatment_cols}
        option[tc] = 1
        treatment_label = tc[len(treatment_col_prefix)+1:] 
        treatment_options.append((treatment_label, option))
    
    actual_duration_days = patient_df[settings.duration_col].iloc[0]
    actual_event = patient_df[settings.event_col].iloc[0]
    
    plt.figure(figsize=(12, 8))
    colors = plt.cm.tab20(np.linspace(0, 1, len(treatment_options)))
    
    for i, (treatment_label, treatment_vals) in enumerate(treatment_options):
        X_patient_modified = X_patient_base.copy()
        for tc in treatment_cols:
            X_patient_modified[tc] = 0
        for tc, val in treatment_vals.items():
            X_patient_modified[tc] = val
      
        try:
            surv_funcs = model.predict_survival_function(X_patient_modified)

            times = np.linspace(
                max(fn.x[0] for fn in surv_funcs),
                min(fn.x[-1] for fn in surv_funcs),
                100
            )
            surv_probs = np.row_stack([fn(times) for fn in surv_funcs])
            plt.step(times / 30.44, surv_probs[0], where="post",
                     color=colors[i],
                     label=f" {treatment_label}")
        except Exception as e:
            print(f"Error plotting survival curves with treatment {treatment_label}: {e}")
    
    marker_color = 'red' if actual_event else 'blue'
    marker_label = "Event Time" if actual_event else "Censoring Time"
    plt.axvline(x=actual_duration_days / 30.44, color=marker_color, linestyle='--', label=marker_label)
    
    plt.title(f"{plot_title} for ncrId: {ncrId}")
    plt.xlabel("Time (months)")
    plt.ylabel("Survival Probability")
    plt.legend(loc="best")
    plt.grid(True)
    plt.show()


In [None]:
plot_survival_curves_for_ncrId_different_treatments(
    model=trained_models['DeepSurv'],
    ncrId=950861304,
    features=lookup_manager.features,
    plot_title=f"{settings.outcome}: Survival Curves"
)