# Predictive Algorithms

This notebook demonstrates the development and evaluation of predictive algorithms for survival analysis. The aim is to train, evaluate, and optimize various survival models to predict overall survival (OS) and progression-free survival (PFS) for patients, using both classical and deep learning techniques. It includes data preprocessing, model training, hyperparameter optimization, and visualization of survival curves.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.models import *
from src.data.data_processing import DataSplitter, DataPreprocessor
from src.data.lookups import LookupManager

In [None]:
db_config_path = '/home/jupyter/.my.cnf'
db_name = 'actin_personalization'
query = "SELECT * FROM knownPalliativeTreatments"

preprocessor = DataPreprocessor(db_config_path, db_name)

lookup_manager = LookupManager()
features = lookup_manager.features

## Data Preprocessing

In this section, we set up the data pipeline for survival analysis. The `DataSplitter` and `DataPreprocessor` classes are used to load, preprocess, and split the data into training and testing sets. This ensures the survival data is structured appropriately for model training.


In [None]:
def get_data(query, event_col, duration_col, features):
    splitter = DataSplitter(test_size=0.1, random_state=42)
    
    df, features, encoded_columns = preprocessor.preprocess_data(query, duration_col, event_col, features)
                          
    y = Surv.from_dataframe(event=event_col, time=duration_col, data=df)
    X_train, X_test, y_train, y_test = splitter.split(df[features], df, 'systemicTreatmentPlan', encoded_columns)
    
    return X_train, X_test, y_train, y_test, encoded_columns

## Train and Evaluate Models

This section defines the function `train_evaluate_models`, which trains various survival models using predefined configurations. The trained models are evaluated using the following metrics:

- **C-Index**: The Concordance Index measures how well the predicted survival times align with the actual outcomes. It is a measure of discrimination, indicating the model's ability to correctly rank the survival times of patients. A higher value indicates better predictive accuracy.

- **Integrated Brier Score (IBS)**: This metric evaluates the accuracy of the survival probability predictions over time. It combines the squared differences between predicted and actual survival probabilities, weighted by the survival distribution. Lower values indicate better predictive performance.

- **Calibration Error (CE)**: Calibration error assesses how well the predicted survival probabilities match the observed probabilities. It indicates whether the model is systematically overestimating or underestimating survival probabilities. Lower values signify better calibration.

- **Area Under the Curve (AUC)**: For survival models, AUC is typically computed over a time-dependent ROC curve, reflecting the model's discrimination ability at different time points. Higher AUC values indicate better overall performance.

Together, these metrics provide a comprehensive evaluation of the models' predictive performance, capturing different aspects of accuracy, discrimination, and calibration.


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, plot_title):
    """
    Plot survival curves for a specific patient using trained models.
    """
    X_patient = X_test.iloc[[patient_index]]
    actual_duration_days = y_test[duration_col].iloc[patient_index]
    actual_event = y_test[event_col].iloc[patient_index]

    plt.figure(figsize=(12, 8))
    for model_name, model in trained_models.items():
        try:
            surv_funcs = model.predict_survival_function(X_patient)

            times = np.linspace( max(fn.x[0] for fn in surv_funcs), min(fn.x[-1] for fn in surv_funcs), 100)
            surv_probs = np.row_stack([fn(times) for fn in surv_funcs])
            
            plt.step(times / 30.44, surv_probs[0], where="post", label=model_name)
        except Exception as e:
            print(f"Error plotting survival curves for model {model_name}: {e}")

    marker_color = 'red' if actual_event else 'blue'
    marker_label = "Event Time" if actual_event else "Censoring Time"
    plt.axvline(x=actual_duration_days / 30.44, color=marker_color, linestyle='--', label=marker_label)

    plt.title(f"Predicted {plot_title} Curves")
    plt.xlabel("Time (months)")
    plt.ylabel("Survival Probability")
    plt.legend(loc="best")
    plt.grid(True)
    plt.show()
       

In [None]:
from sksurv.util import Surv
import os
import pandas as pd
import joblib
import torch

