# Demo Notebook:
## SurvivEHR: Competing Risk Survival Transformer For Causal Sequence Modelling 

In this notebook we demonstrate how a pre-trained model can be used for generation

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvivEHR/notebooks/CompetingRisk/0_pretraining


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from contextlib import redirect_stdout

from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.examples.modelling.SurvivEHR.setup_causal_experiment import CausalExperiment
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

from FastEHR.dataloader import FoundationalDataModule
from FastEHR.database.collector import SQLiteDataCollector

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Demo Version of SurvStreamGPT

In [4]:
# pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr', 'SurvivEHR-cr-v1', 'SurvivEHR-cr-v1-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337',
#                         'SurvivEHR-cr-small-192', "SurvivEHR-cr-small-192-v1"]

pre_trained_model = "SurvivEHR-cr-small-debug7_exp1000-v1" # "SurvivEHR-cr-small-debug3_2_exp1000-v1-v1" #  "SurvivEHR-cr-small-debug3_2_Zero"  "SurvivEHR-cr-small-debug3_2_leadsmall" # 

# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             f"experiment.run_id='{pre_trained_model}'",
                             "experiment.train=False",
                             "experiment.test=False",
                             "experiment.log=False",
                             # Dataloader
                             "data.batch_size=128",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=3",
                             # Model
                             # "transformer.n_embd=1024",
                             # "transformer.n_embd=384",
                             # "transformer.n_embd=192",
                             # "transformer.n_embd=2304",
                             # "transformer.block_size=512", 
                            ]
                 )     
 
%env SLURM_NTASKS_PER_NODE=28       # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
experiment, dm = run(cfg)     
print(f"Loaded model with {sum(p.numel() for p in experiment.parameters())/1e6} M parameters")


INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/. This will be loaded in causal form.
INFO:root:Creating unsupervised collator for DataModule


env: SLURM_NTASKS_PER_NODE=28       # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 1337
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,613,894 samp

Loaded model with 11.20919 M parameters


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [5]:
print(pre_trained_model)

os.makedirs(f"figs/generation/{pre_trained_model}/", exist_ok=True) 

samples_of_interest = [6,0,10,1]

for dataset, idx in zip([ "PreTrain", "FineTune_Hypertension", "FineTune_CVD", "FineTune_MultiMorbidity50+"],
                        samples_of_interest
                       ):
    
    with open(f'figs/generation/{pre_trained_model}/dataset{dataset}_patient{idx}.txt', 'w') as f:
        with redirect_stdout(f):
    
            dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                                        path_to_ds=f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/{dataset}/",
                                        overwrite_meta_information=cfg.data.meta_information_path,
                                        load=True,
                                        # max_seq_length=512,
                                        )
            
            # Get first sample from batch (shuffle is by default turned off in test data, so first sample should equate to dm.test_set[0] above - uncomment print to verify)
            # This is done on batches as the API is written for batches
            for batch in dm.test_dataloader():
                # Convert to supervised (as these are Clinical Prediction Model datasets) to strip final observation which did not occur consecutively
                if dataset != "PreTrain":
                    batch = dm.collate_fn.convert_to_supervised(batch, supervised_time_scale=1.0)
                break
        
            static = dm.test_set._decode_covariates(batch["static_covariates"][idx])
            print("STATIC INFORMATION")
            print("="*120)
            for key, item in static.items():
                print(f"\t{key}:".ljust(20) + f"{item[0]}")
            
            # Get the idx sample from batch
            ################################
            for key, value in batch.items():
                # Take only first element (generating over batches is not implemented)
                if len(batch[key].shape) > 1:
                    batch[key] = batch[key][[idx],:]
                # Move all to same device as model
                batch[key] = batch[key].to(device)
        
            mask = (batch["tokens"][0, :] != 0)
            batch["tokens"] = batch["tokens"][:, mask]
            batch["ages"] = batch["ages"][:, mask]
            batch["values"] = batch["values"][:, mask]

            # Get the unpadded sequence length
            true_seq_len = batch["tokens"].shape[1]

            # Generate forward
            tokens, ages, values, survs = experiment.model.generate(**batch, max_new_tokens=50)

            # 
            tokens = tokens[0, :]
            ages = ages[0, :]
            values = values[0, :]
        
            # Report
            tokens = dm.tokenizer.decode(tokens.tolist()).split(" ")
            diagnoses = []
            last_age = 0
            print("\n\nGiven patient context".upper())
            print("="*120)
            for idx_event, (token, age, value) in enumerate(zip(tokens, ages, values)):
        
                # Unscale age and bin to week fidelity
                age = int(age * dm.test_set.time_scale / 52) 
        
                # If new event create break
                if age != last_age:
                    print("\t" + "-"*117)
                print(f"\t{token.ljust(75)}| {str(age).ljust(30)}| {value:.2f}".ljust(20))
        
                
                if token.upper() == token:
                    diagnoses.append(token)
        
                if idx_event == true_seq_len - 1:
                    print("\n" + "="*120)
                    print("Diagnosis summary".upper())
                    print(f"{diagnoses}")
                    print("="*120)
                    print("\n")
                    print("Predicted future events".upper())
                    print("="*120)
        
                last_age = age
            


