# Demo Notebook:
## DeepHit



Note: requires additional modules
* tensorflow

In [2]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [11]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf

# from pycox.evaluation import EvalSurv
# from scipy.integrate import trapz
# import math

from FastEHR.dataloader import FoundationalDataModule

from CPRD.examples.modelling.benchmarks.make_method_loaders import get_dataloaders
from CPRD.examples.modelling.SurvivEHR.custom_outcome_methods import custom_mm_outcomes
from CPRD.examples.modelling.benchmarks.DeepHit.train_deephit import run_experiment

# import torchtuples as tt # Some useful functions
# from pycox.models import DeepHitSingle

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cpu.


# Extract the indicies which relate to the diagnoses

In [9]:
with initialize(version_base=None, config_path="../../SurvivEHR/confs", job_name="deephit-mm-notebook"):
    cfg = compose(config_name="config_CompetingRisk11M")

dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/",
                            overwrite_meta_information=cfg.data.meta_information_path,
                            load=True)

print(get_tokens_for_stratification(dm, custom_mm_outcomes))

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhal

[17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 33, 34, 35, 36, 37, 38, 41, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 58, 59, 60, 61, 63, 64, 67, 71, 72, 73, 75, 77, 80, 82, 85, 89, 90, 93, 94, 96, 97, 98, 104, 106, 108, 109, 110, 112, 115, 119, 121, 129, 135, 138, 141, 144, 148, 149, 151]


# Train model

In [18]:
num_nodes = [32, 32]
batch_norm = True
dropout = 0.1
epochs = 100
batch_size = 256
bins = 200                                                   # Default of 10: did very poorly. Increasing, improving results.
seeds = [1,2,3,4,5]
# t_eval = np.linspace(0, 1, 1000)                           # the time grid which we generate over
time_grid = np.linspace(start=0, stop=1 , num=300)           # the time grid which we calculate scores over

model_names = []
all_ctd = []
all_ibs = []
all_inbll = []
all_obs_RMST = [] 
all_pred_RMST = []
for seed in seeds:

    # Load data
    dataset_train, dataset_val, dataset_test, meta_information = get_dataloaders("MultiMorbidity50+", False, "deephit", sample_size=20000, seed=seed, bins=bins)

    # Train benchmark
    result_dict = run_experiment(dataset_train, dataset_val, dataset_test, meta_information,
                                 dm=dm, custom_outcomes_method=custom_mm_outcomes)
    
    print(result_dict)

    # Record
    model_names.append( f"DeepHit-SR-MultiMorbidity50+-Ns{20000}-seed{seed}")
    all_ctd.append(result_dict["ctd"])
    all_ibs.append(result_dict["ibs"])
    all_inbll.append(result_dict["inbll"])
    all_pred_RMST.append(result_dict["pred_RMST"])
    all_obs_RMST.append(result_dict["approx_obs_RMST"])
    


Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/N=20000_seed1.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/benchmark_data/all.pickle
lr_finder best lr: 0.0001
setting to lr: 0.001
0:	[9s / 9s],		train_loss: 0.7848,	val_loss: 0.7548
1:	[7s / 16s],		train_loss: 0.7697,	val_loss: 0.7470
2:	[2s / 19s],		train_loss: 0.7576,	val_loss: 0.7425
3:	[7s / 27s],		train_loss: 0.7518,	val_loss: 0.7391
4:	[3s / 30s],		train_loss: 0.7492,	val_loss: 0.7353
5:	[2s / 32s],		train_loss: 0.7374,	val_loss: 0.7323
6:	[2s / 34s],		train_loss: 0.7356,	val_loss: 0.7313
7:	[2s / 36s],		train_loss: 0.7324,	val_loss: 0.7303
8:	[2s / 39s],		train_loss: 0.7287,	val_loss: 0.7298
9:	[3s / 42s],		train_loss: 0.7290,	val_loss: 0.7307
10:	[2s / 44s],		train_loss: 0.7239,	val_loss: 0.7302
11:	[2s / 46s],		train_loss:

In [21]:
num_pre_existing = np.arange(len(all_obs_RMST[0]))

plt.figure()
for pred_RMST_seed in all_pred_RMST:
    plt.plot(num_pre_existing[:10], pred_RMST_seed[:10], color='b')
    
plt.plot(num_pre_existing[:10], all_obs_RMST[0][:10], color='k')   # these are evaluated on the `all` the test data - which is shared across subsampled datasets 
plt.xlabel("Number of pre-existing conditions")
plt.ylabel("Survival time")
plt.savefig("figs/calibration_deephit.png")

In [22]:
import scipy.stats

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h, h
    
print(mean_confidence_interval(all_ctd))
print(mean_confidence_interval(all_ibs))
print(mean_confidence_interval(all_inbll))

(0.5851262186769184, 0.58067685718006, 0.5895755801737768, 0.004449361496858366)
(0.15952456146634558, 0.15873542210751018, 0.16031370082518098, 0.0007891393588354023)
(0.4817130012522169, 0.47973916052329973, 0.48368684198113404, 0.0019738407289171295)
