# Demo Notebook:
## SurvivEHR: Competing Risk Survival Transformer For Causal Sequence Modelling 

In this notebook we demonstrate how a pre-trained model can be used for generation

In [None]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.run_experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

# Demo Version of SurvStreamGPT

## Build configurations

In [None]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=["experiment.type='zero-shot'",
                             "experiment.run_id='CR_11M'",
                             "experiment.train=False",
                             "experiment.test=False",
                             "data.path_to_ds=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/",
                             "optim.limit_test_batches=null"
                            ]
                 )     

cfg.data.min_workers = 12

# Just load in pretrained model
cfg.experiment.log = False

print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

In [1]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

model, dm_pretrain = run(cfg)     
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")


env: SLURM_NTASKS_PER_NODE=28


NameError: name 'run' is not defined

In [4]:
# Update dataset path to point to the new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"

# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_C

265 vocab elements


In [15]:
import pickle
with open("/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/practice_id_splits.pickle", 'rb') as f:
    splits = pickle.load(f)
type(splits["test"][0])

int

In [10]:
from CPRD.data.dataset.collector import SQLiteDataCollector
collector = SQLiteDataCollector(cfg.data.path_to_db + "cprd.db")
collector.connect()
collector.disconnect()

Full example from the fine-tuning dataset. This includes all events up and including the index event, followed by the outcome or last seen event. This outcome/censoring event is not relevant to this notebook.

In [6]:
dm.train_set.view_sample(100, max_dynamic_events=None, report_time=True)

# for batch in dm.train_dataloader():
#     break
# print(batch.keys())


Time to retrieve sample index 100 was 0.23813700675964355 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1972.0
Sequence of 128 events

Token                                                                      | Age               | Standardised value
Red_blood_cell__RBC__count_10                                              | 13971             | -0.49             
Serum_HDL_cholesterol_level_100                                            | 13971             | -0.18             
Serum_LDL_cholesterol_level_102                                            | 13971             | -0.39             
Serum_TSH_level_71                                                         | 13971             | -0.50             
Serum_albumin_51                                                           | 13971             | -0.29             
Serum_alkaline_phosphatase_50                                              | 13971             | -0.50    

# Generation from real prompts

In [7]:
# Utility function

# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

def table(_tokens,_ages,_values):
    # print table rows 
    assert _tokens.shape[0] == 1
    assert _ages.shape[0] == 1
    assert _values.shape[0] == 1
    
    for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(_tokens[0].tolist()).split(" "), 
                                                    _ages[0, :], 
                                                    _values[0, :]
                                                    )
                                                ):
        _value = dm.unstandardise(_cat, _value)
        print(f"\t{_cat}".ljust(60) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}


indexing_token_to_pivot_on = dm.encode(["TYPE2DIABETES"])[0]
stop_token =  dm.encode(["DEATH"])[0]

For a number of patients, split by a `document break`, 
1) take the context before the index event day and predict what will be the subsequent series of events
2) take the context up to and including the index event day, and predict what will be the subsequent series of events

We can then observe how the index event correlated with future time-series predictions

In [10]:
for idx_sample, sample in enumerate(dm.test_set):

    # Get the sample
    # sample = dm.test_set[_patient_idx]
    _index = (sample["tokens"] == indexing_token_to_pivot_on).nonzero(as_tuple=True)[0].item()

    # chunk by day
    _day_at_index = int(sample["ages"][_index])
    _index_pre = sum(sample["ages"] < _day_at_index)
    _index_inc = sum(sample["ages"] <= _day_at_index)
    
    for _phase, _split_at in enumerate([_index_pre, _index_inc]):

        if _phase == 0:
            print(f"\n\nBefore {dm.decode([indexing_token_to_pivot_on]).lower()} is seen in the medical history")
        else:
            print('\n------------------------------------ page break ------------------------------------')
            print(f"\n\nAfter the diagnosis of {dm.decode([indexing_token_to_pivot_on]).lower()} is then seen in the medical history")

        _covariates = sample["static_covariates"].reshape((1,-1))
        _tokens = sample["tokens"][:_split_at].reshape((1,-1))
        _ages = sample["ages"][:_split_at].reshape((1,-1))
        _values = sample["values"][:_split_at].reshape((1,-1))

        # Report the initial part of their historical context
        _dec_covariates = dm.train_set._decode_covariates(_covariates)
        print(f"\n\nMedical history of a \n\t" + \
                        f"{_dec_covariates['ETHNICITY'][0].lower()}, " + \
                        f"{'male' if _dec_covariates['SEX'][0] == 'M' else 'female'} patient, " + \
                        f"born in {int(_dec_covariates['birth_year'][0])}, " + \
                        f"with IMD (deprivation) level {int(_dec_covariates['IMD'][0])}. \n\n" 
              )
        table(_tokens, _ages, _values)

        
        # Predict the future and report
        new_tokens, new_ages, new_values = model.generate(_tokens.to(device), _ages.to(device), _values.to(device), _covariates.to(device), max_new_tokens=20, eos_token=stop_token)
        print(f"""\nSurvivEHR then predicts the next events to be:
               """)
        table(new_tokens[:, _tokens.shape[1]:].reshape((1,-1)), 
              new_ages[:, _tokens.shape[1]:].reshape((1,-1)),
              new_values[:, _tokens.shape[1]:].reshape((1,-1))
             )

    print('\n----------------------------------------------------------------------------------------')
    print('------------------------------------ document break ------------------------------------')
    print('----------------------------------------------------------------------------------------')

    if idx_sample > 20:
        break





Before type2diabetes is seen in the medical history


Medical history of a 
	black, male patient, born in 1968, with IMD (deprivation) level 5. 


	SUBSTANCEMISUSE                                            nan            at age 26 (9605 days)
	SCHIZOPHRENIAMM_V2                                         nan            at age 36 (12977 days)

SurvivEHR then predicts the next events to be:
               
	ACE_Inhibitors_D2T                                         nan            at age 36 (13138 days)
	Serum_albumin_51                                           36.11          at age 37 (13341 days)
	Tricyclic_Antidepressants_final                            nan            at age 37 (13372 days)
	Lipid_lowering_drugs_Optimal                               nan            at age 37 (13676 days)
	O_E___weight_2                                             92.05          at age 38 (13764 days)
	Diastolic_blood_pressure_5                                 71.18          at age 38 (13856 days)
	COC

In [None]:
# dm.test_set.view_sample(10, max_dynamic_events=None, report_time=True)

# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input 2_generation_fine_tune.ipynb