INFO:root:Creating unsupervised collator for DataModule


SurvivEHR-cr-small-debug7_exp1000-v1


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,613,894 sample

In [8]:
batch["attention_mask"]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')

In [40]:
# print(survs[0][0])
# print(l)
plt.plot(experiment.model.surv_layer.t_eval, survs[0][0][0,:])
plt.savefig("fig.png")

# Check against the database

# Generate forward

In [5]:
# path_to_directory = os.getcwd() + "/../data/"
PATH_TO_DB = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db"

collector = SQLiteDataCollector(db_path=PATH_TO_DB)
collector.connect()

In [72]:
collector.cursor.execute("""SELECT name FROM sqlite_master WHERE type='table' LIMIT 3;""")   # 
results = collector.cursor.fetchall()
for result in results:
    print(result)

('static_table',)
('diagnosis_table',)
('measurement_25_Hydroxyvitamin_D2_level_92',)


SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1933.0

In [114]:
collector.cursor.execute("""SELECT * FROM static_table WHERE sex=='F' AND imd=='4' AND ethnicity=='WHITE' AND YEAR_OF_BIRTH LIKE '1933-%'""")   # 

patient_ids = []
results = collector.cursor.fetchall()
for result in results:
    # print(result)
    patient_ids.append(result[1])
patient_ids_str1 = ", ".join(str(pid) for pid in patient_ids)
print(f"{len(patient_ids)} patients with static match")

6884 patients with static match


In [121]:
# query = f"""
#             SELECT DISTINCT PATIENT_ID
#             FROM diagnosis_table
#             WHERE EVENT IN ('HYPERTENSION', 'ANY_DEAFNESS_HEARING_LOSS_V2', 'IHDINCLUDINGMI_OPTIMALV2, OSTEOARTHRITIS,TYPE2DIABETES') 
#                 AND patient_id IN ({patient_ids_str})
#             GROUP BY patient_id
#             HAVING COUNT(DISTINCT event) >= 5
#             """

query = f"""SELECT patient_id
            FROM diagnosis_table
            WHERE patient_id IN ({patient_ids_str1} )
            GROUP BY patient_id
            HAVING 
                 COUNT(
                      DISTINCT CASE WHEN event IN ('HYPERTENSION', 'ANY_DEAFNESS_HEARING_LOSS_V2', 'IHDINCLUDINGMI_OPTIMALV2', 'OSTEOARTHRITIS', 'TYPE2DIABETES')
                                    THEN event
                               END
                    ) = 5
            ORDER BY patient_id;
            """

patient_ids = []
collector.cursor.execute(query)   # measurement_ACE_Inhibitors_D2T
results = collector.cursor.fetchall()
for result in results:
    patient_ids.append(result[0])
patient_ids_str2 = ", ".join(str(pid) for pid in patient_ids)
print(f"{len(patient_ids)} patients with static match and all of these events")

114 patients with static match and all of these events


