# Predictive Algorithms for Survival Analysis

This notebook demonstrates the pipeline for developing and evaluating predictive algorithms in survival analysis. The primary objective is to model and predict **overall survival (OS)** and **progression-free survival (PFS)** for patients. Using both classical statistical methods and state-of-the-art deep learning techniques, the notebook covers the entire process, including:
- Data Preprocessing: Preparing survival datasets for analysis, ensuring compatibility with various model types.
- Model Training: Building survival models tailored to predict survival outcomes and handle censored data.
- Hyperparameter Optimization: Fine-tuning models for optimal performance.
- Performance Evaluation: Comparing models based on metrics such as concordance index (C-Index), integrated Brier score (IBS), calibration error (CE), and time-dependent AUC.
- Visualization: Generating survival curves and feature importance plots to interpret model predictions and uncover key insights.

This workflow provides a framework to explore survival modeling techniques and tailor them to specific datasets and objectives.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.models import *
from src.data.data_processing import DataSplitter, DataPreprocessor
from src.data.lookups import LookupManager

In [None]:
db_config_path = '/home/jupyter/.my.cnf'
db_name = 'actin_personalization'
query = "SELECT * FROM knownPalliativeTreatments"

preprocessor = DataPreprocessor(db_config_path, db_name)

lookup_manager = LookupManager()
features = lookup_manager.features

## Data Preprocessing

In this section, we set up the data pipeline for survival analysis. The `DataSplitter` and `DataPreprocessor` classes are used to load, preprocess, and split the data into training and testing sets. This ensures the survival data is structured appropriately for model training.


In [None]:
from sksurv.util import Surv

def get_data(query, event_col, duration_col, features):
    splitter = DataSplitter(test_size=0.1, random_state=42)
    
    df, features, encoded_columns = preprocessor.preprocess_data(query, duration_col, event_col, features)
                          
    y = Surv.from_dataframe(event=event_col, time=duration_col, data=df)
    X_train, X_test, y_train, y_test = splitter.split(df[features], df, event_col, encoded_columns)
    
    return df, X_train, X_test, y_train, y_test, encoded_columns

In [None]:
os_df, os_X_train, os_X_test, os_y_train, os_y_test, os_encoded_columns = get_data(
    query, 'hadSurvivalEvent', 'observedOsFromTreatmentStartDays', features
)

In [None]:
pfs_df, pfs_X_train, pfs_X_test, pfs_y_train, pfs_y_test, pfs_encoded_columns = get_data(
    query, 'hadProgressionEvent', 'observedPfsDays', features
)

## Train and Evaluate Models

This section defines the function `train_evaluate_models`, which trains various survival models using predefined configurations. The trained models are evaluated using the following metrics:

- **C-Index**: The Concordance Index measures how well the predicted survival times align with the actual outcomes. It is a measure of discrimination, indicating the model's ability to correctly rank the survival times of patients. A higher value indicates better predictive accuracy.

- **Integrated Brier Score (IBS)**: This metric evaluates the accuracy of the survival probability predictions over time. It combines the squared differences between predicted and actual survival probabilities, weighted by the survival distribution. Lower values indicate better predictive performance.

- **Calibration Error (CE)**: Calibration error assesses how well the predicted survival probabilities match the observed probabilities. It indicates whether the model is systematically overestimating or underestimating survival probabilities. Lower values signify better calibration.

- **Area Under the Curve (AUC)**: For survival models, AUC is typically computed over a time-dependent ROC curve, reflecting the model's discrimination ability at different time points. Higher AUC values indicate better overall performance.

