# Demo Notebook:
## DeSurv

In [2]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf

from pycox.datasets import support
from pycox.evaluation import EvalSurv
from scipy.integrate import trapz

from FastEHR.dataloader import FoundationalDataModule

from CPRD.src.modules.head_layers.survival.desurv import ODESurvSingle
from CPRD.src.modules.head_layers.survival.desurv import ODESurvMultiple
from CPRD.examples.modelling.benchmarks.make_method_loaders import get_dataloaders
from CPRD.examples.modelling.SurvivEHR.custom_outcome_methods import custom_mm_outcomes

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cpu.


# Extract the indicies which relate to the diagnoses

In [85]:
with initialize(version_base=None, config_path="../../SurvivEHR/confs", job_name="desurv-mm-notebook"):
    cfg = compose(config_name="config_CompetingRisk11M")

dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/",
                            overwrite_meta_information=cfg.data.meta_information_path,
                            load=True)

# Get the indicies for the diagnoses used to stratify patient groups (under the SurvivEHR setup)
conditions = custom_mm_outcomes(dm)
encoded_conditions = dm.tokenizer.encode(conditions)                    # The indicies of the MM events in the xsectional dataset (not adjusted for UNK/PAD/static data)

# Get the number of baseline static variables (after one-hot encoding etc), and the vocab size excluding PAD and UNK tokens
num_cov = dm.train_set[0]["static_covariates"].shape[0]
num_context_tokens = dm.tokenizer._event_counts.shape[0] - 1            # Removing UNK token, which is not included in xsectional datasets

# Convert the `encoded_conditions` indicies to the equivalent in the xsectional dataset
encoded_conditions_xsec = [_ind + num_cov - 1 for _ind in encoded_conditions]
print(encoded_conditions_xsec)

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhal

[17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 41, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 58, 59, 60, 61, 63, 64, 67, 71, 72, 73, 75, 77, 80, 82, 85, 89, 90, 93, 94, 96, 97, 98, 104, 106, 108, 109, 110, 112, 115, 119, 121, 129, 135, 138, 141, 144, 148, 149, 151]


# Load data

## Example/test dataloader usage

In [63]:
data_loader_train, data_loader_val, data_loader_test = get_dataloaders("MultiMorbidity50+", True, sample_size=20000, seed=1)

for batch in data_loader_test:
    x_train = batch[0]
    num_xsectional_in_dims = x_train[0].shape[0]
    average_time_to_event = torch.mean(batch[1])
    targets = batch[2]
    break

print(num_xsectional_in_dims)
assert num_xsectional_in_dims == num_cov + num_context_tokens, f"{num_xsectional_in_dims} != {num_cov} + {num_context_tokens}"

Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed1.pickle
(20000, 279)
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle
279


Error check: The last event in the dm tokenizer index should align to the new last event index

In [65]:
largest_dm_tokenizer_ind = dm.tokenizer._event_counts.shape[0]
# print(dm.tokenizer.decode([264]))
assert largest_dm_tokenizer_ind + num_cov -1 == num_xsectional_in_dims

# Train model

In [67]:
dataset = "MultiMorbidity50+"
competing_risk = False
# sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]      # [3000, 12500, 30000, 60000, 100000]: # 600, 1200, 
sample_sizes = [20000] #283203, 2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 

lr = 1e-3
xdim = x_train.shape[1]

# the time grid which we generate over
t_eval = np.linspace(0, 1, 1000) 
# the time grid which we calculate scores over
time_grid = np.linspace(start=0, stop=1 , num=300)

In [118]:
model_names, all_ctd, all_ibs, all_inbll = [], [], [], []
all_obs_RMST, all_pred_RMST = [], []