In [122]:
patient_ids = []
collector.cursor.execute(f"""SELECT * FROM diagnosis_table WHERE patient_id IN ({patient_ids_str2}) ORDER BY patient_id ASC, date ASC""")   # event=='ALCOHOLMISUSE_V2' AND date LIKE '2008-%' AND 
results = collector.cursor.fetchall()
for result in results:
    print(result)
    patient_ids.append(result[1])

# collector.cursor.execute("""SELECT * FROM diagnosis_table WHERE event=='HYPERTENSION' AND date LIKE '2003-%' LIMIT 10""")   # measurement_ACE_Inhibitors_D2T
# results = collector.cursor.fetchall()
# for result in results:
#     print(result)


(20368, 67978220368, 'ASTHMA_PUSHASTHMA', '1998-01-01')
(20368, 67978220368, 'PMRANDGCA', '2000-01-01')
(20368, 67978220368, 'OSTEOARTHRITIS', '2000-02-01')
(20368, 67978220368, 'ANY_DEAFNESS_HEARING_LOSS_V2', '2000-11-01')
(20368, 67978220368, 'COPD', '2001-01-01')
(20368, 67978220368, 'TYPE2DIABETES', '2002-01-11')
(20368, 67978220368, 'HYPERTENSION', '2002-02-28')
(20368, 67978220368, 'IHDINCLUDINGMI_OPTIMALV2', '2004-03-01')
(20368, 67978220368, 'AF', '2007-05-10')
(20368, 67978220368, 'CKDSTAGE3TO5', '2008-11-03')
(20368, 67978220368, 'VALVULARDISEASES_V2', '2009-08-20')
(20368, 67978220368, 'DEATH', '2010-10-09')
(20502, 73302920502, 'DEPRESSION', '2007-12-17')
(20502, 73302920502, 'PREVALENT_IBS_V2', '2008-09-08')
(20502, 73302920502, 'HYPERTENSION', '2008-10-13')
(20502, 73302920502, 'TYPE2DIABETES', '2008-10-13')
(20502, 73302920502, 'ANXIETY', '2008-10-13')
(20502, 73302920502, 'OSTEOARTHRITIS', '2009-02-27')
(20502, 73302920502, 'IHDINCLUDINGMI_OPTIMALV2', '2012-11-01')
(205

In [106]:
patient_ids_str = ", ".join(str(pid) for pid in patient_ids)
collector.cursor.execute(f"""SELECT * FROM measurement_Systolic_blood_pressure_4 WHERE patient_id == 2666145020970""")   # 5437879821203
results = collector.cursor.fetchall()
for result in results:
    print(result)


(20970, 2666145020970, 'Systolic_blood_pressure_4', 142.0, '2012-08-14')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 130.0, '2016-07-12')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 139.0, '2017-07-04')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 169.0, '2018-02-26')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 163.0, '2018-03-05')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 174.0, '2018-03-05')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 154.0, '2018-11-15')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 168.0, '2018-11-22')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 134.0, '2019-08-16')
(20970, 2666145020970, 'Systolic_blood_pressure_4', 142.0, '2021-12-02')


In [13]:
collector.cursor.execute("""SELECT * FROM static_table WHERE practice_id=='21573' AND patient_id=='6626432621573'""")   # 

results = collector.cursor.fetchall()
for result in results:
    print(result)


collector.cursor.execute("""SELECT * FROM measurement_Body_mass_index_3 WHERE practice_id=='21573' AND patient_id=='6626432621573'""")   # 

results = collector.cursor.fetchall()
for result in results:
    print(result)

(21573, 6626432621573, 'ASIAN', '2000-07-15', 'F', 'E', 4, 'North West', '2019-12-13', '2019-12-13', '2021-09-21')
(21573, 6626432621573, 'Body_mass_index_3', 17.6, '2018-12-13')


## Build configurations

In [4]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  # overrides=[
                  #     ]
                 )

# Just load in pretrained model
cfg.experiment.train = False
cfg.experiment.test = False
cfg.experiment.log = False
cfg.experiment.run_id="CR_11M"



