# Demo Notebook:
## Random Survival Forest

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
import pytorch_lightning
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
import pickle
from tqdm import tqdm

from pycox.datasets import support
from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from torch.utils.data import TensorDataset, DataLoader

from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


In [3]:
set_config(display="text")  # displays text representation of estimators

# Load data

In [4]:
def get_dataloaders(dataset, competing_risk, sample_size=None):

    match dataset.lower():
        case "pycox":
            df_train = support.read_df()
            df_test = df_train.sample(frac=0.2)
            df_train = df_train.drop(df_test.index)
            df_val = df_train.sample(frac=0.2)
            df_train = df_train.drop(df_val.index)
            
            cols_standardize = ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
            cols_leave = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']
            
            standardize = [([col], StandardScaler()) for col in cols_standardize]
            leave = [(col, None) for col in cols_leave]
            
            x_mapper = DataFrameMapper(standardize + leave)
            
            x_train = x_mapper.fit_transform(df_train).astype('float32')
            x_val = x_mapper.transform(df_val).astype('float32')
            x_test = x_mapper.transform(df_test).astype('float32')
            
            get_target = lambda df: (df['duration'].values, df['event'].values)
            y_train = get_target(df_train)
            y_val = get_target(df_val)
            y_test = get_target(df_test)
            
            t_train, e_train = y_train
            t_val, e_val = y_val
            t_test, e_test = y_test
            
            t_train_max = np.amax(t_train)
            t_train = t_train / t_train_max
            t_val = t_val / t_train_max
            t_test = t_test / t_train_max
            
    
        case "hypertension" | "cvd":
    
            # Training samples
            if sample_size is not None:
                save_path =  f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_{dataset}/" + f"benchmark_data/N={sample_size}.pickle" 
            else:
                save_path = f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_{dataset}/" + "benchmark_data/all.pickle"
                
            with open(save_path, "rb") as handle:
                print(f"Loading training dataset from {save_path}")
                data_train = pickle.load(handle)
            
            # display(data["X_train"].head())
            # display(data["y_train"])
            # print(data.keys())
            
            data = {}
            data["X_train"] = data_train["X_train"]
            data["y_train"] = data_train["y_train"]
    
            # Test and validation samples
    
            save_path = f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_{dataset}/" + "benchmark_data/all.pickle"
            with open(save_path, "rb") as handle:
                print(f"Loading validation/test datasets from {save_path}")
                data_val_test = pickle.load(handle)
                
            data["X_val"] = data_val_test["X_val"]
            data["y_val"] = data_val_test["y_val"]
            data["X_test"] = data_val_test["X_test"]
            data["y_test"] = data_val_test["y_test"]
    
            # Convert to correct formats
            x_train = data["X_train"].to_numpy(dtype=np.float32)
            x_val = data["X_val"].to_numpy(dtype=np.float32)
            x_test = data["X_test"].to_numpy(dtype=np.float32)
            
            t_train = np.asarray([i[1] for i in data["y_train"]])
            t_val = np.asarray([i[1] for i in data["y_val"]])        
            t_test = np.asarray([i[1] for i in data["y_test"]])
    
            if competing_risk is False:
                e_train = np.asarray([0 if i[0] == 0 else 1 for i in data["y_train"]])
                e_val = np.asarray([0 if i[0] == 0 else 1 for i in data["y_val"]])
                e_test = np.asarray([0 if i[0] == 0 else 1 for i in data["y_test"]])
            else:
                e_train = np.asarray([i[0] for i in data["y_train"]])
                e_val = np.asarray([i[0] for i in data["y_val"]])
                e_test = np.asarray([i[0] for i in data["y_test"]])

    # display(x_train.shape)
    # display(type(x_train))
    # display(type(x_train[0,0]))
    # display(e_train.shape)
    # display(type(e_train))
    # display(type(e_train[0]))
    # display(t_train.shape)
    # display(type(t_train))
    # display(type(t_train[0]))
    # print(np.mean(e_train))
    # print(np.mean(t_train))
    # print(np.std(t_train))
    # print(np.mean(x_train))
    # print(t_train.min())
    # print(t_train.max())
    # print(np.unique(e_test, return_counts=True))

    # print(x_train.shape)
    # print(t_train.shape)
    # print(e_train.shape)

    
    Xtrain = pd.DataFrame(x_train, )    #  columns=list(dm.train_set.tokenizer._stoi.keys())[1:]
    Xval = pd.DataFrame(x_val, )    #  columns=list(dm.train_set.tokenizer._stoi.keys())[1:]
    Xtest = pd.DataFrame(x_test, )    #  columns=list(dm.train_set.tokenizer._stoi.keys())[1:]

    if competing_risk is False:
        ytrain = np.array([(_yk, _yt) for _yk, _yt in zip(e_train, t_train)], dtype=[('cens', 'bool'), ('time', '<f8')])
        yval = np.array([(_yk, _yt) for _yk, _yt in zip(e_val, t_val)], dtype=[('cens', 'bool'), ('time', '<f8')])
        ytest = np.array([(_yk, _yt) for _yk, _yt in zip(e_test, t_test)], dtype=[('cens', 'bool'), ('time', '<f8')])
    else:
        # Package does not support Competing Risks
        raise NotImplementedError
        
        # ytrain = np.array([(_yk, _yt) for _yk, _yt in zip(e_train, t_train)])# , dtype=[('cens', 'float'), ('time', '<f8')])
        # yval = np.array([(_yk, _yt) for _yk, _yt in zip(e_val, t_val)]) #, dtype=[('cens', 'float'), ('time', '<f8')])
        # ytest = np.array([(_yk, _yt) for _yk, _yt in zip(e_test, t_test)]) #, dtype=[('cens', 'float'), ('time', '<f8')])
    # print(Xtrain.head())
    # print(ytrain[:5])

    return (Xtrain, ytrain), (Xval, yval), (Xtest, ytest)




