# Demo Notebook:
## Random Survival Forest

In [3]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf

from FastEHR.dataloader import FoundationalDataModule

from pycox.datasets import support
from pycox.evaluation import EvalSurv
from scipy.integrate import trapz

from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

from CPRD.examples.modelling.benchmarks.DeSurv.make_desurv_loader import get_dataloaders
from CPRD.examples.modelling.SurvivEHR.custom_outcome_methods import custom_mm_outcomes


In [3]:
set_config(display="text")  # displays text representation of estimators

# Extract the indicies which relate to the diagnoses

In [4]:
with initialize(version_base=None, config_path="../../SurvivEHR/confs", job_name="desurv-mm-notebook"):
    cfg = compose(config_name="config_CompetingRisk11M")

dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/",
                            overwrite_meta_information=cfg.data.meta_information_path,
                            load=True)

# Get the indicies for the diagnoses used to stratify patient groups (under the SurvivEHR setup)
conditions = custom_mm_outcomes(dm)
encoded_conditions = dm.tokenizer.encode(conditions)                    # The indicies of the MM events in the xsectional dataset (not adjusted for UNK/PAD/static data)

# Get the number of baseline static variables (after one-hot encoding etc), and the vocab size excluding PAD and UNK tokens
num_cov = dm.train_set[0]["static_covariates"].shape[0]
num_context_tokens = dm.tokenizer._event_counts.shape[0] - 1            # Removing UNK token, which is not included in xsectional datasets

# Convert the `encoded_conditions` indicies to the equivalent in the xsectional dataset
encoded_conditions_xsec = [_ind + num_cov - 1 for _ind in encoded_conditions]
print(encoded_conditions_xsec)

[17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 41, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 58, 59, 60, 61, 63, 64, 67, 71, 72, 73, 75, 77, 80, 82, 85, 89, 90, 93, 94, 96, 97, 98, 104, 106, 108, 109, 110, 112, 115, 119, 121, 129, 135, 138, 141, 144, 148, 149, 151]


# Load data

## Example/test dataloader usage

In [5]:
# dataset_train, dataset_val, dataset_test = get_dataloaders("Hypertension", False, sample_size=2999, seed=1)

dataset_train, dataset_val, dataset_test = get_dataloaders("MultiMorbidity50+", False, benchmark="sklearn_RSF", sample_size=20000, seed=1)

num_xsectional_in_dims = dataset_train[0].shape[-1]
print(num_xsectional_in_dims)
assert num_xsectional_in_dims == num_cov + num_context_tokens, f"{num_xsectional_in_dims} != {num_cov} + {num_context_tokens}"

Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed1.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle
279


In [6]:
# print(dataset_test[0].head())
# print(dataset_test[1][:3])

# Train model

In [7]:
sample_sizes = [20000]

# the time grid which we generate over
t_eval = np.linspace(0, 1, 1000) 
# the time grid which we calculate scores over
time_grid = np.linspace(start=0, stop=1 , num=300)

In [12]:
model_names, all_ctd, all_ibs, all_inbll = [], [], [], []
all_obs_RMST, all_pred_RMST = [], []