print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 20
  global_diagnoses: false
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
experiment:
  type: pre-train
  project_name: SurvStreamGPT_${head.SurvLayer}
  run_id: CR_11M
  train: false
  test: false
  verbose: true
  seed: 1337
  log: false
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes: None
optim:
  num_epochs: 1
  learning_rate: 0.0003
  scheduler: CAWarmRestarts
  scheduler_periods: 5000
  scheduler_warmup: true
  lr_cosine_decay_period: 1000000

In [5]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

model, dm = run(cfg)     
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")


INFO:root:Running cr on 72 CPUs and 1 GPUs


env: SLURM_NTASKS_PER_NODE=28


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,613,894 samples
INFO:root:Loaded /rds/p

Loaded model with 11.433294 M parameters


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [6]:
dm.train_set.view_sample(10, max_dynamic_events=None, report_time=True)

# for batch in dm.train_dataloader():
#     break
# print(batch["tokens"].shape)


Time to retrieve sample index 10 was 0.09404301643371582 seconds

SEX                 | F
IMD                 | 1.0
ETHNICITY           | WHITE
birth_year          | 1998.0
Sequence of 58 events

Token                                                                      | Age               | Standardised value
Systemic_oral_corticosteroids_optimal                                      | 82                | nan               
ATOPICECZEMA                                                               | 601               | nan               
First_gen_H1_antihistamines                                                | 5691              | nan               
Current_smoker_83                                                          | 6035              | nan               
Diastolic_blood_pressure_5                                                 | 6035              | -0.38             
Systolic_blood_pressure_4                                                  | 6035              | -0.29      

# Generation from real prompts

In [7]:
# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

def table(_tokens,_ages,_values):
    # print table rows 
    assert _tokens.shape[0] == 1
    assert _ages.shape[0] == 1
    assert _values.shape[0] == 1
    
    for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(_tokens[0].tolist()).split(" "), 
                                                    _ages[0, :], 
                                                    _values[0, :]
                                                    )
                                                ):
        _value = dm.unstandardise(_cat, _value)
        print(f"\t{_cat}".ljust(60) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}



In [29]:
dm.meta_information["diagnosis_table"]["event"].to_list()
dm.meta_information["measurement_tables"]["event"].to_list()

