# Predictive Algorithms for Survival Analysis

This notebook demonstrates the pipeline for developing and evaluating predictive algorithms in survival analysis. The primary objective is to model and predict **overall survival (OS)** and **progression-free survival (PFS)** for patients. Using both classical statistical methods and state-of-the-art deep learning techniques, the notebook covers the entire process, including:
- Data Preprocessing: Preparing survival datasets for analysis, ensuring compatibility with various model types.
- Model Training: Building survival models tailored to predict survival outcomes and handle censored data.
- Hyperparameter Optimization: Fine-tuning models for optimal performance.

The interpretation and evaluation of the models can be found in `predictive_algorithms_interpretation.ipynb` and contains:
- Performance Evaluation: Comparing models based on metrics such as concordance index (C-Index), integrated Brier score (IBS), calibration error (CE), and time-dependent AUC.
- Visualization: Generating survival curves and feature importance plots to interpret model predictions and uncover key insights.

This workflow provides a framework to explore survival modeling techniques and tailor them to specific datasets and objectives.

In the file `utils/settings.py` all the experiment settings can be set (e.g. OS or PFS, grouped treatments or not), then the experiment can be run in this notebook.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import torch

os.environ["OMP_NUM_THREADS"]   = "4"
os.environ["MKL_NUM_THREADS"]   = "4"

torch.set_num_threads(4)
torch.set_num_interop_threads(4)

In [None]:
os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

from models import *
from utils.settings import settings
from data.data_processing import DataSplitter, DataPreprocessor

## Data Preprocessing

In this section, we set up the data pipeline for survival analysis. The `DataSplitter` and `DataPreprocessor` classes from `data/data_processing.py` are used to load, preprocess, and split the data into training and testing sets. This ensures the survival data is structured appropriately for model training.


In [None]:
from sksurv.util import Surv

def get_data():
    preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)
    splitter = DataSplitter(test_size=0.1, random_state=42)
    
    df, features, encoded_columns = preprocessor.preprocess_data()
                          
    y = Surv.from_dataframe(event=settings.event_col, time=settings.duration_col, data=df)
    
    X_train, X_test, y_train, y_test = splitter.split(df[features], y, encoded_columns)
    
    return df, X_train, X_test, y_train, y_test, encoded_columns

In [None]:
df, X_train, X_test, y_train, y_test, encoded_columns = get_data()

## Train and Evaluate Models