Together, these metrics provide a comprehensive evaluation of the models' predictive performance, capturing different aspects of accuracy, discrimination, and calibration.


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, plot_title, actual_line = True):
    """
    Plot survival curves for a specific patient using trained models.
    """
    X_patient = X_test.iloc[[patient_index]]
    actual_duration_days = y_test[duration_col].iloc[patient_index]
    actual_event = y_test[event_col].iloc[patient_index]

    plt.figure(figsize=(12, 8))
    for model_name, model in trained_models.items():
        try:
            surv_funcs = model.predict_survival_function(X_patient)

            times = np.linspace( max(fn.x[0] for fn in surv_funcs), min(fn.x[-1] for fn in surv_funcs), 100)
            surv_probs = np.row_stack([fn(times) for fn in surv_funcs])
            
            plt.step(times / 30.44, surv_probs[0], where="post", label=model_name)
        except Exception as e:
            print(f"Error plotting survival curves for model {model_name}: {e}")

    if actual_line:
        marker_color = 'red' if actual_event else 'blue'
        marker_label = "Event Time" if actual_event else "Censoring Time"
        
        plt.axvline(x=actual_duration_days / 30.44, color=marker_color, linestyle='--', label=marker_label)

    plt.title(f"Predicted {plot_title} Curves")
    plt.xlabel("Time (months)")
    plt.ylabel("Survival Probability")
    plt.legend(loc="best")
    plt.grid(True)
    plt.show()
       

In [None]:
import os
import pandas as pd
import torch

def train_evaluate_models(query, event_col, duration_col, features, configs, title, patient_index = 78, save_models=False):
    
    X_train, X_test, y_train, y_test, encoded_columns = get_data(query, event_col, duration_col, features)
    
    models = {}        
    for model_name, (model_class, model_kwargs) in configs.items():
        if issubclass(model_class, NNSurvivalModel):
            model_kwargs['input_size'] = X_train.shape[1]
        models[model_name] = model_class(**model_kwargs)
    
    trainer = ModelTrainer(models=models, n_splits=5, random_state=42)

    results, trained_models = trainer.train_and_evaluate(
        X_train,
        y_train,
        X_test,
        y_test,
        encoded_columns=encoded_columns,
        event_col= event_col,
        duration_col= duration_col, 
        title = title, 
        save_models = save_models
    )
    
    results_df = pd.DataFrame.from_dict(results, orient='index')
    results_df.reset_index(inplace=True)
    results_df.rename(columns={'index': 'Model'}, inplace=True)
    
    if save_models:
        save_path = "src/models/trained_models"
        os.makedirs(save_path, exist_ok=True)

        csv_file = os.path.join(save_path, f"{title}_model_outcomes.csv")
        if os.path.exists(csv_file):
            existing_df = pd.read_csv(csv_file)
            merged_df = pd.concat([existing_df, results_df]).drop_duplicates(
                subset=['Model'], keep='last'
            )
            merged_df.to_csv(csv_file, index=False)
            print(f"Updated model outcomes saved to {csv_file}")
        else:
            results_df.to_csv(csv_file, index=False)
            print(f"Model outcomes saved to {csv_file}")
    
    plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, title)
    
    return results_df, trained_models
    

#### Best Model Configurations

The best configurations for OS and PFS models were determined using hyperparameter optimization, as defined below in this notebook. These configurations are used to instantiate the models for training and evaluation. The best configurations for the OS and PFS models are stored in `models/model_configurations`.


In [None]:
os_model_outcomes, os_trained_models =  train_evaluate_models(query, event_col='hadSurvivalEvent', duration_col='observedOsFromTreatmentStartDays', features=features, configs=os_configs, title="OS", save_models=True)
os_model_outcomes

In [None]:
pfs_model_outcomes, pfs_trained_models = train_evaluate_models(query, event_col='hadProgressionEvent', duration_col='observedPfsDays', features=features, configs=pfs_configs, title="PFS", save_models=True)
pfs_model_outcomes

## Model Evaluation

### Metric comparison: OS vs. PFS

This section visualizes the comparison of model performance metrics (C-Index, IBS, CE, AUC) for OS and PFS. The bar plots highlight the strengths and weaknesses of each model in the two prediction tasks.


In [None]:
def load_model_outcomes(title, save_path="src/models/trained_models"):
    csv_file = os.path.join(save_path, f"{title}_model_outcomes.csv")
    
    if os.path.exists(csv_file):
        results_df = pd.read_csv(csv_file)
        print(f"Loaded model outcomes from {csv_file}")
    else:
        raise FileNotFoundError(f"No saved outcomes found for {title} in {save_path}")
    
    return results_df

In [None]:
os_model_outcomes = load_model_outcomes("OS")
pfs_model_outcomes = load_model_outcomes("PFS")