['25_Hydroxyvitamin_D2_level_92',
 '25_Hydroxyvitamin_D3_level_90',
 'AST___aspartate_transam_SGOT__46',
 'AST_serum_level_47',
 'Albumin___creatinine_ratio_37',
 'Basophil_count_22',
 'Blood_calcium_level_38',
 'Blood_urea_28',
 'Body_mass_index_3',
 'Brain_natriuretic_peptide_level_66',
 'Calcium_adjusted_level_41',
 'Calculated_LDL_cholesterol_level_103',
 'Combined_total_vitamin_D2_and_D3_level_93',
 'Corrected_serum_calcium_level_42',
 'Current_smoker_83',
 'Diastolic_blood_pressure_5',
 'Eosinophil_count_21',
 'Erythrocyte_sedimentation_rate_61',
 'Ex_smoker_84',
 'Free_T4_level_76',
 'GFR_calculated_abbreviated_MDRD_34',
 'Haematocrit___PCV_16',
 'Haematocrit_15',
 'Haemoglobin_A1c_level___IFCC_standardised_6',
 'Haemoglobin_A1c_level_8',
 'Haemoglobin_estimation_9',
 'HbA1c_level__DCCT_aligned__7',
 'INR___international_normalised_ratio_81',
 'International_normalised_ratio_82',
 'Lymphocyte_count_20',
 'Mean_corpusc_Hb_conc__MCHC__14',
 'Mean_corpusc_haemoglobin_MCH__13',
 'Me

## Brute force search, get some prompts from the test dataset which show some different criteria

In [8]:
# indexing_conditions_to_pivot_on = "TYPE2DIABETES"    # TYPE1DM, HYPERTENSION, OSTEOARTHRITIS, CKDSTAGE3TO5, HF_V3, ISCHAEMICSTROKE_V2, DEPRESSION
# exclude_on_events = ["Statins",
#                      "Metformin_612_A10BD2",
#                      "Lipid_lowering_drugs_Optimal"]

indexing_conditions_to_pivot_on = ["POLYCYSTIC_OVARIAN_SYNDROME_PCOS_V2",
                                   "COPD",
                                   # "ENDOMETRIOSIS_ADENOMYOSIS_V2"
                                  ]
exclude_on_events = []



In [9]:
indexing_token_to_pivot_on = dm.encode(indexing_conditions_to_pivot_on)
print(indexing_token_to_pivot_on)


tokens_to_exclude_on = dm.encode(exclude_on_events)
print(tokens_to_exclude_on)

patients_satisfying_criteria = []
samples_satisfying_criteria = []
example_count = 0

for _idx, sample in tqdm(enumerate(dm.test_set), total=len(dm.test_set)):

    number_of_index_events = sum([tkn for tkn in indexing_token_to_pivot_on if tkn in sample["tokens"]])
    
    if (len(sample["tokens"]) > 5) and (number_of_index_events==len(indexing_token_to_pivot_on)):

        # todo: this is excluded events at any time, change to before the index event
        number_of_excluded_events = sum([tkn for tkn in tokens_to_exclude_on if tkn in sample["tokens"]])

        if number_of_excluded_events == 0:
            patients_satisfying_criteria.append(_idx)
            samples_satisfying_criteria.append(sample)

            if example_count >= 4:
                break
            else:
                example_count += 1
                print(example_count)

    # elif _idx > 100000:
    #     break
    else:
        pass

[58, 83]
[]


  0%|          | 4458/1508320 [03:44<21:04:18, 19.82it/s]

KeyboardInterrupt



In [None]:
print(patients_satisfying_criteria)
# patients_satisfying_criteria = [724, 1760, 2055, 2099, 2167]

In [None]:
for _patient_idx in patients_satisfying_criteria:
    print(_patient_idx)

    # Get the sample
    sample = dm.test_set[_patient_idx]
    _index = (sample["tokens"] == indexing_token_to_pivot_on).nonzero(as_tuple=True)[0].item()

    # chunk by day
    _day_at_index = int(sample["ages"][_index])
    _index_pre = sum(sample["ages"] < _day_at_index)
    _index_inc = sum(sample["ages"] <= _day_at_index)
    
    for _phase, _split_at in enumerate([_index_pre, _index_inc]):

        if _phase == 0:
            print(f"\n\nBefore {dm.decode([_token]).lower()} is seen in the medical history")
        else:
            print('\n------------------------------------ page break ------------------------------------')
            print(f"\n\nAfter the diagnosis of {dm.decode([_token]).lower()} is then seen in the medical history")

        _covariates = sample["static_covariates"].reshape((1,-1))
        _tokens = sample["tokens"][:_split_at].reshape((1,-1))
        _ages = sample["ages"][:_split_at].reshape((1,-1))
        _values = sample["values"][:_split_at].reshape((1,-1))

        # Report the initial part of their historical context
        _dec_covariates = dm.train_set._decode_covariates(_covariates)
        print(f"\n\nMedical history of a \n\t" + \
                        f"{_dec_covariates['ETHNICITY'][0].lower()}, " + \
                        f"{'male' if _dec_covariates['SEX'][0] == 'M' else 'female'} patient, " + \
                        f"born in {int(_dec_covariates['birth_year'][0])}, " + \
                        f"with IMD (deprivation) level {int(_dec_covariates['IMD'][0])}. \n\n" 
              )
        table(_tokens, _ages, _values)

        
        # Predict the future and report
        new_tokens, new_ages, new_values = model.generate(_tokens.to(device), _ages.to(device), _values.to(device), _covariates.to(device), max_new_tokens=20)
        print(f"""\nSurvivEHR then predicts the next events to be:
               """)
        table(new_tokens[:, _tokens.shape[1]:].reshape((1,-1)), 
              new_ages[:, _tokens.shape[1]:].reshape((1,-1)),
              new_values[:, _tokens.shape[1]:].reshape((1,-1))
             )

    print('\n----------------------------------------------------------------------------------------')
    print('------------------------------------ document break ------------------------------------')
    print('----------------------------------------------------------------------------------------')



In [None]:
dm.test_set.view_sample(patients_satisfying_criteria[1], max_dynamic_events=None, report_time=True)

# Generation from fixed prompts

### Sampling from the model

In [33]:
model= model.to(device)

baseline_covariates = {"sex": "F", "deprivation": 5.0, "ethnicity": "WHITE", "year_of_birth": 1997-65}

multimorbidity_conditions = ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION", "IHDINCLUDINGMI_OPTIMALV2", "COPD"]
at_ages = [40, 40, 44, 49, 65]

prompt, ages_in_years, values = [], [], []

for condition, age in zip(multimorbidity_conditions, at_ages):
    # Default context start
    prompt.append(condition)
    ages_in_years.append(age)
    values.append(np.nan)

    # Convert for model
    covariates = dm.train_set._encode_covariates(**baseline_covariates).reshape(1,-1).to(device)
    tokens = encode_prompt(prompt)
    values_scaled = encode_value(prompt, values)
    ages_in_days = encode_age(ages_in_years)

    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, covariates, max_new_tokens=10)
    
    # report:
    print(f"Baseline covariates: \n{baseline_covariates}\n" + "="*90)
    print(f"PROMPT:")
    for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                    new_ages[0, :], 
                                                    new_values[0, :]
                                                   )
                                               ):
        _value = dm.unstandardise(_cat, _value)
        print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}
        if _idx == tokens.shape[-1] - 1:
            print("="*90)
            print(f"GENERATION")



