# Predictive Algorithms

This notebook demonstrates the development and evaluation of predictive algorithms for survival analysis. The aim is to train, evaluate, and optimize various survival models to predict overall survival (OS) and progression-free survival (PFS) for patients, using both classical and deep learning techniques. It includes data preprocessing, model training, hyperparameter optimization, and visualization of survival curves.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.data.data_processing import DataSplitter, DataPreprocessor
from src.data.lookups import LookupManager

db_config_path = '/home/jupyter/.my.cnf'
db_name = 'actin_personalization'
query = "SELECT * FROM knownPalliativeTreatments"

preprocessor = DataPreprocessor(db_config_path, db_name)

lookup_manager = LookupManager()
features = lookup_manager.features

## Data Preprocessing

In this section, we set up the data pipeline for survival analysis. The `DataSplitter` and `DataPreprocessor` classes are used to load, preprocess, and split the data into training and testing sets. This ensures the survival data is structured appropriately for model training.


In [None]:
def get_data(query, event_col, duration_col, features):
    splitter = DataSplitter(test_size=0.1, random_state=42)
    
    df, features, encoded_columns = preprocessor.preprocess_data(query, duration_col, event_col, features)
                          
    y = Surv.from_dataframe(event=event_col, time=duration_col, data=df)
    X_train, X_test, y_train, y_test = splitter.split(df[features], df, 'systemicTreatmentPlan', encoded_columns)
    
    return X_train, X_test, y_train, y_test, encoded_columns

## Train and Evaluate Models

This section defines the function `train_evaluate_models`, which trains various survival models using predefined configurations. The trained models are evaluated using the following metrics:

- **C-Index**: The Concordance Index measures how well the predicted survival times align with the actual outcomes. It is a measure of discrimination, indicating the model's ability to correctly rank the survival times of patients. A higher value indicates better predictive accuracy.

- **Integrated Brier Score (IBS)**: This metric evaluates the accuracy of the survival probability predictions over time. It combines the squared differences between predicted and actual survival probabilities, weighted by the survival distribution. Lower values indicate better predictive performance.

- **Calibration Error (CE)**: Calibration error assesses how well the predicted survival probabilities match the observed probabilities. It indicates whether the model is systematically overestimating or underestimating survival probabilities. Lower values signify better calibration.

- **Area Under the Curve (AUC)**: For survival models, AUC is typically computed over a time-dependent ROC curve, reflecting the model's discrimination ability at different time points. Higher AUC values indicate better overall performance.

Together, these metrics provide a comprehensive evaluation of the models' predictive performance, capturing different aspects of accuracy, discrimination, and calibration.


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, plot_title):
    """
    Plot survival curves for a specific patient using trained models.
    """
    X_patient = X_test.iloc[[patient_index]]
    actual_duration_days = y_test[duration_col].iloc[patient_index]
    actual_event = y_test[event_col].iloc[patient_index]

    plt.figure(figsize=(12, 8))
    for model_name, model in trained_models.items():
        try:
            surv_funcs = model.predict_survival_function(X_patient)

            times = np.linspace( max(fn.x[0] for fn in surv_funcs), min(fn.x[-1] for fn in surv_funcs), 100)
            surv_probs = np.row_stack([fn(times) for fn in surv_funcs])
            
            plt.step(times / 30.44, surv_probs[0], where="post", label=model_name)
        except Exception as e:
            print(f"Error plotting survival curves for model {model_name}: {e}")

    marker_color = 'red' if actual_event else 'blue'
    marker_label = "Event Time" if actual_event else "Censoring Time"
    plt.axvline(x=actual_duration_days / 30.44, color=marker_color, linestyle='--', label=marker_label)

    plt.title(f"Predicted {plot_title} Survival Curves")
    plt.xlabel("Time (months)")
    plt.ylabel("Survival Probability")
    plt.legend(loc="best")
    plt.grid(True)
    plt.show()
       

In [None]:
from src.models.survival_models import *
from src.models.model_trainer import *
from sksurv.util import Surv