# Example dataloader function usage

In [7]:
dataset_train, dataset_val, dataset_test = get_dataloaders("CVD", False, sample_size=2999)

Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/N=2999.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/all.pickle
4.401644706726074
(array([0, 1]), array([31472,  4286]))
(2999, 279)
(2999,)
(2999,)


In [8]:
print(dataset_test[0].head())
print(dataset_test[1][:3])

   0    1    2    3    4    5    6    7    8    9    ...  269  270  271  272  \
0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  ...  0.0  0.0  0.0  0.0   
1  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  ...  1.0  0.0  1.0  1.0   
2  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  ...  0.0  1.0  1.0  1.0   
3  1.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  ...  1.0  0.0  1.0  1.0   
4  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  ...  1.0  0.0  0.0  0.0   

   273  274  275  276  277  278  
0  0.0  0.0  0.0  0.0  0.0  0.0  
1  1.0  1.0  1.0  1.0  1.0  1.0  
2  0.0  1.0  1.0  1.0  0.0  0.0  
3  0.0  1.0  1.0  1.0  0.0  0.0  
4  0.0  1.0  1.0  1.0  0.0  0.0  

[5 rows x 279 columns]
[(False, 2.43013716) (False, 0.34082222) (False, 2.12383461)]


# Train model

In [11]:
dataset = "CVD" # "Hypertension"
competing_risk = False
sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]      # [3000, 12500, 30000, 60000, 100000]: # 600, 1200, 
# sample_sizes = [None]

# the time grid which we generate over
t_eval = np.linspace(0, 1, 1000) 
# the time grid which we calculate scores over
time_grid = np.linspace(start=0, stop=1 , num=300)


In [13]:
# Loop over different dataset sizes, train the Random Survival Forest with bootstapping and report results