Baseline covariates: 
{'sex': 'F', 'deprivation': 5.0, 'ethnicity': 'WHITE', 'year_of_birth': 1932}
PROMPT:
TYPE2DIABETES                                     nan            at age 40 (14600 days)
GENERATION
CalciumChannelBlck_D2T                            nan            at age 40 (14607 days)
Benzodiazepines                                   nan            at age 40 (14646 days)
Plasma_ferritin_level_62                          314956.41      at age 42 (15240 days)
OSTEOARTHRITIS                                    nan            at age 43 (15652 days)
ALLERGICRHINITISCONJ                              nan            at age 47 (17050 days)
Levothyroxine_                                    nan            at age 52 (18816 days)
MINFARCTION                                       nan            at age 53 (19523 days)
LYMPHOMA_PREVALENCE_V2                            nan            at age 57 (20804 days)
POP_reg_contraceptive                             nan            at age 61 (22243 days)
A

# Prompt testing

In [22]:
# generate: sample the next 10 tokens
new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, covariates, max_new_tokens=50)

# report:
print(f"Baseline covariates: \n{baseline_covariates}\n" + "="*90)
print(f"PROMPT:")
for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                new_ages[0, :], 
                                                new_values[0, :]
                                               )
                                           ):
    # _value = dm.unstandardise(_cat, _value)
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}
    if _idx == tokens.shape[-1] - 1:
        print("="*90)
        print(f"GENERATION")



Baseline covariates: 
{'sex': 'F', 'deprivation': 5.0, 'ethnicity': 'WHITE', 'year_of_birth': 1932}
PROMPT:
TYPE2DIABETES                                     nan            at age 40 (14600 days)
GENERATION
Current_smoker_83                                 0.84           at age 41 (15049 days)
HF_V3                                             nan            at age 44 (15907 days)
All_Antiplatelets                                 nan            at age 45 (16568 days)
Total_25_hydroxyvitamin_D_level_91                -0.08          at age 48 (17521 days)
Albumin___creatinine_ratio_37                     0.80           at age 50 (18396 days)
LEUKAEMIA_PREVALENCEV2                            nan            at age 51 (18557 days)
OSTEOPOROSIS                                      nan            at age 54 (19683 days)
GFR_calculated_abbreviated_MDRD_34                -0.10          at age 54 (19786 days)
Serum_bilirubin_level_53                          1.38           at age 59 (21519 days)
T