def train_evaluate_models(query, event_col, duration_col, features, configs, patient_index = 78, plot_title = "Overall"):
    
    X_train, X_test, y_train, y_test, encoded_columns = get_data(query, event_col, duration_col, features)
        
    for model_name, params in configs.items():
        if 'input_size' in params:
            params['input_size'] = X_train.shape[1]

    models = {
        'DeepSurv': DeepSurv(**configs['DeepSurv']),
        'LogisticHazardModel': LogisticHazardModel(**configs['LogisticHazardModel']),
        'DeepHitModel': DeepHitModel(**configs['DeepHitModel']),
        'PCHazardModel': PCHazardModel(**configs['PCHazardModel']),
        'MTLRModel': MTLRModel(**configs['MTLRModel']),
        'AalenAdditive': AalenAdditiveModel(**configs['AalenAdditive']),
        'CoxPH': CoxPHModel(**configs['CoxPH']),
        'RandomSurvivalForest': RandomSurvivalForestModel(**configs['RandomSurvivalForest']),
        'GradientBoosting': GradientBoostingSurvivalModel(**configs['GradientBoosting']),
    }
    
    trainer = ModelTrainer(models=models, n_splits=5, random_state=42)

    results, trained_models = trainer.train_and_evaluate(
        X_train,
        y_train,
        X_test,
        y_test,
        treatment_col='systemicTreatmentPlan',
        encoded_columns=encoded_columns,
        event_col= event_col,
        duration_col= duration_col
    )
    
    results_df = pd.DataFrame.from_dict(results, orient='index')
    results_df.reset_index(inplace=True)
    results_df.rename(columns={'index': 'Model'}, inplace=True)
    
    plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, plot_title)
   
    return results_df, trained_models
    

#### Best Model Configurations

The best configurations for OS and PFS models were determined using hyperparameter optimization, as defined below in this notebook. These configurations are used to instantiate the models for training and evaluation. Below are the best configurations for the OS and PFS models.


In [None]:
# Best configurations for OS
os_configs = {
    'DeepSurv': {
        'input_size': None, 
        'num_nodes': [256, 128, 64],
        'batch_norm': False,
        'dropout': 0.1,
        'weight_decay': 0.0001,
        'lr': 0.001,
        'activation': 'relu',
        'optimizer': 'RMSprop',
        'batch_size': 32,
        'epochs': 500,
    },
    'LogisticHazardModel': {
        'input_size': None, 
        'num_nodes': [256, 128, 64],
        'lr': 0.0005,
        'dropout': 0.15,
        'batch_size': 128,
        'epochs': 500,
        'num_durations': 60,
        'early_stopping_patience': 50,
        'optimizer': 'RMSprop',
    },
    'DeepHitModel': {
        'input_size': None, 
        'num_nodes': [64, 32],
        'activation': 'elu',
        'alpha': 0.2,
        'sigma': 0.1,
        'weight_decay': 0.0001,
        'optimizer': 'Adam',
        'dropout': 0.1,
        'lr': 1e-3,
        'batch_size': 32,
        'epochs': 500,
        'num_durations': 60,
    },
    'PCHazardModel': {
        'input_size': None,  
        'num_nodes': [256, 128, 64],
        'batch_norm': True,
        'dropout': 0.15,
        'lr': 0.01,
        'batch_size': 32,
        'epochs': 500,
        'num_durations': 120,
        'optimizer': 'Adam',
    },
    'AalenAdditive': {
        'fit_intercept': True,
        'alpha': 0.05,
        'coef_penalizer': 10.0,
        'smoothing_penalizer': 0.0,
    },
    'CoxPH': {
        'alpha': 1.0,
        'ties': 'breslow',
        'n_iter': 100,
        'tol': 1e-7,
    },
    'RandomSurvivalForest': {
        'n_estimators': 200,
        'max_depth': 50,
        'min_samples_split': 50,
        'min_samples_leaf': 5,
        'max_features': None,
        'random_state': 42,
    },
    'GradientBoosting': {
        'learning_rate': 0.05, 
        'n_estimators': 100, 
        'max_depth': 10, 
        'subsample': 1.0, 
        'min_samples_leaf': 20, 
        'min_samples_split': 10, 
        'max_features': 'log2'
    },
}

# PFS is still running

In [None]:
os_model_outcomes, os_trained_models =  train_evaluate_models(query, event_col='hadSurvivalEvent', duration_col='observedOsFromTreatmentStartDays', features=features, configs=os_configs)
os_model_outcomes

In [None]:
pfs_model_outcomes, pfs_trained_models = train_evaluate_models(query, event_col='hadProgressionEvent', duration_col='observedPfsDays', features=features, configs=pfs_configs)
pfs_model_outcomes

### Metric comparison: OS vs. PFS

