# BEHRT dataset for cardiovascular disease outcomes in patients with Type 2 Diabete Mellitus

In this workbook we demonstrate the workflow of adding the BEHRT adapter to FastEHR

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/4_convert_BEHRT_data


In [2]:
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
import logging
import time
import pickle as pkl
from tqdm import tqdm
import pandas as pd

from FastEHR.dataloader import FoundationalDataModule
from CPRD.examples.data.study_criteria import t2d_inclusion_method

logging.disable(logging.CRITICAL)
torch.manual_seed(1337)

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")


Using device: cuda.


In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../modelling/SurvivEHR/confs", job_name="dataset_creation_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", overrides=[])

# Create new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
# Removing windowing applied to SurvivEHR by default so BEHRT can set it's own windowing
cfg.transformer.block_size = 1e6

print(OmegaConf.to_yaml(cfg))

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 12
  global_diagnoses: false
  repeating_events: true
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
  subsample_training: null
experiment:
  type: pre-train
  project_name: SurvivEHR
  run_id: ${head.SurvLayer}PreTrain_small_${experiment.seed}
  fine_tune_id: null
  notes: null
  tags: null
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
fine_tuning:
  fine_tune_outcomes: null
  custom_outcome

# Adding an adapter

Using FastEHR we can add an adapter to an existing dataloader which will be called inside the collator. This adapter will format the data to what is needed by the downstream task - in this case BEHRT. 

In doing this we still retain the PyTorch dataloader style, but in the case of BEHRT - the downstream model requires all of the data to instead be stored in a single dataframe.

In [4]:
# Build 
dm = FoundationalDataModule(
    path_to_db=cfg.data.path_to_db,
    path_to_ds=cfg.data.path_to_ds,
    load=True,
    overwrite_practice_ids = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/practice_id_splits.pickle",
    overwrite_meta_information=cfg.data.meta_information_path,
    num_threads=6,
    supervised=True,
    adapter="BEHRT",
)

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

# If we were to view the sample, this would be before collation and so we would see the original format.
# dm.train_set.view_sample(1, max_dynamic_events=None, report_time=True)


572096 training patients
33280 validation patients
35758 test patients
265 vocab elements


## Adapter tokeniser

Our adapter also provides the new tokeniser which would be compatible with BEHRT, including the new special tokens "SEP", "CLS", and "MASK".

In [5]:
display(dm.adapter.tokenizer)