## Diagnoses: How related conditions are impacted by each other - multi-morbidity

In [32]:
exp_prompts = [["TYPE2DIABETES", "Metformin_612_A10BD2"],
               ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION",],
               ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION", "IHDINCLUDINGMI_OPTIMALV2"],
               ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION", "IHDINCLUDINGMI_OPTIMALV2", "COPD"],
              ]
exp_promps_lbl = ["T2D+Metformin", "+ Depression", "+IHD/MI", "+COPD"]
exp_ages = [[40, 40],
            [40, 40, 44],
            [40, 40, 44, 49],
            [40, 40, 44, 49, 65],
           ]
exp_values = [[np.nan, np.nan],
              [np.nan, np.nan, np.nan],
              [np.nan, np.nan, np.nan, np.nan],
              [np.nan, np.nan, np.nan, np.nan, np.nan],
              ]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, (_exp_prompt, _exp_age, _exp_value) in enumerate(zip(exp_prompts, 
                                                                    exp_ages, 
                                                                    exp_values)):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True,
                              return_loss=False,
                              return_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        for p_idx in range(len(exp_prompts)):
            plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{exp_promps_lbl[p_idx]}")
        plt.xlabel("Time (years)")
        plt.ylabel(f"$P(T>t)$ ({event_name})")
        plt.legend()
        plt.savefig(save_path + f"multimorbidity/{event_name}.png")


## Diagnoses: How related conditions are impacted by each other


In [23]:
exp_prompts = [["DEPRESSION"], ["TYPE1DM"], ["TYPE2DIABETES"], ["Never_smoked_tobacco_85"], ["Ex_smoker_84"]]
exp_ages = [[20] for _ in range(len(exp_prompts))]
exp_values = [[np.nan] for _ in range(len(exp_prompts))]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, (_exp_prompt, _exp_age, _exp_value) in enumerate(zip(exp_prompts, 
                                                                    exp_ages, 
                                                                    exp_values)):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True,
                              return_loss=False,
                              return_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        for p_idx in range(len(exp_prompts)):
            plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{'->'.join(exp_prompts[p_idx]).lower()}")
        plt.xlabel("Time (years)")
        plt.ylabel(f"$P(T>t)$ ({event_name})")
        plt.legend()
        plt.savefig(save_path + f"diabetes/{event_name}.png")


TypeError: SurvStreamGPTForCausalModelling.forward() got an unexpected keyword argument 'is_causal'

## Values: How increasing BMI affects diagnosis risk

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]

_exp_prompt = ["Body_mass_index_3"]
_exp_age = [40]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel(f"$P(T>t)$ ({event_name})")
            plt.legend()
            plt.savefig(save_path + f"bmi/{event_name}.png")


## Values: How increasing DBP affects diagnosis risk

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]


_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [40]
_exp_values = [[60.], [70.], [80.], [90.], [100.], [110.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"diastolic_blood_pressure/{event_name}.png")


## Values: How varying diagnosis affects value of DBP

In [None]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompts = [["DEPRESSION"], ["TYPE2DIABETES"], ["HF_V3"], ["HYPERTENSION"]]
_exp_age = [20]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_prompt in enumerate(_exp_prompts):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [None]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompt = ["Body_mass_index_3"]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)} of {_exp_value[0]}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


## Baseline, impact of gender

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "POLYCYSTIC_OVARIAN_SYNDROME_PCOS_V2",
                      "DEATH",
                      "COCP_reg_contraception",
                      "all_contraceptive"
                     ]

_genders = ["M", "F", "I"]
_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [20]
_exp_value = [90.]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _gender in enumerate(_genders):

        _baseline_covariate = {"sex": _gender, "deprivation": 4.0, "ethnicity": "WHITE", "year_of_birth": 1997}
        _covariates = dm.train_set._encode_covariates(**_baseline_covariate).reshape(1,-1).to(device)
        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=_covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )        
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_genders)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_genders[p_idx]}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"gender/{event_name}.png")


# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input 2_generation.ipynb