for sample_size in sample_sizes:

    seeds = [1,2,3,4,5]

    for seed in seeds:
        # Load dataset
        dataset_train, dataset_val, dataset_test = get_dataloaders("MultiMorbidity50+", False, benchmark="sklearn_RSF", sample_size=sample_size, seed=seed)

        # Create RSF model with default bootstrap values due to memory constraints
        model_name = f"RandomSurvivalForest-SR-MultiMorbidity50+-Ns{sample_size}-seed{seed}"
        rsf = RandomSurvivalForest(
            bootstrap=True,
            max_samples=1000,    
            random_state=seed,
            low_memory=False
        )
        print(f"\n\n{model_name}")

        # Train model
        print("Training")
        rsf.fit(dataset_train[0], dataset_train[1])
    
        # Test
        bsz = 2**8
        print(f"Evaluating performance by splitting {dataset_test[0].shape} test samples into batches of size {bsz}")
        ctd = []
        ibs = []
        inbll = []
        obs_RMST_by_number_of_preexisting_conditions = [[] for _ in range(len(encoded_conditions_xsec))]
        pred_RMST_by_number_of_preexisting_conditions = [[] for _ in range(len(encoded_conditions_xsec))]
        test_generator = range(0, dataset_test[0].shape[0], bsz)
        for batch_idx in tqdm(test_generator, total=(len(test_generator)), desc="Testing"):
        
            batch_dataset_test = (dataset_test[0][batch_idx:batch_idx + bsz], dataset_test[1][batch_idx:batch_idx + bsz])
            actual_bsz = batch_dataset_test[0].shape[0]
        
            # Predict survival functionfor batch
            surv = rsf.predict_survival_function(batch_dataset_test[0], return_array=True)
            
            # Find the indices in rsf.unique_times_ that are closest to values in t_eval, so we can evaluate the RSF if the same way as other benchmarks
            closest_indices = [np.abs(rsf.unique_times_ - v).argmin() for v in t_eval]
            surv_reduced = surv[:, closest_indices]

            lbls_test = np.zeros((actual_bsz,))
            t_test = np.zeros((actual_bsz,))
            for sample_idx in range(actual_bsz):
                lbls_test[sample_idx] = 1 if batch_dataset_test[1][sample_idx][0] == True else 0
                t_test[sample_idx] = batch_dataset_test[1][sample_idx][1]
                
            ###########################
            # Get RMST Survival times #
            ###########################                
            for sample in range(surv_reduced.shape[0]):
                x_test = batch_dataset_test[0].to_numpy()[sample, :]
                e_test = batch_dataset_test[1][sample][0]
                
                # Get the number of pre-existing conditions
                sample_stratification_label = np.sum(x_test[encoded_conditions_xsec] == 1)
                # Get the RMST predicted under the survival curve
                sample_predicted_rmst = trapz(surv_reduced[sample,:], t_eval)
                pred_RMST_by_number_of_preexisting_conditions[sample_stratification_label].append(sample_predicted_rmst)

                # Get the observed RMST - warning: this can never properly account for censoring
                if e_test:
                    # If outcome was observed
                    sample_approx_obs_rmst = np.min((1, t_test[sample]))
                else:
                    sample_approx_obs_rmst = 1
                obs_RMST_by_number_of_preexisting_conditions[sample_stratification_label].append(sample_approx_obs_rmst)

        
            ########################
            # Get survival metrics #
            ########################
            df_surv = pd.DataFrame(np.transpose(surv_reduced), index=t_eval)
            ev = EvalSurv(df_surv, t_test, lbls_test, censor_surv='km')             # Same treatment as in SurvivEHR
            # Log overall scores
            ctd.append(ev.concordance_td())
            ibs.append(ev.integrated_brier_score(time_grid))
            inbll.append(ev.integrated_nbll(time_grid))
        
            # print(f"Scores up to sample {batch_idx+bsz}:".ljust(50) + f"Ctd: {np.mean(ctd):.3f}. IBS: {np.mean(ibs):.4f}. INBLL: {np.mean(inbll):.3f}")
        
        ctd = np.mean(ctd)
        ibs = np.mean(ibs)
        inbll = np.mean(inbll)
        
        print(f"{model_name}:".ljust(20) + f"N={sample_size}.".ljust(15) + f"Ctd: {ctd}. IBS: {ibs}. INBLL: {inbll}")
        model_names.append(model_name)
        all_ctd.append(ctd)
        all_ibs.append(ibs)
        all_inbll.append(inbll)
        all_pred_RMST.append(pred_RMST_by_number_of_preexisting_conditions)
        all_obs_RMST.append(obs_RMST_by_number_of_preexisting_conditions)

        # print(f"\tRandom Survival Forest ({'CR' if competing_risk else 'SR'}):".ljust(20) + f"N={sample_size}.".ljust(15) + f"Ctd: {ctd}. IBS: {ibs}. INBLL: {inbll}")


Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed1.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed1
Training
Evaluating performance by splitting (107557, 279) test samples into batches of size 256


Testing: 100%|██████████| 421/421 [04:05<00:00,  1.72it/s]


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed1:N=20000.       Ctd: 0.5834264896399975. IBS: 0.15459947151900016. INBLL: 0.46990139389813723
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed2.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed2
Training
Evaluating performance by splitting (107557, 279) test samples into batches of size 256


Testing: 100%|██████████| 421/421 [04:34<00:00,  1.54it/s]


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed2:N=20000.       Ctd: 0.5853854477013098. IBS: 0.15429969800844914. INBLL: 0.4690151566676098
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed3.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed3
Training
Evaluating performance by splitting (107557, 279) test samples into batches of size 256