for sample_size in sample_sizes:

    seeds = [1,2,3,4,5]

    for seed in seeds:
        # Load dataset
        data_loader_train, data_loader_val, data_loader_test = get_dataloaders(dataset, competing_risk, sample_size=sample_size, seed=seed)
    
        # Initialise model
        model_name = f"DeSurv-{dataset}-Ns{sample_size}-seed{seed}"
        if competing_risk is False:
            model = ODESurvSingle(xdim, [32, 32], device=device)
        else:
            model = ODESurvMultiple(xdim, [32, 32], num_risks=5)
        print(f"\n\n{model_name} with {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

        # Load or train model
        torch.manual_seed(seed)
        try:
            state_dict = torch.load("outputs/MM/" + model_name + "_tst_model")
            model.load_state_dict(state_dict)
            print(f"Loaded previously trained model")
        except:
            print(f"Training")
            model.optimize(data_loader_train, n_epochs=20, logging_freq=1, data_loader_val=data_loader_val, max_wait=2)
            print("finished training")
            torch.save(model.state_dict(), "outputs/MM/" + model_name + "_tst_model")
            model.eval()
    
            # state_dict = torch.load("outputs/MM/" + model_name + "_tst_model")
            # model.load_state_dict(state_dict)
           
        
        # argsortttest = np.argsort(t_test)
        # t_test = t_test[argsortttest]
        # e_test = e_test[argsortttest]
        # x_test = x_test[argsortttest,:]
        
        print(f"Testing")    
        model.eval()
        with torch.no_grad():
    
            ctd = []
            ibs = []
            inbll = []
            obs_RMST_by_number_of_preexisting_conditions = [[] for _ in range(len(encoded_conditions_xsec))]
            pred_RMST_by_number_of_preexisting_conditions = [[] for _ in range(len(encoded_conditions_xsec))]
            for batch in tqdm(data_loader_test, total=(len(data_loader_test)), desc="Testing"):
    
                x_test = batch[0].numpy()
                t_test = batch[1].numpy()
                e_test = batch[2].numpy()
    
                # The normalised grid over which to predict
                t_test_grid = torch.tensor(np.concatenate([t_eval] * x_test.shape[0], 0), dtype=torch.float32)
                x_test_grid = torch.tensor(x_test, dtype=torch.float32).repeat_interleave(t_eval.size, 0)
                
                pred_bsz = 51200
                pred = []
                for x_test_batched, t_test_batched in zip(torch.split(x_test_grid, pred_bsz), torch.split(t_test_grid, pred_bsz)):
                    
                    if competing_risk is False:
                        pred_ = model.predict(x_test_batched, t_test_batched)          # shape: (x_test.batched.shape[0],)
                    else:
                        pred_, pi_  = model.predict(x_test_batched, t_test_batched)    # shape: (x_test.batched.shape[0], num_outcomes)
                    pred.append(pred_)
                        
                pred = torch.concat(pred)
            
                pred = pred.reshape((x_test.shape[0], t_eval.size, -1)).cpu().detach().numpy()
                preds = [pred[:, :, _i] for _i in range(pred.shape[-1])]
                # print([_.shape for _ in preds])
        
                # Merge (additively) each outcome risk curve into a single CDF, and update label for if outcome occurred or not
                cdf = np.zeros_like(preds[0])
                lbls = np.zeros_like(e_test)     
                for _outcome_token in np.unique(e_test)[1:]:
                    # print(f"{_outcome_token} of {np.unique(e_test)[1:]} included from {len(preds)} surv CDFs")
                    # print(_outcome_token)
                    cdf += preds[_outcome_token - 1] 
                    lbls += (e_test == _outcome_token)

                ###########################
                # Get RMST Survival times #
                ###########################                
                surv = 1 - cdf
                for sample in range(surv.shape[0]):
                    # Get the number of pre-existing conditions
                    sample_stratification_label = np.sum(x_test[sample][encoded_conditions_xsec] == 1)
                    # Get the RMST predicted under the survival curve
                    sample_predicted_rmst = trapz(surv[sample,:], t_eval)
                    pred_RMST_by_number_of_preexisting_conditions[sample_stratification_label].append(sample_predicted_rmst)
                    
                    if e_test[sample] != 0:
                        # Get the observed RMST - warning: this is IGNORING CENSORING
                        obs_RMST_by_number_of_preexisting_conditions[sample_stratification_label].append(t_test[sample])

                ########################
                # Get survival metrics #
                ########################
                surv = pd.DataFrame(np.transpose((1 - cdf.reshape((x_test.shape[0],t_eval.size)))), index=t_eval)                
                ev = EvalSurv(surv, t_test, lbls, censor_surv='km')         # Evaluate surv curve with unscaled index with unscaled test times to event 
                # Log overall scores
                ctd.append(ev.concordance_td())
                ibs.append(ev.integrated_brier_score(time_grid))
                inbll.append(ev.integrated_nbll(time_grid))
            
            ctd = np.mean(ctd)
            ibs = np.mean(ibs)
            inbll = np.mean(inbll)
    
            print(f"{model_name}:".ljust(20) + f"N={sample_size}.".ljust(15) + f"Ctd: {ctd}. IBS: {ibs}. INBLL: {inbll}")

        model_names.append(model_name)
        all_ctd.append(ctd)
        all_ibs.append(ibs)
        all_inbll.append(inbll)
        all_pred_RMST.append(pred_RMST_by_number_of_preexisting_conditions)
        all_obs_RMST.append(obs_RMST_by_number_of_preexisting_conditions)


Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed1.pickle
(20000, 279)
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


DeSurv-MultiMorbidity50+-Ns20000-seed1 with 10081 parameters
Loaded previously trained model
Testing


Testing: 100%|██████████| 421/421 [04:53<00:00,  1.43it/s]


DeSurv-MultiMorbidity50+-Ns20000-seed1:N=20000.       Ctd: 0.6011444192666519. IBS: 0.1533536157774033. INBLL: 0.4668451181069194
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed2.pickle
(20000, 279)
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


DeSurv-MultiMorbidity50+-Ns20000-seed2 with 10081 parameters
Loaded previously trained model
Testing


Testing: 100%|██████████| 421/421 [05:20<00:00,  1.31it/s]


DeSurv-MultiMorbidity50+-Ns20000-seed2:N=20000.       Ctd: 0.6013657435469636. IBS: 0.1534244307392567. INBLL: 0.467045335486115
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed3.pickle
(20000, 279)
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


DeSurv-MultiMorbidity50+-Ns20000-seed3 with 10081 parameters
Loaded previously trained model
Testing


Testing: 100%|██████████| 421/421 [05:06<00:00,  1.37it/s]


DeSurv-MultiMorbidity50+-Ns20000-seed3:N=20000.       Ctd: 0.6016623977315547. IBS: 0.1529709001483522. INBLL: 0.4661057071524355
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed4.pickle
(20000, 279)
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


DeSurv-MultiMorbidity50+-Ns20000-seed4 with 10081 parameters
Loaded previously trained model
Testing


Testing: 100%|██████████| 421/421 [04:55<00:00,  1.43it/s]


DeSurv-MultiMorbidity50+-Ns20000-seed4:N=20000.       Ctd: 0.5997379291750893. IBS: 0.15314308465268128. INBLL: 0.46653789613422
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed5.pickle
(20000, 279)
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle


DeSurv-MultiMorbidity50+-Ns20000-seed5 with 10081 parameters
Loaded previously trained model
Testing


Testing: 100%|██████████| 421/421 [05:08<00:00,  1.36it/s]

DeSurv-MultiMorbidity50+-Ns20000-seed5:N=20000.       Ctd: 0.6028073757006384. IBS: 0.15275661753237382. INBLL: 0.46536915735690737





In [135]:
mean_pred_RMST = [[np.mean(_i) if len(_i) > 0 else np.nan for _i in _pred_RMST] for _pred_RMST in all_pred_RMST]
mean_obs_RMST = [[np.mean(_i) if len(_i) > 0 else np.nan for _i in _obs_RMST] for _obs_RMST in all_obs_RMST]

num_pre_existing = np.arange(len(obs_RMST))

plt.close()
for _mean_pred_RMST in mean_pred_RMST:
    plt.plot(num_pre_existing[:10], _mean_pred_RMST[:10], color='b')
    
plt.plot(num_pre_existing[:10], mean_obs_RMST[0][:10], color='k')   # these are evaluated on the `all` the test data - which is shared across subsampled datasets 
plt.xlabel("Number of pre-existing conditions")
plt.ylabel("Survival time")
plt.savefig("calibration_desurv.png")

In [133]:
print(f"\nModel names: \n\t {model_names}")
print(f"\nConcordance (time-dependent): \n\t {all_ctd}")
print(f"\nIntegrated Brier Score: \n\t {all_ibs}")
print(f"\nINBLL: \n\t {all_inbll}")
print(f"\nNaive observed RMST: \n\t {mean_obs_RMST[0][:10]}")
print(f"\nPredicted RMST:")
for _mean_pred_RMST in mean_pred_RMST:
    print(f"\t {_mean_pred_RMST[:10]}")


Model names: 
	 ['DeSurv-MultiMorbidity50+-Ns20000-seed1', 'DeSurv-MultiMorbidity50+-Ns20000-seed2', 'DeSurv-MultiMorbidity50+-Ns20000-seed3', 'DeSurv-MultiMorbidity50+-Ns20000-seed4', 'DeSurv-MultiMorbidity50+-Ns20000-seed5']

Concordance (time-dependent): 
	 [0.6011444192666519, 0.6013657435469636, 0.6016623977315547, 0.5997379291750893, 0.6028073757006384]

Integrated Brier Score: 
	 [0.1533536157774033, 0.1534244307392567, 0.1529709001483522, 0.15314308465268128, 0.15275661753237382]

INBLL: 
	 [0.4668451181069194, 0.467045335486115, 0.4661057071524355, 0.46653789613422, 0.46536915735690737]

Naive observed RMST: 
	 [0.81378, 0.7986956, 0.8275756, 0.81665397, 0.7681157, 0.67035997, 0.6242169, 0.56990874, 0.5747824, 0.49749595]

Predicted RMST:
	 [0.798844688330255, 0.7708229796585668, 0.7570225176336992, 0.7470901657899797, 0.7330107328216403, 0.7157286434126142, 0.7040444357894772, 0.7021473203879012, 0.7107821175735707, 0.6752714445342004]
	 [0.7973564746561381, 0.77319405722141

# Output across different setups

Hypertension Single Risk

In [12]:
from statistics import NormalDist

def confidence_interval(data, confidence=0.95):
  dist = NormalDist.from_samples(data)
  z = NormalDist().inv_cdf((1 + confidence) / 2.)
  h = dist.stdev * z / ((len(data) - 1) ** .5)
  return dist.mean - h, dist.mean + h, h

data = [0.6011444192666519, 0.6013657435469636, 0.6016623977315547, 0.5997379291750893, 0.6028073757006384]
print(np.mean(data))
print(confidence_interval(data))

data = [0.1533536157774033, 0.1534244307392567, 0.1529709001483522, 0.15314308465268128, 0.15275661753237382]
print(np.mean(data))
print(confidence_interval(data))

data =  [0.4668451181069194, 0.467045335486115, 0.4661057071524355, 0.46653789613422, 0.46536915735690737]
print(np.mean(data))
print(confidence_interval(data))

0.6013435730841796
(0.6002630775641054, 0.6024240686042537, 0.0010804955200741837)
0.15312972977001346
(0.1528606396925774, 0.15339881984744952, 0.0002690900774360697)
0.4663806428473194
(0.4657267343591914, 0.46703455133544763, 0.0006539084881281537)
