# Demo Notebook:
## Zero-shot evaluation the Competing Risk Survival Transformer For Causal Sequence Modelling.

Evaluating the pre-trained model on a cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population.

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.run_experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule, convert_batch_to_none_causal
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from CPRD.examples.modelling.SurvStreamGPT.setup_zeroshot_experiment import setup_zeroshot_experiment, ZeroShotExperiment

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


## Load configurations
Modifying as necesssary for zero-shot application. 

Here we choose to load in the configuration for a small **pre-trained** 11M parameter model, named "CR_11M". We specfiy the zero-shot experiment type, which will lead to running the ```ZeroShotExperiment```. 
We tell this experiment that no further training is needed. Additionally, we do choose to perform testing. As this is a causal model, this would not test the ability to predict the outcomes of interest, but to perform the causal modelling task on the new dataset. This also allows us to implement outcome predictions as a callback hook at the end of testing.

Instead, we want to test the pre-trained model's capacity to predict the relative risk of outcomes. Here, this is to test will check for the pre-trained model's capacity to predict ```COPD``` and ```SUBSTANCEMISUSE```. To do this we point the experiment to the ```FineTune_CVD``` dataset, and set the outcomes of interest. This is performed internally in ```setup_zeroshot_experiment``` through callbacks, in the next cell.

In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=["experiment.type='zero-shot'",
                             "experiment.run_id='CR_11M'",
                             "experiment.train=False",
                             "experiment.test=True",
                             'experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]',
                             "data.path_to_ds=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_stillwithsingleeventpatients/",
                             "optim.limit_test_batches=null"
                            ]
                 )     

cfg.data.batch_size=1024

print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

is_decoder: true
data:
  batch_size: 1024
  unk_freq_threshold: 0.0
  min_workers: 20
  global_diagnoses: false
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_stillwithsingleeventpatients/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
experiment:
  type: zero-shot
  project_name: SurvStreamGPT_${head.SurvLayer}
  run_id: CR_11M
  train: false
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes:
  - IHDINCLUDINGMI_OPTIMALV2
  - ISCHAEMICSTROKE_V2
  - MINFARCTION
  - STROKEUNSPECIFIED_V2
  - STROKE_HAEMRGIC
optim:
  num_epoch

# Load Experiment

In [17]:
model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_stillwithsingleeventpatients/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_stillwithsingleeventpatients/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_stillwithsingleeventpatients/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Loaded /rds/projects/g

Testing: |          | 0/? [00:00<?, ?it/s]

                                		 Bad sample tokens: tensor([[ 95,   0,   0,   0,   0],
        [175,   0,   0,   0,   0]], device='cuda:0')
                                		 and corresponding ages tensor([[22467,     0,     0,     0,     0],
        [33030,     0,     0,     0,     0]], device='cuda:0')
                                		 Bad sample tokens: tensor([[106,   0,   0,   0,   0]], device='cuda:0')
                                		 and corresponding ages tensor([[24484,     0,     0,     0,     0]], device='cuda:0')
                                		 Bad sample tokens: tensor([[263,   0,   0,   0,   0]], device='cuda:0')
                                		 and corresponding ages tensor([[30467,     0,     0,     0,     0]], device='cuda:0')
                                		 Bad sample tokens: tensor([[249,   0,   0,   0,   0]], device='cuda:0')
                                		 and corresponding ages tensor([[29741,     0,     0,     0,     0]], device='cuda:0')


Loaded model with 11.433294 M parameters


In [None]:
import wandb
wandb.finish()

In [13]:
dm.encode(['Diastolic_blood_pressure_5.png'])
display(dm.decode([95, 175, 263,249]).split(" "))

['IHDINCLUDINGMI_OPTIMALV2',
 'Gabapentin_Oral_OPTIMAL',
 'Statins',
 'Thiazide_Diuretics_v2']