In [None]:
os_model_outcomes

In [None]:
pfs_model_outcomes

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import ast

def extract_holdout_metrics(df):
    if df['holdout'].apply(lambda x: isinstance(x, str)).any():
        df['holdout'] = df['holdout'].apply(ast.literal_eval)
    
    holdout_metrics = df['holdout'].apply(pd.Series)
    holdout_metrics['Model'] = df['Model']
    
    return holdout_metrics

def plot_all_metrics(pfs_df, os_df, holdout=True):
       
    if holdout:
        pfs_df = extract_holdout_metrics(pfs_df)
        os_df = extract_holdout_metrics(os_df)
    
    pfs_df['Type'] = 'PFS'
    os_df['Type'] = 'OS'

    combined_df = pd.concat([pfs_df, os_df], ignore_index=True)

    metrics = ['c_index', 'ibs', 'ce', 'auc']

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        ax = axes[i]
        sns.barplot(
            x='Model',
            y=metric,
            hue='Type',
            data=combined_df,
            ax=ax,
            palette='Set1'
        )
        ax.set_title(f'{metric.upper()} Comparison: OS vs. PFS')
        ax.set_xlabel('Model')
        ax.set_ylabel(metric.upper())
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
        ax.legend(title='Type', loc='best')
        ax.grid(axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.show()

In [None]:
plot_all_metrics(pfs_model_outcomes, os_model_outcomes)

### Import Trained Models
The pretrained models are stored in the Google Cloud Storage bucket: `gs://actin-personalization-models-v1/trained_models/`. 

To download the saved models from the bucket to the trained_models map, run the following command in your terminal:

`gsutil -m cp -r gs://actin-personalization-models-v1/trained_models/./trained_models/`

Make sure the trained_models folder is inside the models folder.


In [None]:
import dill
from src.models.survival_models import NNSurvivalModel

def load_trained_model(model_name, title, model_class, model_kwargs={}, save_path="src/models/trained_models"):
    model_file_prefix = os.path.join(save_path, f"{title}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
    
    if model_name in ['AalenAdditive', 'CoxPH', 'RandomSurvivalForest', 'GradientBoosting']:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model
    else:
        if 'input_size' not in model_kwargs:
            model_kwargs['input_size'] = 100
        model = model_class(**model_kwargs)
    
        model.model.net.load_state_dict(torch.load(nn_file, map_location=torch.device('cpu')))
        model.model.net.eval()
        print(f"Model {model_name} loaded from {nn_file}")
        return model
        

In [None]:
def load_all_trained_models(model_specs, title, save_path="src/models/trained_models"):
    loaded_models = {}
    for model_name, (model_class, model_kwargs) in model_specs.items():
        loaded_model = load_trained_model(
            model_name=model_name, 
            title=title, 
            model_class=model_class, 
            model_kwargs=model_kwargs, 
            save_path=save_path
        )
        loaded_models[model_name] = loaded_model
    return loaded_models

os_trained_models = load_all_trained_models(os_configs, title="OS", save_path="src/models/trained_models")
pfs_trained_models = load_all_trained_models(pfs_configs, title="PFS", save_path="src/models/trained_models")

### Time-Dependent ROC-AUC

This section visualizes the ROC curves and computes the AUC for survival models at specific time intervals for both OS and PFS. By evaluating the models' discriminative performance over time, we identify which models perform best at different prediction horizons.

Time interval:
- **Overall Survival (OS)**: For OS, the follow-up times in the dataset extend up to 5 years, allowing us to evaluate model performance over this longer horizon. As survival outcomes often have a broader timespan, a 5-year evaluation provides a comprehensive view of the model's ability to predict long-term survival.

- **Progression-Free Survival (PFS)**: In contrast, PFS events typically occur sooner than OS events. The maximum PFS duration in the dataset is less than 5 years for most patients, with very few outliers exceeding 4 years. Using a 5-year horizon for PFS would skew the evaluation by focusing on a small subset of patients, potentially leading to unreliable results. Instead, we limit the PFS analysis to 3 years, which encompasses the majority of observed progression events while maintaining meaningful statistical power.

The ROC curves for the best-performing models (Gradient Boosting, CoxPH, DeepSurv, and RSF) are plotted for each time interval, showcasing how well the models distinguish patients at risk across the respective timeframes for OS and PFS.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sksurv.util import Surv
from sksurv.metrics import cumulative_dynamic_auc

def calculate_time_dependent_auc_for_models(model_dict, X_train, y_train, X_test, y_test, duration_col, event_col, time_points, title="Time-Dependent AUC"):
    
    y_train_df = pd.DataFrame({'duration': y_train[duration_col], 'event': y_train[event_col]}, index=X_train.index)
    y_train_structured = Surv.from_dataframe('event', 'duration', y_train_df)

    y_test_df = pd.DataFrame({'duration': y_test[duration_col], 'event': y_test[event_col]}, index=X_test.index)
    y_test_structured = Surv.from_dataframe('event', 'duration', y_test_df)

    auc_results = {}

    for model_name, model in model_dict.items():
        model_class = type(model)
        if issubclass(model_class, NNSurvivalModel):
            risk_scores = model.model.predict(X_test.values.astype('float32'), is_dataloader=False).ravel()
        else:
            risk_scores = model.predict(X_test).ravel()

        auc_values, mean_auc = cumulative_dynamic_auc(y_train_structured, y_test_structured, risk_scores, time_points)
        auc_results[model_name] = (auc_values, mean_auc)

    plt.figure(figsize=(10, 6))
    years = [t / 365 for t in time_points]
    for model_name, (auc_values, mean_auc) in auc_results.items():
        plt.plot(years, auc_values, label=f"{model_name} (Mean AUC={mean_auc:.3f})", marker='o')

    plt.xlabel("Time (years)")
    plt.ylabel("AUC")
    plt.title(title)
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()
    
    return auc_results

In [None]:
os_models_to_plot = {
    "GradientBoosting": os_trained_models["GradientBoosting"],
    "CoxPH": os_trained_models["CoxPH"],
    "DeepSurv": os_trained_models["DeepSurv"],
    "RandomSurvivalForest": os_trained_models["RandomSurvivalForest"]
}

os_time_points = [int(round(i * 365 / 4)) for i in range(1, 21)] # Every 3 months (up to 5 years)

calculate_time_dependent_auc_for_models(
    os_models_to_plot, os_X_train, os_y_train, os_X_test, os_y_test, 
    duration_col="observedOsFromTreatmentStartDays", 
    event_col="hadSurvivalEvent", 
    time_points=os_time_points, 
    title="OS"
)

In [None]:
pfs_models_to_plot = {
    "GradientBoosting": pfs_trained_models["GradientBoosting"],
    "CoxPH": pfs_trained_models["CoxPH"],
    "DeepSurv": pfs_trained_models["DeepSurv"],
    "RandomSurvivalForest": pfs_trained_models["RandomSurvivalForest"]
}

pfs_time_points = [int(round(i * 365 / 4)) for i in range(1, 13)] # Every 3 months (up to 3 years)

calculate_time_dependent_auc_for_models(
    pfs_models_to_plot, pfs_X_train, pfs_y_train, pfs_X_test, pfs_y_test, 
    duration_col="observedPfsDays", 
    event_col="hadProgressionEvent", 
    time_points=pfs_time_points, 
    title="PFS"
)

### Performance per Treatment
To assess model performance across treatment groups, we evaluate metrics for each treatment. For every treatment column, the dataset is filtered to include only the relevant patients. The model generates survival or risk predictions for these patients, and metrics are calculated based on their outcomes.

**Note on inflated metrics:**
Since we evaluate the model on the full dataset, including data it was trained on, the performance metrics might be inflated compared to an evaluation on unseen test data. This occurs because the model has already seen and learned from some of these patients during training.

**Why this is still useful**
Despite the potential inflation, this approach remains valuable for exploring the model’s predictions across different treatment groups. It allows us to identify patterns, strengths, or weaknesses in how the model handles specific treatments, which can guide further refinements or highlight treatment groups requiring more robust predictions.


In [None]:
def binarize_scaled_treatment_cols(X: pd.DataFrame, col_prefix="systemicTreatmentPlan") -> pd.DataFrame:
    """
    Given an X DataFrame where certain columns named like
    'systemicTreatmentPlan_*' have been scaled but originally were binary (0/1),
    map them back to 0/1.
    """
    X = X.copy()
    treatment_cols = [c for c in X.columns if col_prefix in c]
    
    for col in treatment_cols:
        unique_vals = sorted(X[col].unique())
        if len(unique_vals) == 2:
            low_val, high_val = unique_vals
            X[col] = X[col].apply(
                lambda x: 1 if abs(x - high_val) < 1e-8 else 0
            )
        else:
            print(f"Warning: {col} has {len(unique_vals)} distinct values, cannot binarize cleanly.")
    
    return X

def evaluate_model_by_treatment(
    model, df, duration_col, event_col, model_name, col_prefix='systemicTreatmentPlan'
):
    trainer = ModelTrainer(models={}, n_splits=5, random_state=42)

    treatment_cols = [col for col in df.columns if col_prefix in col]
    results_list = []

    y_df = pd.DataFrame({
        "duration": df[duration_col],
        "event": df[event_col].astype(bool)
    }, index=df.index)
    y_struct = Surv.from_dataframe("event", "duration", y_df)

    for t_col in treatment_cols:
        X_sub = df[df[t_col] == 1].drop(columns=[event_col, duration_col])
        y_sub = y_df.loc[X_sub.index]
        
        y_sub_struct = Surv.from_dataframe("event", "duration", y_sub)

        metrics = trainer._evaluate_model(
            model=model,
            X_val=X_sub,
            y_train_structured=y_struct, 
            y_val_structured=y_sub_struct,
            model_name=model_name
        )

        results_list.append({
            "treatment_column": t_col,
            "patient_count": len(X_sub),
            **metrics
        })

    return pd.DataFrame(results_list)

In [None]:
os_models_to_evaluate = {
    "GradientBoosting": os_trained_models["GradientBoosting"],
    "RandomSurvivalForest": os_trained_models["RandomSurvivalForest"]
}

os_df = binarize_scaled_treatment_cols(os_df)

for model_name, model_instance in os_models_to_evaluate.items():
    results_df = evaluate_model_by_treatment(
        model=model_instance,
        df=os_df,
        duration_col='observedOsFromTreatmentStartDays',
        event_col='hadSurvivalEvent',
        model_name=model_name
    )
    print(f"Metrics per treatment for {model_name}:\n", results_df)


In [None]:
pfs_models_to_evaluate = {
    "GradientBoosting": pfs_trained_models["GradientBoosting"],
    "RandomSurvivalForest": pfs_trained_models["RandomSurvivalForest"]
}

pfs_df = binarize_scaled_treatment_cols(pfs_df)

for model_name, model_instance in pfs_models_to_evaluate.items():
    results_df = evaluate_model_by_treatment(
        model=model_instance,
        df=pfs_df,
        duration_col='observedPfsDays',
        event_col='hadProgressionEvent',
        model_name=model_name
    )
    print(f"Metrics per treatment for {model_name}:\n", results_df)


## Model Interpretation

### Feature Importance

Feature importance analysis helps us identify which features most strongly influence survival predictions across various models. In this section SHAP (SHapley Additive exPlanations) is used for all models to ensure uniformity and interpretability. SHAP values provide a consistent and locally accurate measure of feature importance for individual predictions.

How SHAP is used:
- For classical models like CoxPH and Aalen Additive, SHAP values are calculated based on risk scores or cumulative hazard coefficients. Custom prediction functions are used when necessary (e.g., for Aalen Additive) to align feature importance with model-specific outputs.
- For tree-based models like Random Survival Forest (RSF) and Gradient Boosting Survival Model (GBM), SHAP values replace traditional feature importance metrics to ensure consistency.
- For neural network-based models like DeepSurv and DeepHit, SHAP values are derived using the model's prediction function.

#### Visualization
SHAP provides the following insights:
- Summary Plot - Bar: Displays the average magnitude of SHAP values for each feature, indicating the overall importance of features in the model.
- Summary Plot - Dot: Highlights the distribution of SHAP values for each feature, showing their impact across different samples.


In [None]:
import shap

def nn_predict(X, model, X_train):
    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X, columns=X_train.columns)
    X_tensor = X.values.astype('float32')
    return model.model.predict(X_tensor)

def custom_aalen_predict(X, model):
    """
    Custom predict function for AalenAdditiveModel.
    Aligns cumulative hazard coefficients with input features.
    """
    cumulative_coefs = model.model.cumulative_hazards_
    X = X[model.selected_features].copy()

    # Interpolate coefficients at the latest time point
    latest_coefs = cumulative_coefs.iloc[-1].values
    
    if len(latest_coefs) > X.shape[1]:
        X = X.copy()
        X.insert(0, "Intercept", 1.0)  

    X_array = X.values
    risk_scores = np.einsum('ij,j->i', X_array, latest_coefs)
    return risk_scores

def shap_interpret_model(model_name, model, X_train, max_features=20, shap_sample=200):

    X_sample = X_train.sample(min(shap_sample, len(X_train)), random_state=42)

    predict_functions = {
        'AalenAdditive': lambda X: custom_aalen_predict(X, model),
        'default': model.predict
    }

    if model_name == 'AalenAdditive':
        prediction_fn = predict_functions['AalenAdditive']
    else:
        try:
            model.predict(X_sample.head(1))
            prediction_fn = predict_functions['default']
        except:
            prediction_fn = lambda x: nn_predict(x, model, X_train)

    explainer = shap.Explainer(prediction_fn, X_sample)
    shap_values = explainer(X_sample)

    print(f"SHAP Summary for {model_name}:")

    if len(shap_values.values.shape) == 3: #If time dimension present
        aggregated_shap = shap_values.values.mean(axis=1)
        shap.summary_plot(aggregated_shap, features=X_sample, plot_type="bar", max_display=max_features)
        shap.summary_plot(aggregated_shap, features=X_sample, max_display=max_features)
    else:
        shap.summary_plot(shap_values, features=X_sample, plot_type="bar", max_display=max_features)
        shap.summary_plot(shap_values, features=X_sample, max_display=max_features)


In [None]:
for model_name, model_instance in os_trained_models.items():
    print(f"\n--- Interpreting {model_name} ---")
    shap_interpret_model(model_name, model_instance, os_X_train)

In [None]:
for model_name, model_instance in pfs_trained_models.items():
    print(f"\n--- Interpreting {model_name} ---")
    shap_interpret_model(model_name, model_instance, pfs_X_train)

### Model Output

#### Comparison of Treatments
In this section the predicted survival probabilities are visualized for a single patient under different treatment scenarios. By simulating the patient receiving each available treatment, we can compare how the model predicts their survival trajectory across treatments.

This analysis provides insights into the model's predictions for different treatment options, helping to identify potentially better treatment choices for the patient based on the predicted survival probabilities over time.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_survival_curves_for_all_treatments(model, X_test, y_test, patient_index, event_col, duration_col, col_prefix='systemicTreatmentPlan', plot_title='Survival Curves by Treatment'):

    X_patient_original = X_test.iloc[[patient_index]].copy()
    actual_duration = y_test.loc[X_patient_original.index, duration_col].values[0]
    actual_event = y_test.loc[X_patient_original.index, event_col].values[0]
    
    treatment_cols = [col for col in X_test.columns if col_prefix in col]
    
    colors = plt.cm.tab20(np.linspace(0, 1, len(treatment_cols)))
    plt.figure(figsize=(10, 6))
    
    times = np.arange(1, 1826, 30)

    for i, t_col in enumerate(treatment_cols):
        X_patient = X_patient_original.copy()

        for col in treatment_cols:
            X_patient[col] = 0
            
        X_patient[t_col] = 1

        surv_funcs = model.predict_survival_function(X_patient) 
            
        plt.step(times, surv_funcs[0](times), where="post", label=t_col[len(col_prefix) + 1:], color=colors[i])

    plt.title(plot_title)
    plt.xlabel("Time (days)")
    plt.ylabel("Survival Probability")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
os_best_model = "RandomSurvivalForest"
plot_survival_curves_for_all_treatments(
    model=os_trained_models[os_best_model],
    X_test=os_X_test,
    y_test=os_y_test,
    patient_index=100,
    event_col="hadSurvivalEvent",
    duration_col="observedOsFromTreatmentStartDays",
    col_prefix="systemicTreatmentPlan",
    plot_title=f"OS: One Patient, All Treatments ({os_best_model})"
)


In [None]:
pfs_best_model = "GradientBoosting"
plot_survival_curves_for_all_treatments(
    model=pfs_trained_models[pfs_best_model],
    X_test=pfs_X_test,
    y_test=pfs_y_test,
    patient_index=100,
    event_col="hadProgressionEvent",
    duration_col="observedPfsDays",
    col_prefix="systemicTreatmentPlan",
    plot_title=f"PFS: One Patient, All Treatments ({pfs_best_model})"
)

#### Comparison of Models
Next, we focus on the patient's actual treatment choice and compare the predictions from the two best-performing models. By evaluating the survival curves generated by both models for the chosen treatment, we assess their agreement and gain a deeper understanding of the patient-specific predictions.

In [None]:
plot_survival_curves_for_patient(
    trained_models=os_models_to_evaluate,
    X_test=os_X_test,
    y_test=os_y_test,
    patient_index=101,
    event_col="hadSurvivalEvent",
    duration_col="observedOsFromTreatmentStartDays",
    plot_title="OS (GradientBoosting & RSF)"
)

In [None]:
plot_survival_curves_for_patient(
    trained_models=pfs_models_to_evaluate,
    X_test=pfs_X_test,
    y_test=pfs_y_test,
    patient_index=100,
    event_col="hadProgressionEvent",
    duration_col="observedPfsDays",
    plot_title="PFS (GradientBoosting & RSF)"
)

## Feature Selection

Explicit feature selection is applied to CoxPH and Aalen Additive model to improve interpretability and reduce noise:

- `CoxPH`: 
    - Features with high p-values (non-significant) are removed.
    - Multicollinearity is addressed by excluding highly correlated predictors.
- `Aalen Additive`:
    - Features with low cumulative impact (mean absolute coefficients near zero) are excluded.
    
Other models, such as tree-based or neural survival models, inherently manage feature selection through their architecture or regularization techniques, making explicit feature filtering unnecessary.


In [None]:
def feature_select_coxph(model, X_train, threshold=0.01):
    """
    For CoxPH: Remove features with abs(coef) < threshold.
    """
    if hasattr(model.model, 'coef_'):
        coefs = model.model.coef_
        feature_mask = np.abs(coefs) > threshold
        retained = model.selected_features[feature_mask]
        if len(retained) == 0:
            retained = model.selected_features
        return retained
    else:
        return model.selected_features

def feature_select_aalen_additive(model, X_train, threshold=0.001):
    """
    For AalenAdditive: Remove features with mean absolute cumulative coefficient < threshold.
    """
    if not model.model.cumulative_hazards_.empty:
        cum_haz = model.model.cumulative_hazards_
        mean_abs_coefs = cum_haz.abs().mean()
        feature_mask = mean_abs_coefs > threshold
        retained = mean_abs_coefs.index[feature_mask]
        
        return retained
    else:
        return model.selected_features

In [None]:
def refit_model_with_selected_features(model_name, original_model, X_train, y_train, X_test, y_test, retained_features, title, duration_col, event_col, save_models=True):
    """
    Refit the given model with the selected features and evaluate it using the ModelTrainer's _evaluate_model method.
    """
    
    retained_features = [f for f in retained_features if f != "Intercept"]

    y_train_df = pd.DataFrame({'duration': y_train[duration_col], 'event': y_train[event_col]}, index=X_train.index)
    y_train_structured = Surv.from_dataframe('event', 'duration', y_train_df)

    y_test_df = pd.DataFrame({'duration': y_test[duration_col], 'event': y_test[event_col]}, index=y_test.index)
    y_test_structured = Surv.from_dataframe('event', 'duration', y_test_df)

    # Refit the model with the reduced features
    model_class = type(original_model)
    model_kwargs = getattr(original_model, 'kwargs', {})
    new_model = model_class(**model_kwargs)
    new_model.fit(X_train[retained_features], y_train_structured)

    trainer = ModelTrainer(models={}, n_splits=5, random_state=42)
    holdout_metrics = trainer._evaluate_model(
        new_model,
        X_test[retained_features],
        y_train_structured,
        y_test_structured,
        y_test_df,
        model_name,
        event_col
    )
    print(f"{model_name} Feature-Selected Hold-Out Results: {holdout_metrics}")

    if save_models:
        save_new_model(new_model, model_name, title)

    return new_model

def save_new_model(model, model_name, title, suffix="_feature_selected", save_path="src/models/trained_models"):
    new_model_name = model_name + suffix
    model_file = os.path.join(save_path, f"{title}_{new_model_name}")
    with open(model_file + ".pkl", "wb") as f:
        dill.dump(model, f)
    print(f"New model with feature selection saved as {title}_{new_model_name}.pkl")


In [None]:
def select_features_and_refit(X_train, y_train, X_test, y_test, configs, duration_col, event_col, title):
    coxph_model = load_trained_model("CoxPH", title, CoxPHModel, model_kwargs=configs['CoxPH'][1])
    aalen_model = load_trained_model("AalenAdditive", title, AalenAdditiveModel, model_kwargs=configs['AalenAdditive'][1])
    
    coxph_retained = feature_select_coxph(coxph_model, X_train, threshold=0.01)
    aalen_retained = feature_select_aalen_additive(aalen_model, X_train, threshold=0.001)

    new_coxph_model = refit_model_with_selected_features(
        "CoxPH", coxph_model, X_train, y_train, X_test, y_test,
        coxph_retained, title, duration_col=duration_col, event_col=event_col, save_models=True
    )

    new_aalen_model = refit_model_with_selected_features(
        "AalenAdditive", aalen_model, X_train, y_train, X_test, y_test,
        aalen_retained, title, duration_col=duration_col, event_col=event_col, save_models=True
    )

In [None]:
select_features_and_refit(os_X_train, os_y_train, os_X_test, os_y_test, os_configs, 'observedOsFromTreatmentStartDays', 'hadSurvivalEvent', title='OS')

In [None]:
select_features_and_refit(pfs_X_train, pfs_y_train, pfs_X_test, pfs_y_test, pfs_configs, 'observedPfsDays', 'hadProgressionEvent', title='PFS')

## Hyperparameter Optimization

Hyperparameter optimization is performed for each model using a defined grid of parameters. The `random_parameter_search` function samples configurations to identify the optimal parameters for each model (can be found in `models/hyperparameter_optimization`). This ensures that models achieve their best performance for the given data.

After optimization the results for both OS and PFS were stored in `models/configs/model_configurations.py`.

In [None]:
def optimize_hyperparameters(query, event_col, duration_col, features, metric_comparison):
    
    X_train, X_test, y_train, y_test, encoded_columns = get_data(query, event_col, duration_col, features)
          
    models = {
        'DeepSurv': DeepSurv(input_size=X_train.shape[1]),
        'LogisticHazardModel': LogisticHazardModel(input_size=X_train.shape[1]),
        'DeepHitModel': DeepHitModel(input_size=X_train.shape[1]), 
        'PCHazardModel': PCHazardModel(input_size=X_train.shape[1]), 
        'MTLRModel': MTLRModel(input_size=X_train.shape[1]),
        'AalenAdditive': AalenAdditiveModel(),
        'CoxPH': CoxPHModel(),
        'RandomSurvivalForest': RandomSurvivalForestModel(),
        'GradientBoosting': GradientBoostingSurvivalModel(),
    }
    
    best_models, all_results = hyperparameter_search(
        X_train, y_train, X_test, y_test,
        treatment_col='systemicTreatmentPlan', encoded_columns=encoded_columns,
        event_col=event_col, duration_col=duration_col,
        base_models=models, param_grids=param_grids, metric_comparison=metric_comparison
    )
       
    return best_models, all_results         

In [None]:
os_best_models, os_results = optimize_hyperparameters(query, event_col = 'hadSurvivalEvent', duration_col = 'observedOsFromTreatmentStartDays', features = features, metric_comparison='auc')

In [None]:
pfs_best_models, pfs_results = optimize_hyperparameters(query, event_col = 'hadProgressionEvent', duration_col = 'observedPfsDays', features=features, metric_comparison='auc')