def train_evaluate_models(query, event_col, duration_col, features, configs, title, patient_index = 78, save_models=True):
    
    X_train, X_test, y_train, y_test, encoded_columns = get_data(query, event_col, duration_col, features)
    
    models = {}        
    for model_name, (model_class, model_kwargs) in configs.items():
        if issubclass(model_class, NNSurvivalModel):
            model_kwargs['input_size'] = X_train.shape[1]
        models[model_name] = model_class(**model_kwargs)
    
    trainer = ModelTrainer(models=models, n_splits=5, random_state=42)

    results, trained_models = trainer.train_and_evaluate(
        X_train,
        y_train,
        X_test,
        y_test,
        treatment_col='systemicTreatmentPlan',
        encoded_columns=encoded_columns,
        event_col= event_col,
        duration_col= duration_col, 
        title = title
    )
    
    results_df = pd.DataFrame.from_dict(results, orient='index')
    results_df.reset_index(inplace=True)
    results_df.rename(columns={'index': 'Model'}, inplace=True)
    
    if save_models:
        save_path = "src/models/trained_models"
        os.makedirs(save_path, exist_ok=True)

        csv_file = os.path.join(save_path, f"{title}_model_outcomes.csv")
        results_df.to_csv(csv_file, index=False)
        print(f"Model outcomes saved to {csv_file}")
    
    plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, title)
    
    return results_df, trained_models
    

#### Best Model Configurations

The best configurations for OS and PFS models were determined using hyperparameter optimization, as defined below in this notebook. These configurations are used to instantiate the models for training and evaluation. The best configurations for the OS and PFS models are stored in `models/model_configurations`.


In [None]:
os_model_outcomes, os_trained_models =  train_evaluate_models(query, event_col='hadSurvivalEvent', duration_col='observedOsFromTreatmentStartDays', features=features, configs=os_configs, title="OS", save_models=True)
os_model_outcomes

In [None]:
pfs_model_outcomes, pfs_trained_models = train_evaluate_models(query, event_col='hadProgressionEvent', duration_col='observedPfsDays', features=features, configs=pfs_configs, title="PFS")
pfs_model_outcomes

### Metric comparison: OS vs. PFS

This section visualizes the comparison of model performance metrics (C-Index, IBS, CE, AUC) for OS and PFS. The bar plots highlight the strengths and weaknesses of each model in the two prediction tasks.


In [None]:
def load_model_outcomes(title, save_path="src/models/trained_models"):
    csv_file = os.path.join(save_path, f"{title}_model_outcomes.csv")
    
    if os.path.exists(csv_file):
        results_df = pd.read_csv(csv_file)
        print(f"Loaded model outcomes from {csv_file}")
    else:
        raise FileNotFoundError(f"No saved outcomes found for {title} in {save_path}")
    
    return results_df

In [None]:
os_model_outcomes = load_model_outcomes("OS")
pfs_model_outcomes = load_model_outcomes("PFS")

In [None]:
os_model_outcomes

In [None]:
pfs_model_outcomes

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ast

def extract_holdout_metrics(df):
    if df['holdout'].apply(lambda x: isinstance(x, str)).any():
        df['holdout'] = df['holdout'].apply(ast.literal_eval)
    
    holdout_metrics = df['holdout'].apply(pd.Series)
    holdout_metrics['Model'] = df['Model']
    
    return holdout_metrics