This section visualizes the comparison of model performance metrics (C-Index, IBS, CE, AUC) for OS and PFS. The bar plots highlight the strengths and weaknesses of each model in the two prediction tasks.


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_all_metrics(pfs_df, os_df):
    pfs_df['Type'] = 'PFS'
    os_df['Type'] = 'OS'

    combined_df = pd.concat([pfs_df, os_df], ignore_index=True)

    metrics = ['c_index', 'ibs', 'ce', 'auc']

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        ax = axes[i]
        sns.barplot(
            x='Model',
            y=metric,
            hue='Type',
            data=combined_df,
            ax=ax,
            palette='Set1'
        )
        ax.set_title(f'{metric.upper()} Comparison: OS vs. PFS')
        ax.set_xlabel('Model')
        ax.set_ylabel(metric.upper())
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
        ax.legend(title='Type', loc='best')
        ax.grid(axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.show()

In [None]:
plot_all_metrics(pfs_model_outcomes, os_model_outcomes)

## Hyperparameter Optimization

Hyperparameter optimization is performed for each model using a defined grid of parameters. The `random_parameter_search` function samples configurations to identify the optimal parameters for each model. This ensures that models achieve their best performance for the given data.


In [None]:
param_grids = {
    'LogisticHazardModel': [
        {
            'num_nodes': [[256, 128, 64], [128, 64, 32],[128, 64], [64, 32], [64], [32]],
            'lr': [0.001, 0.0005, 0.01],
            'dropout': [0.1, 0.15, 0.2],
            'batch_size': [32, 64, 128],
            'patience': [20, 30, 50],
            'optimizer': ['Adam', 'RMSprop']
        }
    ],
    'DeepHitModel': [
        {
            'num_nodes': [[256, 128, 64], [128, 64, 32],[128, 64], [64, 32], [64], [32]],
            'activation': ['swish', 'elu', 'relu'],
            'alpha': [0.2, 0.3, 0.4],
            'sigma': [0.05, 0.1, 0.2],
            'weight_decay': [1e-3, 5e-4, 1e-4],
            'optimizer': ['Adam', 'RMSprop']
        }
    ],
    'PCHazardModel': [
        {
            'num_nodes': [[256, 128, 64],[128, 64, 32], [128, 64], [64, 32], [64], [32]],
            'num_durations': [60, 80, 100, 120],
            'lr': [0.0005, 0.001, 0.01],
            'dropout': [0.1, 0.15, 0.2],
            'optimizer': ['Adam', 'RMSprop']
        }
    ],
    'MTLRModel': [
        {
            'num_nodes': [[256, 128, 64],[128, 64, 32], [128, 64], [64, 32], [64], [32]],
            'num_durations': [80, 100, 120],
            'lr': [0.0005, 0.001, 0.01],
            'dropout': [0.1, 0.15, 0.2],
            'optimizer': ['Adam', 'RMSprop']
        }
    ],
    'AalenAdditive': [
        {
            'fit_intercept': [True, False],
            'alpha': [0.01, 0.05, 0.1],
            'coef_penalizer': [0.5, 1.0, 2.0, 5.0, 10.0],
            'smoothing_penalizer': [0.0, 0.5, 1.0],
        }
    ],
    'CoxPH': [
        {
            'alpha': [0, 0.05, 0.1, 0.2, 0.5, 1.0],
            'ties': ['breslow', 'efron'],
            'n_iter': [100, 200, 500],
            'tol': [1e-5, 1e-7, 1e-9],
        }
    ],
    'RandomSurvivalForest': [
        {
            'n_estimators': [50, 100, 200, 500],
            'max_depth': [5, 10, 20, 50],
            'min_samples_split': [10, 20, 50, 100],
            'min_samples_leaf': [5, 10, 15, 20, 30],
            'max_features': ['sqrt', 'log2', None]
        }
    ],
    'GradientBoosting': [
        {
            'learning_rate': [0.01, 0.05, 0.1],
            'n_estimators': [50, 100, 200, 300, 500],
            'max_depth': [3, 5, 10],
            'subsample': [0.8, 1.0],
            'min_samples_leaf': [5, 10, 20],
            'min_samples_split': [10, 20, 50],
            'max_features': ['sqrt', 'log2', None]
        }
    ],
    'DeepSurv': [
        {
            'num_nodes': [[256, 128, 64], [128, 64, 32], [128, 64], [64, 32], [64], [32]],
            'batch_norm': [True, False],
            'dropout': [0.1, 0.2, 0.3],
            'weight_decay': [1e-3, 5e-4, 1e-4],
            'lr': [0.001, 0.0005, 0.005, 0.01],
            'activation': ['elu', 'relu'],
            'optimizer': ['Adam', 'RMSprop']
        }
    ],
}

In [None]:
import random
from itertools import product

def random_parameter_search(param_dict, n_samples):
    """
    Randomly sample `n_samples` parameter combinations from the given param_dict.
    param_dict should be a dict of lists, e.g.:
      {
        'lr': [0.001, 0.0005, 0.01],
        'dropout': [0.1, 0.2],
      }
    """
    keys = list(param_dict.keys())
    values = [param_dict[k] for k in keys]

    all_combos = list(product(*values))
    if len(all_combos) <= n_samples:
        return [dict(zip(keys, combo)) for combo in all_combos]

    sampled_combos = random.sample(all_combos, n_samples)
    return [dict(zip(keys, combo)) for combo in sampled_combos]

def hyperparameter_search(
    X_train, y_train, X_test, y_test, treatment_col, encoded_columns, event_col, duration_col,
    base_models, param_grids, n_samples=20, random_state=42
):
    random.seed(random_state)
    best_models = {}
    all_results = {}
    trainer = ModelTrainer(models={}, n_splits=5, random_state=random_state)

    for model_name, model_instance in base_models.items():
        model_class = type(model_instance)
        if model_name not in param_grids:
            print(f"No hyperparameter grid found for {model_name}, skipping optimization...")
            best_models[model_name] = (model_instance, None)
            continue

        best_score = -np.inf
        best_params = None
        best_model_trained = None
        all_results[model_name] = []

        for param_dict in param_grids[model_name]:
            sampled_params = random_parameter_search(param_dict, n_samples)
          
            for params in sampled_params:
                if issubclass(model_class, NNSurvivalModel):
                    new_model = model_class(input_size=X_train.shape[1], **params)
                else:
                    new_model = model_class(**params)

                print(f"Training {model_name} with parameters: {params}")

                # Use the new_model here instead of model_instance
                trainer.models = {model_name: new_model}
                results, trained_models = trainer.train_and_evaluate(
                    X_train, y_train, X_test, y_test,
                    treatment_col=treatment_col,
                    encoded_columns=encoded_columns,
                    event_col=event_col,
                    duration_col=duration_col
                )

                current_score = results[model_name]['auc']
                all_results[model_name].append((params, results[model_name]))

                if current_score > best_score:
                    best_score = current_score
                    best_params = params
                    best_model_trained = trained_models[model_name]


        best_models[model_name] = (best_model_trained, best_params)
        print(f"Best params for {model_name}: {best_params} with auc={best_score}")

    return best_models, all_results

In [None]:
def optimize_hyperparameters(query, event_col, duration_col, features):
    
    X_train, X_test, y_train, y_test, encoded_columns = get_data(query, event_col, duration_col, features)
          
    models = {
        'DeepSurv': DeepSurv(input_size=X_train.shape[1]),
        'LogisticHazardModel': LogisticHazardModel(input_size=X_train.shape[1]),
        'DeepHitModel': DeepHitModel(input_size=X_train.shape[1]), 
        'PCHazardModel': PCHazardModel(input_size=X_train.shape[1]), 
        'MTLRModel': MTLRModel(input_size=X_train.shape[1]),
        'AalenAdditive': AalenAdditiveModel(),
        'CoxPH': CoxPHModel(),
        'RandomSurvivalForest': RandomSurvivalForestModel(),
        'GradientBoosting': GradientBoostingSurvivalModel(),
    }
    
    best_models, all_results = hyperparameter_search(
        X_train, y_train, X_test, y_test,
        treatment_col='systemicTreatmentPlan', encoded_columns=encoded_columns,
        event_col=event_col, duration_col=duration_col,
        base_models=models, param_grids=param_grids, n_samples = 20
    )
    
    for model_name, trials in all_results.items():
        for params, metrics in trials:
            print(model_name, params, metrics['c_index'], metrics['ibs'], metrics['auc'], metrics['ce'])
            
    return best_models, all_results
            

In [None]:
os_best_models, os_results = optimize_hyperparameters(query, event_col = 'hadSurvivalEvent', duration_col = 'observedOsFromTreatmentStartDays', features = features)

In [None]:
pfs_best_models, pfs_results = optimize_hyperparameters(query, event_col = 'hadProgressionEvent', duration_col = 'observedPfsDays', features=features)