In [None]:
dm.tokenizer._event_counts["EVENT"][-5:].to_list()

# Load Pre-Trained model

In [None]:
ckpt_path = cfg.experiment.log_dir + f'checkpoints/{cfg.experiment.run_id}.ckpt'
model = SurvivalExperiment.load_from_checkpoint(ckpt_path)

# Initialise fine-tuning data module

In [4]:
# Update dataset path to point to the new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_badoutcomelist/"

# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)

INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_badoutcomelist/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_badoutcomelist/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_badoutcomelist/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD_badoutco

265 vocab elements


In [20]:
# import pickle as pkl
# import pathlib

pkl_file_to_amend = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle"

with open(pkl_file_to_amend, 'rb') as pickle_file:
    content = pickle.load(pickle_file)
display(content)

# new_dictionary = {}
# for key in content.keys():
#     str_to_remove = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/split=val/"
#     new_key = str(key)[len(str_to_remove):]
#     new_dictionary[new_key] = content[key]
# display(new_dictionary)


# with open(pkl_file_to_amend, 'wb') as handle:
#     pickle.dump(new_dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)


{'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=20970/CHUNK=135/cbc791c7ef9b43cfaf4dbeea2624abb1-0.parquet': 69,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=20970/CHUNK=132/cbc791c7ef9b43cfaf4dbeea2624abb1-0.parquet': 39,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=20970/CHUNK=133/cbc791c7ef9b43cfaf4dbeea2624abb1-0.parquet': 250,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=20970/CHUNK=134/cbc791c7ef9b43cfaf4dbeea2624abb1-0.parquet': 250,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=21042/CHUNK=58/87a120eadabc4d95ae45ca9b8c79d017-0.parquet': 215,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=21042/CHUNK=59/87a120eadabc4d95ae45ca9b8c79d017-0.parquet': 250,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=21042/CHUNK=61/87a120eadabc4d95ae45ca9b8c79d017-0.parquet': 250,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=21042/CHUNK=62/87a120eadabc4d95ae45ca9b8c79d017-0.parquet': 250,
 'COUNTRY=E/HEALTH_AUTH=London/PRACTICE_ID=21042/CHUNK=63/87a120eadabc4d95ae45ca9b8c79d017-0.parquet': 35,
 'COUNTRY=E/HEALTH_AUTH=Lon

In [None]:
# new_dictionary

In [5]:
start = time.time()   # starting time
for batch in dm.train_dataloader():
    break
print(f"batch loaded in {time.time()-start} seconds")    
    
# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

batch loaded in 104.21096539497375 seconds


In [16]:
display(batch.keys())
display(convert_batch_to_none_causal(batch).keys())

print(batch["static_covariates"].shape)

print(dm.train_set.static_1hot)
print(dm.train_set.static_1hot["SEX"].categories_)
print(dm.train_set.static_1hot["IMD"].categories_)
print(dm.train_set.static_1hot["ETHNICITY"].categories_)

dict_keys(['static_covariates', 'tokens', 'ages', 'values', 'attention_mask', 'target_token', 'target_age_delta', 'target_value'])



dict_keys(['static_covariates', 'tokens', 'ages', 'values', 'attention_mask', 'target_token', 'target_age_delta', 'target_value'])

torch.Size([1022, 16])
{'SEX': OneHotEncoder(), 'IMD': OneHotEncoder(), 'ETHNICITY': OneHotEncoder()}
[array(['F', 'I', 'M'], dtype=object)]
[array([ 1.,  2.,  3.,  4.,  5., nan])]
[array(['ASIAN', 'BLACK', 'MISSING', 'MIXED', 'OTHER', 'WHITE'],
      dtype=object)]


## View an example sample

In [9]:
dm.test_set.view_sample(11003, max_dynamic_events=None, report_time=True)

SEX
IMD
ETHNICITY
Time to retrieve sample index 11003 was 0.22292256355285645 seconds

SEX                 | M
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1959.0
Sequence of 128 events

Token                                                                      | Age               | Standardised value
Pregabalin_Optimal                                                         | 21650             | nan               
Diastolic_blood_pressure_5                                                 | 21663             | -0.29             
International_normalised_ratio_82                                          | 21663             | -0.49             
Systolic_blood_pressure_4                                                  | 21663             | -0.12             
Anticonvulsants_OPTIMAL                                                    | 21670             | nan               
Pregabalin_Optimal                                                         | 21670    

# Custom wrapper prediction last token

To begin with, I will just loop over samples individually to test the zero-shot capacity of SurvivEHR. 

In [None]:


# Verifying on datamodule 
for _idx, batch in enumerate(dm.test_dataloader()):
    if _idx > 10:
        break
    print(_idx)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))
    batch = replace_last_non_pad_with_pad(batch)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))

