# Predictive Algorithms

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.data.data_processing import DataSplitter, DataPreprocessor
from src.data.lookups import LookupManager

db_config_path = '/home/jupyter/.my.cnf'
db_name = 'actin_personalization'
query = "SELECT * FROM knownPalliativeTreatments"

preprocessor = DataPreprocessor(db_config_path, db_name)

lookup_manager = LookupManager()
features = lookup_manager.features

## Train and evaluate models

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, plot_title):
    """
    Plot survival curves for a specific patient using trained models.
    """
    X_patient = X_test.iloc[[patient_index]]
    actual_duration_days = y_test[duration_col].iloc[patient_index]
    actual_event = y_test[event_col].iloc[patient_index]

    plt.figure(figsize=(12, 8))
    for model_name, model in trained_models.items():
        try:
            surv_funcs = model.predict_survival_function(X_patient)

            times = np.linspace( max(fn.x[0] for fn in surv_funcs), min(fn.x[-1] for fn in surv_funcs), 100)
            surv_probs = np.row_stack([fn(times) for fn in surv_funcs])
            
            plt.step(times / 30.44, surv_probs[0], where="post", label=model_name)
        except Exception as e:
            print(f"Error plotting survival curves for model {model_name}: {e}")

    marker_color = 'red' if actual_event else 'blue'
    marker_label = "Event Time" if actual_event else "Censoring Time"
    plt.axvline(x=actual_duration_days / 30.44, color=marker_color, linestyle='--', label=marker_label)

    plt.title(f"Predicted {plot_title} Survival Curves")
    plt.xlabel("Time (months)")
    plt.ylabel("Survival Probability")
    plt.legend(loc="best")
    plt.grid(True)
    plt.show()
       

In [None]:
from src.models.survival_models import *
from src.models.model_trainer import *
from sksurv.util import Surv

def train_evaluate_models(query, event_col, duration_col, features, patient_index = 78, plot_title = "Overall"):
    splitter = DataSplitter(test_size=0.1, random_state=42)
    
    df, features, encoded_columns = preprocessor.preprocess_data(query, duration_col, event_col, features)
                          
    y = Surv.from_dataframe(event=event_col, time=duration_col, data=df)
    X_train, X_test, y_train, y_test = splitter.split(df[features], df, 'systemicTreatmentPlan', encoded_columns)
    
    models = {
        'DeepSurv': DeepSurv(input_size=X_train.shape[1]),
        'LogisticHazardModel': LogisticHazardModel(input_size=X_train.shape[1]),
        'DeepHitModel': DeepHitModel(input_size=X_train.shape[1]), 
        'PCHazardModel': PCHazardModel(input_size=X_train.shape[1]), 
        'MTLRModel': MTLRModel(input_size=X_train.shape[1]),
        'AalenAdditive': AalenAdditiveModel(),
        'CoxPH': CoxPHModel(),
        'RandomSurvivalForest': RandomSurvivalForestModel(),
        'GradientBoosting': GradientBoostingSurvivalModel(),
    }

    trainer = ModelTrainer(models=models, n_splits=5, random_state=42)

    results, trained_models = trainer.train_and_evaluate(
        X_train,
        y_train,
        X_test,
        y_test,
        treatment_col='systemicTreatmentPlan',
        encoded_columns=encoded_columns,
        event_col= event_col,
        duration_col= duration_col
    )
    
    results_df = pd.DataFrame.from_dict(results, orient='index')
    results_df.reset_index(inplace=True)
    results_df.rename(columns={'index': 'Model'}, inplace=True)
    
    plot_survival_curves_for_patient(trained_models, X_test, y_test, patient_index, event_col, duration_col, plot_title)
    
    return results_df, trained_models
    

In [None]:
os_model_outcomes, os_trained_models = train_evaluate_models(query, 'hadSurvivalEvent', 'observedOsFromTreatmentStartDays', features)
os_model_outcomes

In [None]:
pfs_model_outcomes, pfs_trained_models = train_evaluate_models(query, 'hadProgressionEvent', 'observedPfsDays', features, plot_title="Progression-Free")
pfs_model_outcomes

## Visualize results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_all_metrics(pfs_df, os_df):
    pfs_df['Type'] = 'PFS'
    os_df['Type'] = 'OS'

    combined_df = pd.concat([pfs_df, os_df], ignore_index=True)

    metrics = ['c_index', 'ibs', 'ce', 'auc']

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        ax = axes[i]
        sns.barplot(
            x='Model',
            y=metric,
            hue='Type',
            data=combined_df,
            ax=ax,
            palette='Set1'
        )
        ax.set_title(f'{metric.upper()} Comparison: OS vs. PFS')
        ax.set_xlabel('Model')
        ax.set_ylabel(metric.upper())
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
        ax.legend(title='Type', loc='best')
        ax.grid(axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.show()

In [None]:
plot_all_metrics(pfs_model_outcomes, os_model_outcomes)