{'PAD': 0,
 'UNK': 1,
 'SEP': 2,
 'CLS': 3,
 'MASK': 4,
 'ADDISONS_DISEASE': 5,
 'CYSTICFIBROSIS': 6,
 'SYSTEMIC_SCLEROSIS': 7,
 'SICKLE_CELL_DISEASE_V2': 8,
 'ADDISON_DISEASE': 9,
 'DOWNSSYNDROME': 10,
 'HAEMOCHROMATOSIS_V2': 11,
 'PLASMACELL_NEOPLASM_V2': 12,
 'SJOGRENSSYNDROME': 13,
 'SYSTEMIC_LUPUS_ERYTHEMATOSUS': 14,
 'HIVAIDS': 15,
 'PSORIATICARTHRITIS2021': 16,
 'MS': 17,
 'Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70': 18,
 'LEUKAEMIA_PREVALENCEV2': 19,
 'N_terminal_pro_brain_natriuretic_peptide_level_67': 20,
 'ILD_SH': 21,
 'CHRONIC_LIVER_DISEASE_ALCOHOL': 22,
 'PERNICIOUSANAEMIA': 23,
 'MENIERESDISEASE': 24,
 'LYMPHOMA_PREVALENCE_V2': 25,
 'CROHNS_DISEASE': 26,
 'AllHIVdrugs_HIV': 27,
 'Plasma_B_natriuretic_peptide_level_69': 28,
 'CHRONICFATIGUESYNDROMEMM_V2': 29,
 'Plasma_pro_brain_natriuretic_peptide_level_64': 30,
 'STROKE_HAEMRGIC': 31,
 'PARKINSONS': 32,
 'AORTICANEURYSM_V2': 33,
 'BIPOLAR': 34,
 'BRONCHIECTASIS': 35,
 'ULCERATIVE_COLITIS': 36,
 'SCHIZOPHRE

## Loading with an adapter

Now when we load data into memory we get the correct form. We additionally retain values - despite now being used in BEHRT.

We do not retain static baseline covariates, as these are not used in BEHRT. However it is simple to modify the existing adapter, or create new adapters to include this information.

In [6]:
for batch in dm.train_dataloader():
    break

def print_row(tkn, age, value):
    
    itostr = {tkn_str: tkn_idx for tkn_idx, tkn_str in dm.adapter.tokenizer.items()}
    
    print(f"{itostr[tkn.item()]}".ljust(50) + \
          f"{age.item() * dm.train_set.time_scale / 365:.1f}".ljust(10) + \
          f"{value.item():.1f}"
          )

# First patient of batch
atn_mask = batch["attention_mask"][0, :]
events = batch["tokens"][0, atn_mask==1]
ages = batch["ages"][0, atn_mask==1]
values = batch["values"][0, atn_mask==1]

print("Event".ljust(50) + "Age".ljust(10) + "Value".ljust(10) + "\n" + "="*70)
for event, age, value in zip(events, ages, values):
    print_row(event, age, value)


print("="*70)
print_row(batch['target_token'][0],
          batch['target_age_delta'][0] * dm.collate_fn.supervised_time_scale,
          batch['target_value'][0]
         )


Event                                             Age       Value     
CLS                                               33.9      nan
Body_mass_index_3                                 33.9      -0.3
Diastolic_blood_pressure_5                        33.9      -0.3
O_E___height_1                                    33.9      0.1
O_E___weight_2                                    33.9      -0.4
Systolic_blood_pressure_4                         33.9      -0.3
SEP                                               33.9      nan
Serum_cholesterol_97                              33.9      -0.4
Serum_triglycerides_105                           33.9      -0.4
SEP                                               33.9      nan
Serum_cholesterol_97                              34.9      -0.4
Serum_triglycerides_105                           34.9      -0.5
SEP                                               34.9      nan
NSAIDS_oral_OPTIMAL_final                         36.9      nan
SEP                      

# The BEHRT pandas dataset

We now have a usable dataloader for a BEHRT model. However, the original code requires all data be saved in a single dataframe. This is done in the ``build_cvd_BEHRT_dataset.py`` file found in this directory. 

Here, we inspect what is produced.


### Tokeniser

The tokeniser (already shown above as an attribute of the datamodule class), was saved to file to be loaded into the BEHRT model framework

In [7]:
with open(cfg.data.path_to_ds + "BEHRT/token2idx.pkl", "rb") as f:
    bert_vocab = pkl.load(f)
    
display(bert_vocab)

{'token2idx': {'PAD': 0,
  'UNK': 1,
  'SEP': 2,
  'CLS': 3,
  'MASK': 4,
  'ADDISONS_DISEASE': 5,
  'CYSTICFIBROSIS': 6,
  'SYSTEMIC_SCLEROSIS': 7,
  'SICKLE_CELL_DISEASE_V2': 8,
  'ADDISON_DISEASE': 9,
  'DOWNSSYNDROME': 10,
  'HAEMOCHROMATOSIS_V2': 11,
  'PLASMACELL_NEOPLASM_V2': 12,
  'SJOGRENSSYNDROME': 13,
  'SYSTEMIC_LUPUS_ERYTHEMATOSUS': 14,
  'HIVAIDS': 15,
  'PSORIATICARTHRITIS2021': 16,
  'MS': 17,
  'Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70': 18,
  'LEUKAEMIA_PREVALENCEV2': 19,
  'N_terminal_pro_brain_natriuretic_peptide_level_67': 20,
  'ILD_SH': 21,
  'CHRONIC_LIVER_DISEASE_ALCOHOL': 22,
  'PERNICIOUSANAEMIA': 23,
  'MENIERESDISEASE': 24,
  'LYMPHOMA_PREVALENCE_V2': 25,
  'CROHNS_DISEASE': 26,
  'AllHIVdrugs_HIV': 27,
  'Plasma_B_natriuretic_peptide_level_69': 28,
  'CHRONICFATIGUESYNDROMEMM_V2': 29,
  'Plasma_pro_brain_natriuretic_peptide_level_64': 30,
  'STROKE_HAEMRGIC': 31,
  'PARKINSONS': 32,
  'AORTICANEURYSM_V2': 33,
  'BIPOLAR': 34,
  'BRONCHIECTA

### DataFrame

Similarly, the pandas DataFrame was also saved to file

In [8]:
df = pd.read_parquet(cfg.data.path_to_ds + "BEHRT/dataset.parquet")

print(f"Dataframe with {len(df)} samples")

Dataframe with 572096 samples


In [9]:
display(df.head())

Unnamed: 0,patid,caliber_id,age
0,P0000001,"[CLS, Body_mass_index_3, Diastolic_blood_press...","[6.778630256652832, 6.778630256652832, 6.77863..."
1,P0000002,"[CLS, HYPERTENSION, SEP, OSTEOARTHRITIS, PREVA...","[11.300822257995605, 11.300822257995605, 11.30..."
2,P0000003,"[CLS, O_E___height_1, O_E___weight_2, SEP, AST...","[4.4553422927856445, 4.4553422927856445, 4.455..."
3,P0000004,"[CLS, Erythrocyte_sedimentation_rate_61, Eryth...","[5.9572601318359375, 5.9572601318359375, 5.957..."
4,P0000005,"[CLS, Body_mass_index_3, Diastolic_blood_press...","[6.888767242431641, 6.888767242431641, 6.88876..."