In [None]:
outcome_of_interest = ["COPD", "SUBSTANCEMISUSE"]
outcome_token = dm.encode(outcome_of_interest)[0]
print(outcome_token)
# print(model(batch))

In [None]:
Hs, labels = [], []
mins,maxes=[],[]
for _idx, batch in enumerate(dm.test_dataloader()):

    batch = replace_last_non_pad_with_pad(batch)
    print(batch["tokens".shape)
    outputs, _, hidden_states = model(batch, is_generation=True)
    print(outputs)
    
    hidden_states = hidden_states.cpu().detach().numpy()                           # (64, 128, 384) 
    Hs.append( hidden_states.reshape(hidden_states.shape[0], -1) )
    labels.append((batch["target_token"] == outcome_token).long().numpy())

    if _idx == 9:
        break



# Visualise hidden dimension labelled by target

In [None]:
import umap
from sklearn.preprocessing import StandardScaler

H = np.concatenate(Hs, 0)
lbl = np.concatenate(labels, 0)

H = StandardScaler().fit_transform(H)
reducer = umap.UMAP()
H_proj = reducer.fit_transform(H)

plt.close()
plt.scatter(H_proj[:,0], H_proj[:,1], c=lbl)
plt.savefig(save_path + f"zero_shot/hidden_umap.png")

In [None]:
print(outputs["surv"]["surv_CDF"][outcome_token].shape)

# The first two tokens in the vocab correspond to the PAD and UNK tokens. There is no CDF corresponding to the PAD token, so the indexing for surv_CDF begins as ["UNK", "ADDISONS_DISEASE", ...]
# print(dm.decode([0,1,2]))

outcomes = ["COPD", "SUBSTANCEMISUSE"]
outcome_tokens = dm.encode(outcomes)

# for outcome in outcomes:
    # observed_outcome_token = dm.encode([outcome])[0]
cdf = np.zeros_like(outputs["surv"]["surv_CDF"][0])
lbls = np.zeros_like(batch["target_token"])

for _outcome_token in outcome_tokens:
    cdf += outputs["surv"]["surv_CDF"][_outcome_token - 1] 
    lbls += (batch["target_token"] == _outcome_token).long().numpy()

plt.close()
cdf_true = cdf[lbls==1,:]
cdf_false = cdf[lbls==0,:]
for i in range(cdf_true.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_true[i,:], c="r", label="outcome occurred next" if i == 0 else None, alpha=1)
for i in range(cdf_false.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_false[i,:], c="k", label="outcome did not occur next" if i == 0 else None, alpha=0.3)

plt.legend(loc=2)
plt.xlabel("days")
plt.ylabel(f"P(t>T) - outcomes={','.join(outcomes)}")
plt.savefig(save_path + f"zero_shot/cdf_outcomes.png")

In [None]:
print(batch["target_token"].unique())
print(len(outputs["surv"]["surv_CDF"]))

In [None]:
dm.decode([2])

In [None]:
outputs["surv"]["surv_CDF"][observed_outcome_token - 1]