# Demo Notebook:
## SurvivEHR: Competing Risk Survival Transformer For Causal Sequence Modelling 

In this notebook we demonstrate how a pre-trained model can be used for generation

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvivEHR/notebooks/CompetingRisk/0_pretraining


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

from FastEHR.dataloader import FoundationalDataModule

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Demo Version of SurvStreamGPT

## Build configurations

In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=["experiment.run_id='SurvivEHR-cr-small-v1'",
                             "experiment.train=False",
                             "experiment.test=False",
                             "data.path_to_ds=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/",
                             "optim.limit_test_batches=null"
                            ]
                 )     

cfg.data.min_workers = 12

# Just load in pretrained model
cfg.experiment.log = False

print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 12
  global_diagnoses: false
  repeating_events: true
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
  subsample_training: null
experiment:
  type: pre-train
  project_name: SurvivEHR
  run_id: SurvivEHR-cr-small-v1
  fine_tune_id: null
  notes: null
  tags: null
  train: false
  test: false
  verbose: true
  seed: 1337
  log: false
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes: null
fine_tuning:
  custom_stratification_method:
    _targe

In [4]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

model, dm_pretrain = run(cfg)     
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")


INFO:root:Running cr on 72 CPUs and 1 GPUs


env: SLURM_NTASKS_PER_NODE=28


INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/. This will be loaded in causal form.
INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular 

Loaded model with 11.211302 M parameters


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [5]:
# Update dataset path to point to the new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"

# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/Fou

265 vocab elements


In [25]:
import pickle
with open("/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/practice_id_splits.pickle", 'rb') as f:
    splits = pickle.load(f)
type(splits["test"][0])

int

In [27]:
from FastEHR.database.collector import SQLiteDataCollector
collector = SQLiteDataCollector(cfg.data.path_to_db + "cprd.db")
collector.connect()
collector.disconnect()

Full example from the fine-tuning dataset. This includes all events up and including the index event, followed by the outcome or last seen event. This outcome/censoring event is not relevant to this notebook.

In [28]:
dm.train_set.view_sample(100, max_dynamic_events=None, report_time=True)

# for batch in dm.train_dataloader():
#     break
# print(batch.keys())


Time to retrieve sample index 100: 0.3054 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1972.0
Sequence of 256 events

Token                                                                       | Age at event (days)         | Standardized value
Eosinophil_count_21                                                        | 12366                         | -0.46
Haematocrit_15                                                             | 12366                         | -0.13
Haemoglobin_estimation_9                                                   | 12366                         | -0.49
Lymphocyte_count_20                                                        | 12366                         | -0.49
Mean_corpusc_Hb_conc__MCHC__14                                             | 12366                         | -0.46
Mean_corpusc_haemoglobin_MCH__13                                           | 12366                         | -0.25
M

# Generation from real prompts

In [29]:
# Utility function

# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

def table(_tokens,_ages,_values):
    # print table rows 
    assert _tokens.shape[0] == 1
    assert _ages.shape[0] == 1
    assert _values.shape[0] == 1
    
    for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(_tokens[0].tolist()).split(" "), 
                                                    _ages[0, :], 
                                                    _values[0, :]
                                                    )
                                                ):
        _value = dm.unstandardise(_cat, _value)
        print(f"\t{_cat}".ljust(60) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}


indexing_token_to_pivot_on = dm.encode(["TYPE2DIABETES"])[0]
stop_token =  dm.encode(["DEATH"])[0]

For a number of patients, split by a `document break`, 
1) take the context before the index event day and predict what will be the subsequent series of events
2) take the context up to and including the index event day, and predict what will be the subsequent series of events

We can then observe how the index event correlated with future time-series predictions

In [33]:
for idx_sample, sample in enumerate(dm.test_set):

    # Get the sample
    # sample = dm.test_set[_patient_idx]
    _index = (sample["tokens"] == indexing_token_to_pivot_on).nonzero(as_tuple=True)[0].item()

    # chunk by day
    _day_at_index = int(sample["ages"][_index])
    _index_pre = sum(sample["ages"] < _day_at_index)
    _index_inc = sum(sample["ages"] <= _day_at_index)
    
    for _phase, _split_at in enumerate([_index_pre, _index_inc]):

        if _phase == 0:
            print(f"\n\nBefore {dm.decode([indexing_token_to_pivot_on]).lower()} is seen in the medical history")
        else:
            print('\n------------------------------------ page break ------------------------------------')
            print(f"\n\nAfter the diagnosis of {dm.decode([indexing_token_to_pivot_on]).lower()} is then seen in the medical history")

        _covariates = sample["static_covariates"].reshape((1,-1))
        _tokens = sample["tokens"][:_split_at].reshape((1,-1))
        _ages = sample["ages"][:_split_at].reshape((1,-1))
        _values = sample["values"][:_split_at].reshape((1,-1))

        # Report the initial part of their historical context
        _dec_covariates = dm.train_set._decode_covariates(_covariates)
        print(f"\n\nMedical history of a \n\t" + \
                        f"{_dec_covariates['ETHNICITY'][0].lower()}, " + \
                        f"{'male' if _dec_covariates['SEX'][0] == 'M' else 'female'} patient, " + \
                        f"born in {int(_dec_covariates['birth_year'][0])}, " + \
                        f"with IMD (deprivation) level {int(_dec_covariates['IMD'][0])}. \n\n" 
              )
        table(_tokens, _ages * 1825, _values)

        
        # Predict the future and report
        new_tokens, new_ages, new_values = model.model.generate(_tokens.to(device), _ages.to(device), _values.to(device), _covariates.to(device), max_new_tokens=20, eos_token=stop_token)
        print(f"""\nSurvivEHR then predicts the next events to be:
               """)
        table(new_tokens[:, _tokens.shape[1]:].reshape((1,-1)), 
              new_ages[:, _tokens.shape[1]:].reshape((1,-1)) * 1825,
              new_values[:, _tokens.shape[1]:].reshape((1,-1))
             )

    print('\n----------------------------------------------------------------------------------------')
    print('------------------------------------ document break ------------------------------------')
    print('----------------------------------------------------------------------------------------')

    if idx_sample > 20:
        break





Before type2diabetes is seen in the medical history


Medical history of a 
	black, male patient, born in 1968, with IMD (deprivation) level 5. 


	SUBSTANCEMISUSE                                            nan            at age 26 (9605 days)

SurvivEHR then predicts the next events to be:
               
	Ex_smoker_84                                               158.05         at age 26 (9605 days)
	Ex_smoker_84                                               -10.70         at age 26 (9605 days)
	Diastolic_blood_pressure_5                                 71.52          at age 26 (9605 days)
	Systolic_blood_pressure_4                                  116.83         at age 26 (9605 days)
	Anxiolytics_mumpredict                                     nan            at age 26 (9605 days)
	Benzodiazepines                                            nan            at age 26 (9605 days)
	SSRIs_Optimal                                              nan            at age 26 (9605 days)
	SSRIs_Opti

IndexError: index -1 is out of bounds for dimension 1 with size 0

In [None]:
# dm.test_set.view_sample(10, max_dynamic_events=None, report_time=True)

# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input 2_generation_fine_tune.ipynb