def plot_all_metrics(pfs_df, os_df, holdout=True):
       
    if holdout:
        pfs_df = extract_holdout_metrics(pfs_df)
        os_df = extract_holdout_metrics(os_df)
    
    pfs_df['Type'] = 'PFS'
    os_df['Type'] = 'OS'

    combined_df = pd.concat([pfs_df, os_df], ignore_index=True)

    metrics = ['c_index', 'ibs', 'ce', 'auc']

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        ax = axes[i]
        sns.barplot(
            x='Model',
            y=metric,
            hue='Type',
            data=combined_df,
            ax=ax,
            palette='Set1'
        )
        ax.set_title(f'{metric.upper()} Comparison: OS vs. PFS')
        ax.set_xlabel('Model')
        ax.set_ylabel(metric.upper())
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
        ax.legend(title='Type', loc='best')
        ax.grid(axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.show()

In [None]:
plot_all_metrics(pfs_model_outcomes, os_model_outcomes)

## Model Interpretation

Import the saved models if they were trained previously:

In [None]:
import os
import torch
import joblib
import dill
from src.models.survival_models import NNSurvivalModel

def load_trained_model(model_name, title, model_class, model_kwargs={}, save_path="src/models/trained_models"):
    model_file_prefix = os.path.join(save_path, f"{title}_{model_name}")
    nn_file = model_file_prefix + ".pt"
    sk_file = model_file_prefix + ".pkl"
    
    if issubclass(model_class, NNSurvivalModel):
        if 'input_size' not in model_kwargs:
            model_kwargs['input_size'] = 100
        model = model_class(**model_kwargs)
    
        model.model.net.load_state_dict(torch.load(nn_file, map_location=torch.device('cpu')))
        model.model.net.eval()
        print(f"Model {model_name} loaded from {nn_file}")
        return model
    else:
        with open(sk_file, "rb") as f:
            model = dill.load(f)
        print(f"Model {model_name} loaded from {sk_file}")
        return model

In [None]:
def load_all_trained_models(model_specs, title, save_path="src/models/trained_models"):
    loaded_models = {}
    for model_name, (model_class, model_kwargs) in model_specs.items():
        loaded_model = load_trained_model(
            model_name=model_name, 
            title=title, 
            model_class=model_class, 
            model_kwargs=model_kwargs, 
            save_path=save_path
        )
        loaded_models[model_name] = loaded_model
    return loaded_models

os_trained_models = load_all_trained_models(os_configs, title="OS", save_path="src/models/trained_models")
pfs_trained_models = load_all_trained_models(pfs_configs, title="PFS", save_path="src/models/trained_models")

### Feature Importance

Once the Random Survival Forest (RSF) and Gradient Boosting Survival Model (GBM) have been trained, we can extract feature importance scores. These importance scores help us identify which features most strongly influence survival predictions. High importance scores mean the feature plays a significant role in splitting the data and improving the model’s predictive power.

For complex, non-linear models such as DeepSurv, DeepHit, and other neural network-based models, we will use SHAP. SHAP values provide a consistent and locally accurate measure of feature importance for individual predictions.


In [None]:
#Still busy with implementation

## Feature Selection

To ensure optimal model performance and interpretability, we apply feature selection methods tailored to the model type:
- **CoxPH**: Identify significant linear predictors using hazard ratios and p-values.
- **Tree-Based Models (RSF, GradientBoosting)**: Use feature importance scores to retain only the most influential variables.
- **Neural Models (DeepSurv, DeepHit, PCHazard, MTLR)**: Apply regularization techniques (e.g., Elastic Net) to encourage sparse solutions and reduce unnecessary complexity.

By combining these approaches, we refine the input space to only include variables that provide meaningful predictive value.


In [None]:
#Still needs to be implemented

## Hyperparameter Optimization

Hyperparameter optimization is performed for each model using a defined grid of parameters. The `random_parameter_search` function samples configurations to identify the optimal parameters for each model (can be found in `models/hyperparameter_optimization`). This ensures that models achieve their best performance for the given data.


In [None]:
def optimize_hyperparameters(query, event_col, duration_col, features, metric_comparison):
    
    X_train, X_test, y_train, y_test, encoded_columns = get_data(query, event_col, duration_col, features)
          
    models = {
        'DeepSurv': DeepSurv(input_size=X_train.shape[1]),
        'LogisticHazardModel': LogisticHazardModel(input_size=X_train.shape[1]),
        'DeepHitModel': DeepHitModel(input_size=X_train.shape[1]), 
        'PCHazardModel': PCHazardModel(input_size=X_train.shape[1]), 
        'MTLRModel': MTLRModel(input_size=X_train.shape[1]),
        'AalenAdditive': AalenAdditiveModel(),
        'CoxPH': CoxPHModel(),
        'RandomSurvivalForest': RandomSurvivalForestModel(),
        'GradientBoosting': GradientBoostingSurvivalModel(),
    }
    
    best_models, all_results = hyperparameter_search(
        X_train, y_train, X_test, y_test,
        treatment_col='systemicTreatmentPlan', encoded_columns=encoded_columns,
        event_col=event_col, duration_col=duration_col,
        base_models=models, param_grids=param_grids, metric_comparison=metric_comparison
    )
       
    return best_models, all_results         

In [None]:
os_best_models, os_results = optimize_hyperparameters(query, event_col = 'hadSurvivalEvent', duration_col = 'observedOsFromTreatmentStartDays', features = features, metric_comparison='auc')

In [None]:
pfs_best_models, pfs_results = optimize_hyperparameters(query, event_col = 'hadProgressionEvent', duration_col = 'observedPfsDays', features=features, metric_comparison='auc')