# Predictive Algorithms for Survival Analysis for predicting survival days

This notebook was made to test the feasability of predicting specific survival days and contains a full pipeline for training and evaluating predictive models for survival analysis. We handle both overall survival (OS) and progression‑free survival (PFS) by reusing the same functions. The pipeline loads and preprocesses data, visualizes the target distributions (before and after log transformation), displays correlation heatmaps and residual plots, evaluates a set of regression models, and performs hyperparameter optimization (using random search) for selected models.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

os.chdir('/data/repos/actin-personalization/prediction')
sys.path.insert(0, os.path.abspath("src/main/python"))

In [None]:
from data.data_processing import DataSplitter, DataPreprocessor
from data.lookups import LookupManager
from utils.settings import settings
from models import *

In [None]:
preprocessor = DataPreprocessor(settings.db_config_path, settings.db_name)

lookup_manager = LookupManager()
features = lookup_manager.features

## Data Preprocessing

We define a function to load and prepare the data. We load the data, filter for rows with the event of interest, apply a log transformation to the survival days target, and splits the features and target.


In [None]:
def get_data(features):
    splitter = DataSplitter(test_size=0.1, random_state=42)
    
    raw_df = preprocessor.load_data()
    raw_df = raw_df[ raw_df[settings.event_col] == 1 ].copy()

    df, features, encoded_columns = preprocessor.preprocess_data(features, df=raw_df)

    y_df = df[[settings.event_col, settings.duration_col]].copy()

    X_train, X_test, y_train_df, y_test_df = splitter.split(df[features],y_df,encoded_columns)

    y_train = y_train_df[settings.duration_col].astype(float)
    y_test  = y_test_df [settings.duration_col].astype(float)

    return df, X_train, X_test, y_train, y_test, features, encoded_columns


In [None]:
df, X_train, X_test, y_train, y_test, features, encoded_columns = get_data(features)

### Visualization of Target

We now visualize the distribution of the OS target before and after log transformation. This helps us understand the skewness of the data and the effect of the transformation.


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def visualize_target_distribution(df, target_col):
    plt.figure(figsize=(10,6))
    sns.histplot(df[target_col], kde=True, bins=30)
    plt.title(f"Distribution of {target_col} (Days)")
    plt.xlabel(f"{target_col} (Days)")
    plt.ylabel("Frequency")
    plt.show()
    log_target = np.log1p(df[target_col])
    plt.figure(figsize=(10,6))
    sns.histplot(log_target, kde=True, bins=30, color='green')
    plt.title(f"Distribution of Log-Transformed {target_col}")
    plt.xlabel(f"Log({target_col} + 1)")
    plt.ylabel("Frequency")
    plt.show()
    return log_target

In [None]:
visualize_target_distribution(df, settings.duration_col)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def plot_correlation_with_target(df, target_col):
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    
    ignore_cols = ['hadSurvivalEvent', 'hadProgressionEvent']
    for col in ignore_cols:
        if col in numeric_cols:
            numeric_cols.remove(col)
    
    if target_col in numeric_cols:
        numeric_cols.remove(target_col)
    
    corrs = df[numeric_cols].corrwith(df[target_col]).sort_values(ascending=False)
    
    plt.figure(figsize=(8, 12))
    sns.barplot(x=corrs.values, y=corrs.index)
    plt.title(f"Correlation of Numeric Features with {target_col}")
    plt.xlabel("Correlation Coefficient")
    plt.ylabel("Feature")
    plt.tight_layout()
    plt.show()

In [None]:
plot_correlation_with_target(df, settings.duration_col)

### KNN and Best-K Determination

The function `determine_best_k_nn` runs cross‑validation over a range of K values for a K‑Nearest Neighbors regressor (using log‑transformed target values) and plots the cross‑validated negative MSE.

In [None]:
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import cross_val_score

def determine_best_k_nn(X_train, y_train, k_range):
    scores = {}
    for k in k_range:
        knn = KNeighborsRegressor(n_neighbors=k)
        cv_scores = cross_val_score(knn, X_train, y_train, cv=5, scoring='neg_mean_squared_error')
        scores[k] = np.mean(cv_scores)
    best_k = max(scores, key=scores.get)
    plt.figure(figsize=(8,6))
    plt.plot(list(scores.keys()), list(scores.values()), marker='o')
    plt.title("Cross-Validated Negative MSE for Different k")
    plt.xlabel("k")
    plt.ylabel("CV Negative MSE")
    plt.show()
    return best_k, scores

In [None]:
best_knn_k, knn_scores = determine_best_k_nn(X_train, y_train, range(1, 21))    

## Model Evaluation and Optimization Functions