This section defines the function `train_evaluate_models`, which trains various survival models using predefined configurations (as can all be found in `models/configs/...`. The trained models are evaluated using the C-index, Integrated Brier Score (IBS), Calibration Error (CE) and Area Under the Curve (AUC) as explained in `predictive_algorithms_interpretation.ipynb`. 

Together, these metrics provide a comprehensive evaluation of the models' predictive performance, capturing different aspects of accuracy, discrimination, and calibration.


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_different_models_survival_curves(trained_models, X_test, y_test, patient_index, actual_line = True):
    """
    Plot survival curves for a specific patient using trained models.
    """
    X_patient = X_test.iloc[[patient_index]]
    actual_duration_days = y_test[patient_index][settings.duration_col]
    actual_event = y_test[patient_index][settings.event_col]

    plt.figure(figsize=(12, 8))
    for model_name, model in trained_models.items():
        try:
            surv_funcs = model.predict_survival_function(X_patient)

            times = np.linspace( max(fn.x[0] for fn in surv_funcs), min(fn.x[-1] for fn in surv_funcs), 100)
            surv_probs = np.row_stack([fn(times) for fn in surv_funcs])
            
            plt.step(times / 30.44, surv_probs[0], where="post", label=model_name)
        except Exception as e:
            print(f"Error plotting survival curves for model {model_name}: {e}")

    if actual_line:
        marker_color = 'red' if actual_event else 'blue'
        marker_label = "Event Time" if actual_event else "Censoring Time"
        
        plt.axvline(x=actual_duration_days / 30.44, color=marker_color, linestyle='--', label=marker_label)

    plt.title(f"Predicted {settings.outcome} Curves")
    plt.xlabel("Time (months)")
    plt.ylabel("Survival Probability")
    plt.legend(loc="best")
    plt.grid(True)
    plt.show()

In [None]:
import pandas as pd
import torch
import json
import importlib

def save_model_output(results_df):
    os.makedirs(settings.save_path, exist_ok=True)

    csv_file = os.path.join(settings.save_path, f"{settings.outcome}_model_outcomes.csv")
    if os.path.exists(csv_file):
        existing_df = pd.read_csv(csv_file)
        merged_df = pd.concat([existing_df, results_df]).drop_duplicates(
            subset=['Model'], keep='last'
        )
        merged_df.to_csv(csv_file, index=False)
        print(f"Updated model outcomes saved to {csv_file}")
    else:
        results_df.to_csv(csv_file, index=False)
        print(f"Model outcomes saved to {csv_file}")

def train_evaluate_models(configs, max_time=1825, patient_index = 78):
    
    df, X_train, X_test, y_train, y_test, encoded_columns = get_data()
    if settings.save_models:
        with open(f"{settings.save_path}/preprocessor/label_encodings.json", "w") as f:
            json.dump(encoded_columns, f)
        
    models = {}        
    for model_name, (model_class, model_kwargs) in configs.items():
        if issubclass(model_class, NNSurvivalModel):
            model_kwargs['input_size'] = X_train.shape[1]
        models[model_name] = model_class(**model_kwargs)
    
    trainer = ModelTrainer(models=models)
    
    results, trained_models = trainer.train_and_evaluate(
        X_train,
        y_train,
        X_test,
        y_test,
        encoded_columns=encoded_columns,
    )
    
    results_df = pd.DataFrame.from_dict(results, orient='index')
    results_df.reset_index(inplace=True)
    results_df.rename(columns={'index': 'Model'}, inplace=True)
    
    if settings.save_models:
        save_model_output(results_df)
        ExperimentConfig.update_model_hyperparams(configs)
        
    plot_different_models_survival_curves(trained_models, X_test, y_test, patient_index)
    
    return results_df, trained_models

## Hyperparameter Optimization

Hyperparameter optimization is performed for each model using a defined grid of parameters. The `random_parameter_search` function samples configurations to identify the optimal parameters for each model (can be found in `models/hyperparameter_optimization`). This ensures that models achieve their best performance for the given data.

After optimization the models, including results are stored in `models/trained_models`. The optimal configurations are stored in `models/configs/model_hyperparams.json`.

In [None]:
def optimize_hyperparameters():
    
    df, X_train, X_test, y_train, y_test, encoded_columns = get_data()
          
    models = {
        'DeepSurv': DeepSurv(input_size=X_train.shape[1], use_attention=False),
        'DeepSurv_attention': DeepSurv(input_size=X_train.shape[1], use_attention=True),
        
        'LogisticHazardModel': LogisticHazardModel(input_size=X_train.shape[1], use_attention=False),
        'LogisticHazardModel_attention': LogisticHazardModel(input_size=X_train.shape[1], use_attention=True),
 
        'DeepHitModel': DeepHitModel(input_size=X_train.shape[1], use_attention=False),
        'DeepHitModel_attention': DeepHitModel(input_size=X_train.shape[1], use_attention=True),
        
        'PCHazardModel': PCHazardModel(input_size=X_train.shape[1], use_attention=False), 
        'PCHazardModel_attention': PCHazardModel(input_size=X_train.shape[1], use_attention=True),
        
        'MTLRModel': MTLRModel(input_size=X_train.shape[1], use_attention=False),
        'MTLRModel_attention': MTLRModel(input_size=X_train.shape[1], use_attention=True),

        'CoxPH': CoxPHModel(),
        'RandomSurvivalForest': RandomSurvivalForestModel(),
        'GradientBoosting': GradientBoostingSurvivalModel(),
    }
    
    best_models, all_results = hyperparameter_search(
        X_train, y_train, X_test, y_test,
        encoded_columns=encoded_columns,
        base_models=models, param_grids=curve_param_grids)
           
    return best_models, all_results  

In [None]:
best_models, results = optimize_hyperparameters()

#### Best Model Configurations

The configurations which were determined best can be used to instantiate the models for training and evaluation. The best configurations are stored in `models/configs/model_hyperparams.json`.

Once you've trained or updated the models locally, you can upload the entire prediction/trained models back to the bucket with:

`gsutil -m cp -r /data/patient_like_me/prediction/trained_models/ gs://actin-personalization-models-v1/`


In [None]:
config = ExperimentConfig(settings.json_config_file)
configs = config.load_model_configs()

In [None]:
model_outcomes, trained_models =  train_evaluate_models(configs=configs)
model_outcomes

## Feature Selection

Explicit feature selection is applied to CoxPH and Aalen Additive model to improve interpretability and reduce noise:

- `CoxPH`: 
    - Features with high p-values (non-significant) are removed.
    - Multicollinearity is addressed by excluding highly correlated predictors.
- `Aalen Additive`:
    - Features with low cumulative impact (mean absolute coefficients near zero) are excluded.
    
Other models, such as tree-based or neural survival models, inherently manage feature selection through their architecture or regularization techniques, making explicit feature filtering unnecessary.


In [None]:
def feature_select_coxph(model, X_train, threshold=0.01):
    """
    For CoxPH: Remove features with abs(coef) < threshold.
    """
    if hasattr(model.model, 'coef_'):
        coefs = model.model.coef_
        feature_mask = np.abs(coefs) > threshold
        retained = model.selected_features[feature_mask]
        if len(retained) == 0:
            retained = model.selected_features
        return retained
    else:
        return model.selected_features

def feature_select_aalen_additive(model, X_train, threshold=0.001):
    """
    For AalenAdditive: Remove features with mean absolute cumulative coefficient < threshold.
    """
    if not model.model.cumulative_hazards_.empty:
        cum_haz = model.model.cumulative_hazards_
        mean_abs_coefs = cum_haz.abs().mean()
        feature_mask = mean_abs_coefs > threshold
        retained = mean_abs_coefs.index[feature_mask]
        
        return retained
    else:
        return model.selected_features

In [None]:
def refit_model_with_selected_features(model_name, original_model, X_train, y_train, X_test, y_test, retained_features):
    
    retained_features = [f for f in retained_features if f != "Intercept"]

    y_train_df = pd.DataFrame({'duration': y_train[settings.duration_col], 'event': y_train[settings.event_col]}, index=X_train.index)
    y_train_structured = Surv.from_dataframe('event', 'duration', y_train_df)

    y_test_df = pd.DataFrame({'duration': y_test[settings.duration_col], 'event': y_test[settings.event_col]}, index=y_test.index)
    y_test_structured = Surv.from_dataframe('event', 'duration', y_test_df)

    model_class = type(original_model)
    model_kwargs = getattr(original_model, 'kwargs', {})
    new_model = model_class(**model_kwargs)
    new_model.fit(X_train[retained_features], y_train_structured)

    trainer = ModelTrainer(models={}, n_splits=5, random_state=42, max_time=settings.max_time)
    holdout_metrics = trainer._evaluate_model(
        new_model,
        X_test[retained_features],
        y_train_structured,
        y_test_structured,
        y_test_df,
        model_name
    )
    print(f"{model_name} Feature-Selected Hold-Out Results: {holdout_metrics}")

    if save_models:
        save_new_model(new_model, model_name)

    return new_model

def save_new_model(model, model_name, suffix="_feature_selected"):
    new_model_name = model_name + suffix
    model_file = os.path.join(settings.save_path, f"{settings.outcome}_{new_model_name}")
    with open(model_file + ".pkl", "wb") as f:
        dill.dump(model, f)
    print(f"New model with feature selection saved as {settings.outcome}_{new_model_name}.pkl")

In [None]:
def select_features_and_refit(X_train, y_train, X_test, y_test):
    coxph_model = load_trained_model("CoxPH", CoxPHModel, model_kwargs=configs['CoxPH'][1])
    aalen_model = load_trained_model("AalenAdditive", AalenAdditiveModel, model_kwargs=configs['AalenAdditive'][1])
    
    coxph_retained = feature_select_coxph(coxph_model, X_train, threshold=0.01)
    aalen_retained = feature_select_aalen_additive(aalen_model, X_train, threshold=0.001)

    new_coxph_model = refit_model_with_selected_features("CoxPH", coxph_model, X_train, y_train, X_test, y_test, coxph_retained)
    new_aalen_model = refit_model_with_selected_features("AalenAdditive", aalen_model, X_train, y_train, X_test, y_test, aalen_retained)

In [None]:
select_features_and_refit(X_train, y_train, X_test, y_test, configs)