for sample_size in sample_sizes:

    # Get dataset for given sample size
    dataset_train, dataset_val, dataset_test = get_dataloaders(dataset, competing_risk, sample_size=sample_size)

    # Create RSF model with default bootstrap values due to memory constraints
    print(f"Fitting Random Survival Forest")
    rsf = RandomSurvivalForest(
        bootstrap=True,
        max_samples=1000,    
        random_state=42,
        low_memory=False
    )
    rsf.fit(dataset_train[0], dataset_train[1])

    # Test
    bsz = 512
    print(f"Evaluating performance by splitting {dataset_test[0].shape} test samples into batches of size {bsz}")
    
    ctd = []
    ibs = []
    inbll = []
    for batch_idx in range(0, dataset_test[0].shape[0], bsz):
    
        batch_dataset_test = (dataset_test[0][batch_idx:batch_idx + bsz], dataset_test[1][batch_idx:batch_idx + bsz])
        actual_bsz = batch_dataset_test[0].shape[0]
    
        # Predict survival functionfor batch
        surv = rsf.predict_survival_function(batch_dataset_test[0], return_array=True)
        
        # Find the indices in rsf.unique_times_ that are closest to values in t_eval, so we can evaluate the RSF if the same way as other benchmarks
        closest_indices = [np.abs(rsf.unique_times_ - v).argmin() for v in t_eval]
        surv_reduced = surv[:, closest_indices]
    
        # Format appropriately
        df_surv = pd.DataFrame(np.transpose(surv_reduced), index=t_eval)
        
        lbls_test = np.zeros((actual_bsz,))
        t_test = np.zeros((actual_bsz,))
        for sample_idx in range(actual_bsz):
            lbls_test[sample_idx] = 1 if batch_dataset_test[1][sample_idx][0] == True else 0
            t_test[sample_idx] = batch_dataset_test[1][sample_idx][1]
    
        # Same treatment as in SurvivEHR
        ev = EvalSurv(df_surv, t_test, lbls_test, censor_surv='km')
        ctd.append(ev.concordance_td())
        ibs.append(ev.integrated_brier_score(time_grid))
        inbll.append(ev.integrated_nbll(time_grid))
    
        # print(f"Scores up to sample {batch_idx+bsz}:".ljust(50) + f"Ctd: {np.mean(ctd):.3f}. IBS: {np.mean(ibs):.4f}. INBLL: {np.mean(inbll):.3f}")
    
    # display(f"Ctd:   {np.mean(ctd):.3f}")
    # display(f"IBS:   {np.mean(ibs):.4f}")
    # display(f"INBLL: {np.mean(inbll):.3f}")
    print(f"\tRandom Survival Forest ({'CR' if competing_risk else 'SR'}):".ljust(20) + f"N={sample_size}.".ljust(15) + f"Ctd: {np.mean(ctd)}. IBS: {np.mean(ibs)}. INBLL: {np.mean(inbll)}")





Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/N=2999.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/all.pickle
4.401644706726074
(array([0, 1]), array([31472,  4286]))
(2999, 279)
(2999,)
(2999,)
Fitting Random Survival Forest
Evaluating performance by splitting (35758, 279) test samples into batches of size 512
Random Survival Forest (SR):N=2999.        Ctd: 0.5817231098874919. IBS: 0.03399845492489739. INBLL: 0.14861047476033662
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/N=5296.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/all.pickle
4.402191162109375
(array([0, 1]), array([31472,  4286]))
(5296

# Output across different setups

Cardiovascular disease Single Risk

```
Random Survival Forest (SR):N=2999.        Ctd: 0.5817231098874919. IBS: 0.03399845492489739. INBLL: 0.14861047476033662
Random Survival Forest (SR):N=5296.        Ctd: 0.5726886646549809. IBS: 0.03396416779973878. INBLL: 0.149062513427203
Random Survival Forest (SR):N=9351.        Ctd: 0.5883719280021017. IBS: 0.03384484456385947. INBLL: 0.14750556866346265
Random Survival Forest (SR):N=16509.       Ctd: 0.5960278990576237. IBS: 0.0337225214683957. INBLL: 0.14647954529430643
Random Survival Forest (SR):N=29148.       Ctd: 0.5967549074819676. IBS: 0.033762424117441375. INBLL: 0.14621266423294518
Random Survival Forest (SR):N=51461.       Ctd: 0.6081889934667545. IBS: 0.03375098553961131. INBLL: 0.1457478005690567
Random Survival Forest (SR):N=90856.       Ctd: 0.6091873754214131. IBS: 0.033695391903505144. INBLL: 0.1452056774612423
Random Survival Forest (SR):N=160407.      Ctd: 0.6068623782034415. IBS: 0.033734096545360415. INBLL: 0.14558305134538546
Random Survival Forest (SR):N=283203.      Ctd: 0.6140593631057283. IBS: 0.03370825445709043. INBLL: 0.14507898552247775
Random Survival Forest (SR):N=500000.      Ctd: 0.6118567664875652. IBS: 0.033727277767864286. INBLL: 0.14548991693533334
```

'Ctd:   0.605'

'IBS:   0.0337'

'INBLL: 0.146'

In [12]:
print(len(ctd))

70