The `evaluate_models` function trains a set of models on the log-transformed target and reports performance on both the log scale and the original scale. The `optimize_model_random_search` function performs randomized hyperparameter tuning.


In [None]:
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR
from xgboost import XGBRegressor

In [None]:
def evaluate_models(models, X_train, X_test, y_train, y_test, label):
    trained_models = {}
    results = []
    for model_name, model in models.items():
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        mse_ = mean_squared_error(y_test, y_pred)
        r2_ = r2_score(y_test, y_pred)
        
        trained_models[model_name] = model
        results.append({
            "Model": model_name,
            "MSE (raw)": mse_,
            "R² (raw)": r2_
        })
        model.predict
    
    plot_best_model(models, results, X_train, X_test, y_train, y_test, label, metric = "R² (raw)")
    
    return trained_models, results

def evaluate_models_logtarget(models, X_train, X_test, y_train, y_test, label):
    trained_models = {}
    results = []
    
    y_train_log = np.log1p(y_train)
    y_test_log = np.log1p(y_test)
    
    for model_name, model in models.items():
        model.fit(X_train, y_train_log)
        y_pred_log = model.predict(X_test)
        
        mse_log = mean_squared_error(y_test_log, y_pred_log)
        r2_log = r2_score(y_test_log, y_pred_log)
        
        y_pred_orig = np.expm1(y_pred_log)
        mse_orig = mean_squared_error(y_test, y_pred_orig)
        r2_orig = r2_score(y_test, y_pred_orig)
        
        trained_models[model_name] = model
        results.append({
            "Model": model_name,
            "MSE (log)": mse_log,
            "R² (log)": r2_log,
            "MSE (orig)": mse_orig,
            "R² (orig)": r2_orig
        })
    
    plot_best_model(models, results, X_train, X_test, y_train_log, y_test_log, label, metric = "R² (orig)")
    
    return trained_models, results


In [None]:
def plot_predictions(y_test, y_pred, title, is_log=False):
    plt.figure(figsize=(8,6))
    sns.scatterplot(x=y_test, y=y_pred)
    
    min_val = min(y_test.min(), y_pred.min())
    max_val = max(y_test.max(), y_pred.max())
    
    plt.plot([min_val, max_val], [min_val, max_val], color='red', lw=2)
    scale_label = "Log(Observed Days)" if is_log else "Observed Days"
    plt.xlabel(f"Actual ({scale_label})")
    plt.ylabel(f"Predicted ({scale_label})")
    plt.title(title)
    plt.show()

def plot_residuals(y_true, y_pred, is_log=False):
    residuals = y_true - y_pred
    plt.figure(figsize=(8,6))
    sns.histplot(residuals, kde=True, bins=30)
    scale_label = "Log Scale" if is_log else "Raw Scale"
    plt.title(f"Residuals Distribution ({scale_label})")
    plt.xlabel("Residuals")
    plt.ylabel("Frequency")
    plt.show()
    
    plt.figure(figsize=(8,6))
    sns.scatterplot(x=y_pred, y=residuals)
    plt.axhline(0, color='red', linestyle='--')
    plt.title(f"Residuals vs Predicted ({scale_label})")
    plt.xlabel("Predicted")
    plt.ylabel("Residuals")
    plt.show()

def plot_best_model(models, results, X_train, X_test, y_train, y_test, label, metric = "R²"):
    results_df = pd.DataFrame(results)
    best_model_name = results_df.sort_values(by=metric, ascending=False).iloc[0]["Model"]

    best_model = models[best_model_name]
    best_model.fit(X_train, y_train)
    y_pred = best_model.predict(X_test)

    plot_predictions(y_test, y_pred, f"{best_model_name}: Predicted vs Actual {label}")
    plot_residuals(y_test, y_pred)   
    

In [None]:
from sklearn.metrics import make_scorer, r2_score

def r2_original_score(y_true_log, y_pred_log):
    y_true_orig = np.expm1(y_true_log)
    y_pred_orig = np.expm1(y_pred_log)
    return r2_score(y_true_orig, y_pred_orig)

def hyperparameter_search(model, param_dist, X_train, y_train, use_log=False, cv_folds=5, n_iter=20):
    if use_log:
        scoring = make_scorer(r2_original_score, greater_is_better=True)
    else:
        scoring = 'r2'
        
    random_search = RandomizedSearchCV(model, param_dist, n_iter=n_iter, cv=cv_folds,
                                       scoring=scoring, verbose=1,
                                       n_jobs=-1, random_state=42)
    random_search.fit(X_train, y_train)
    
    return random_search.best_estimator_, random_search.best_params_, random_search.best_score_