Testing: 100%|██████████| 421/421 [03:32<00:00,  1.98it/s]


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed3:N=20000.       Ctd: 0.5851260399538806. IBS: 0.15412524528024846. INBLL: 0.46867952099052923
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed4.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed4
Training
Evaluating performance by splitting (107557, 279) test samples into batches of size 256


Testing: 100%|██████████| 421/421 [04:28<00:00,  1.57it/s]


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed4:N=20000.       Ctd: 0.5838205168142997. IBS: 0.1542414519560594. INBLL: 0.46888269152341966
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed5.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed5
Training
Evaluating performance by splitting (107557, 279) test samples into batches of size 256


Testing: 100%|██████████| 421/421 [04:54<00:00,  1.43it/s]

RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed5:N=20000.       Ctd: 0.5840655791806779. IBS: 0.154206618938596. INBLL: 0.46882119779176207





In [13]:
mean_pred_RMST = [[np.mean(_i) if len(_i) > 0 else np.nan for _i in _pred_RMST] for _pred_RMST in all_pred_RMST]
mean_obs_RMST = [[np.mean(_i) if len(_i) > 0 else np.nan for _i in _obs_RMST] for _obs_RMST in all_obs_RMST]

num_pre_existing = np.arange(len(all_obs_RMST[0]))

plt.close()
for _mean_pred_RMST in mean_pred_RMST:
    plt.plot(num_pre_existing[:10], _mean_pred_RMST[:10], color='b')
    
plt.plot(num_pre_existing[:10], mean_obs_RMST[0][:10], color='k')   # these are evaluated on the `all` the test data - which is shared across subsampled datasets 
plt.xlabel("Number of pre-existing conditions")
plt.ylabel("Survival time")
plt.savefig("calibration_rsf.png")

In [14]:
print(f"\nModel names: \n\t {model_names}")
print(f"\nConcordance (time-dependent): \n\t {all_ctd}")
print(f"\nIntegrated Brier Score: \n\t {all_ibs}")
print(f"\nINBLL: \n\t {all_inbll}")
print(f"\nNaive observed RMST: \n\t {mean_obs_RMST[0][:10]}")
print(f"\nPredicted RMST:")
for _mean_pred_RMST in mean_pred_RMST:
    print(f"\t {_mean_pred_RMST[:10]}")


Model names: 
	 ['RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed1', 'RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed2', 'RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed3', 'RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed4', 'RandomSurvivalForest-SR-MultiMorbidity50+-Ns20000-seed5']

Concordance (time-dependent): 
	 [0.5834264896399975, 0.5853854477013098, 0.5851260399538806, 0.5838205168142997, 0.5840655791806779]

Integrated Brier Score: 
	 [0.15459947151900016, 0.15429969800844914, 0.15412524528024846, 0.1542414519560594, 0.154206618938596]

INBLL: 
	 [0.46990139389813723, 0.4690151566676098, 0.46867952099052923, 0.46888269152341966, 0.46882119779176207]

Naive observed RMST: 
	 [0.8224276887173136, 0.7990325691775028, 0.7849306412327886, 0.7729643687595084, 0.7592866692082013, 0.7325878823314543, 0.7207588323375635, 0.6876065383604424, 0.757309837544218, 0.6978854766258826]

Predicted RMST:
	 [0.798877246538506, 0.7751059552144429, 0.7619923814891338, 0

# Output across different setups

Hypertension Single Risk

In [5]:
from statistics import NormalDist

def confidence_interval(data, confidence=0.95):
  dist = NormalDist.from_samples(data)
  z = NormalDist().inv_cdf((1 + confidence) / 2.)
  h = dist.stdev * z / ((len(data) - 1) ** .5)
  return dist.mean - h, dist.mean + h, h

data = [0.5834264896399975, 0.5853854477013098, 0.5851260399538806, 0.5838205168142997, 0.5840655791806779]
print(np.mean(data))
print(confidence_interval(data))

data = [0.15459947151900016, 0.15429969800844914, 0.15412524528024846, 0.1542414519560594, 0.154206618938596]
print(np.mean(data))
print(confidence_interval(data))

data =  [0.46990139389813723, 0.4690151566676098, 0.46867952099052923, 0.46888269152341966, 0.46882119779176207]
print(np.mean(data))
print(confidence_interval(data))

0.5843648146580331
(0.5835322030735344, 0.5851974262425319, 0.0008326115844988198)
0.15429449714047064
(0.15411632696594102, 0.15447266731500026, 0.00017817017452961926)
0.4690599921742916
(0.46858412446163566, 0.46953585988694746, 0.0004758677126559015)