def optimize_models(models, X_train, y_train, use_log=False, cv_folds=5, n_iter=20):
    optimized_results = {}
    y_train_mod = np.log1p(y_train) if use_log else y_train
    for model_name, model in models.items():
        if model_name in days_param_grids:
            param_grid = days_param_grids[model_name][0]
            best_est, best_params, best_score = hyperparameter_search(
                model, param_grid, X_train, y_train_mod, use_log=use_log, cv_folds=cv_folds, n_iter=n_iter
            )
            optimized_results[model_name] = {
                "best_estimator": best_est,
                "best_params": best_params,
                "best_score": best_score
            }
            print(f"{model_name} optimized: best_params: {best_params}, best_score (R²): {best_score}")
        else:
            print(f"No parameter grid defined for {model_name}. Skipping optimization.")
    return optimized_results


## Full Pipeline Function

The `run_pipeline` function runs the entire workflow for a given survival type ("OS" or "PFS"). It loads data, performs visualizations, evaluates models, plots predictions, and executes hyperparameter tuning.


In [None]:
def run_pipeline(models,  X_train, X_test, y_train, y_test, survival_type, optimize=False):
    if survival_type.lower() == 'os':
        event_col = 'hadSurvivalEvent'
        duration_col = 'observedOsFromTreatmentStartDays'
        label = "Observed OS"
    elif survival_type.lower() == 'pfs':
        event_col = 'hadProgressionEvent'
        duration_col = 'observedPfsDays'
        label = "Observed PFS"
    else:
        raise ValueError("survival_type must be 'OS' or 'PFS'")
    
    print("=== Evaluating Models on Raw Target ===")
    trained_models, results = evaluate_models(models, X_train, X_test, y_train, y_test, label)
    print(pd.DataFrame(results))
    
    print("=== Evaluating Models on Log-Transformed Target ===")
    log_trained_models, log_results = evaluate_models_logtarget(models, X_train, X_test, y_train, y_test, label)
    print(pd.DataFrame(log_results))
    
    if optimize:
        print("\n=== Hyperparameter Optimization for non-transformed output===")
        optimized_models = optimize_models(models, X_train, y_train, use_log=False, n_iter = 10)
        print("Optimized Models Summary:")
        print(pd.DataFrame(optimized_models).T)
        
        print("\n=== Hyperparameter Optimization for log transformed output===")
        optimized_models = optimize_models(models, X_train, y_train, use_log=True, n_iter = 10)
        print("Optimized Models Summary:")
        print(pd.DataFrame(optimized_models).T)
    
    return trained_models, log_trained_models

### Running the Pipeline

Call `run_pipeline` with either `"OS"` or `"PFS"` to execute the entire workflow.


In [None]:
models = {
    "LinearRegression": LinearRegression(),
    "Ridge": Ridge(solver='svd', alpha= 10.0, random_state=42),
    "Lasso": Lasso(max_iter=5000, alpha=1.0, random_state=42),
    "RandomForest": RandomForestRegressor(n_estimators=500, min_samples_split = 10, min_samples_leaf = 2, max_features = 'sqrt', max_depth = 20, random_state=42),
    "GradientBoosting": GradientBoostingRegressor(random_state=42),
    "MLPRegressor": MLPRegressor(solver = 'adam', learning_rate_init = 0.001, hidden_layer_sizes=(32,), alpha=0.001, activation='relu', random_state=42),
    "SVR_RBF": SVR(gamma = 'auto', kernel='rbf', C=10.0),
    "XGBRegressor": XGBRegressor(subsample = 0.6, n_estimators = 300, max_depth=5, learning_rate=0.01, random_state=42), 
    "KNN": KNeighborsRegressor(weights='distance', p=1, n_neighbors=9)
}

trained_models, log_trained_models = run_pipeline(models,  X_train, X_test, y_train, y_test, settings.outcome, optimize = False)

### Patient outcomes

In [None]:
def predict_patient(model, df, features, target_col, patient_index, use_log=False):
    patient_row = df.iloc[[patient_index]]
    patient_features = patient_row[features]
    
    pred = model.predict(patient_features)
    if use_log:
        pred = np.expm1(pred)
    
    actual_value = patient_row[target_col].iloc[0]
    
    print(f"--- Prediction for Patient at Index {patient_index} ---")
    print("Patient Features:")
    display(patient_features) 
    print(f"\nActual {target_col}: {actual_value}")
    print(f"Predicted {target_col}: {pred[0]}")

In [None]:
predict_patient(
    model=trained_models["GradientBoosting"], 
    df=df, 
    features=features, 
    target_col=settings.duration_col, 
    patient_index=12,     
    use_log=True
)