# Creating the parquet dataset from SQLite tables

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/2_build_pre_training_dataset


In [2]:
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
from FastEHR.dataloader.foundational_loader import FoundationalDataModule
import logging
import time

torch.manual_seed(1337)

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")


Using device: cuda.


In [24]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../modelling/SurvivEHR/confs", job_name="pretrain_dataset_creation_job"):
    cfg = compose(config_name="config_CompetingRisk37M")
print(OmegaConf.to_yaml(cfg))


PATH_TO_DB = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db"
PATH_TO_DS = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/"

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 12
  global_diagnoses: false
  repeating_events: true
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
  subsample_training: null
experiment:
  type: pre-train
  project_name: SurvivEHR
  run_id: ${head.SurvLayer}PreTrain_large_${experiment.seed}
  fine_tune_id: null
  notes: null
  tags: null
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes: null
fine_tuning:
  custom_stratificat

In [25]:
# Build 
dm = FoundationalDataModule(path_to_db=PATH_TO_DB,
                            path_to_ds=PATH_TO_DS,
                            load=True,
                            include_diagnoses=True,                            
                            include_measurements=True,
                            drop_missing_data=False,
                            drop_empty_dynamic=True,
                            tokenizer="tabular",
                            practice_inclusion_conditions=["COUNTRY = 'E'"],
                            overwrite_practice_ids = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/practice_id_splits.pickle",
                            overwrite_meta_information=cfg.data.meta_information_path,
                            num_threads=1,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalMod

23613894 training patients
1426714 validation patients
1508320 test patients
265 vocab elements


In [14]:
import polars as pl
pl.Config.set_tbl_rows(300)
pl.Config.set_fmt_str_lengths(100)
print(dm.train_set.tokenizer._event_counts)

shape: (264, 3)
┌──────────────────────────────────────────────────────────┬───────────┬───────────┐
│ EVENT                                                    ┆ COUNT     ┆ FREQUENCY │
│ ---                                                      ┆ ---       ┆ ---       │
│ str                                                      ┆ u32       ┆ f64       │
╞══════════════════════════════════════════════════════════╪═══════════╪═══════════╡
│ UNK                                                      ┆ 0         ┆ 0.0       │
│ ADDISONS_DISEASE                                         ┆ 6691      ┆ 8.8559e-7 │
│ CYSTICFIBROSIS                                           ┆ 7053      ┆ 9.3350e-7 │
│ SYSTEMIC_SCLEROSIS                                       ┆ 8772      ┆ 0.000001  │
│ SICKLE_CELL_DISEASE_V2                                   ┆ 11159     ┆ 0.000001  │
│ ADDISON_DISEASE                                          ┆ 11794     ┆ 0.000002  │
│ DOWNSSYNDROME                                  

In [15]:
dm.train_set.meta_information['measurement_tables'].event.to_list()

['25_Hydroxyvitamin_D2_level_92',
 '25_Hydroxyvitamin_D3_level_90',
 'AST___aspartate_transam_SGOT__46',
 'AST_serum_level_47',
 'Albumin___creatinine_ratio_37',
 'Basophil_count_22',
 'Blood_calcium_level_38',
 'Blood_urea_28',
 'Body_mass_index_3',
 'Brain_natriuretic_peptide_level_66',
 'Calcium_adjusted_level_41',
 'Calculated_LDL_cholesterol_level_103',
 'Combined_total_vitamin_D2_and_D3_level_93',
 'Corrected_serum_calcium_level_42',
 'Current_smoker_83',
 'Diastolic_blood_pressure_5',
 'Eosinophil_count_21',
 'Erythrocyte_sedimentation_rate_61',
 'Ex_smoker_84',
 'Free_T4_level_76',
 'GFR_calculated_abbreviated_MDRD_34',
 'Haematocrit___PCV_16',
 'Haematocrit_15',
 'Haemoglobin_A1c_level___IFCC_standardised_6',
 'Haemoglobin_A1c_level_8',
 'Haemoglobin_estimation_9',
 'HbA1c_level__DCCT_aligned__7',
 'INR___international_normalised_ratio_81',
 'International_normalised_ratio_82',
 'Lymphocyte_count_20',
 'Mean_corpusc_Hb_conc__MCHC__14',
 'Mean_corpusc_haemoglobin_MCH__13',
 'Me

## Time to load individual samples

In [17]:
from tqdm import tqdm
import numpy as np

times = []
start = time.time()   # starting time
for row_idx, row in enumerate(tqdm(dm.train_set)):
    # print(f"Sample loaded in {time.time()-start} seconds")
    times.append(time.time()-start)
    start = time.time()
    if row_idx > 100:
        break
print(np.mean(times))

  0%|          | 101/23613894 [00:09<610:08:56, 10.75it/s]

0.0921168748070212





## Time to load batch (with only one worker)

In [18]:
times = []
start = time.time()   # starting time
for batch_idx, batch in enumerate(tqdm(dm.train_dataloader())):
    # print(f"batch loaded in {time.time()-start} seconds")    
    times.append(time.time()-start)
    start = time.time()
    if batch_idx > 2:
        break
print(np.mean(times))

# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

  0%|          | 3/368968 [00:47<1606:30:27, 15.67s/it]

10.459600329399109





In [19]:
dm.train_set.view_sample(1236, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 1236 was 0.07914471626281738 seconds

SEX                 | M
IMD                 | 4.0
ETHNICITY           | MISSING
birth_year          | 1961.0
Sequence of 61 events

Token                                                                      | Age at event (in days)      | Standardised value
ANXIETY                                                                    | 13695.0                     | nan               
ALCOHOLMISUSE_V2                                                           | 13695.0                     | nan               
Diastolic_blood_pressure_5                                                 | 18601.0                     | 0.01              
Systolic_blood_pressure_4                                                  | 18601.0                     | -0.13             
Body_mass_index_3                                                          | 19281.0                     | -0.17             
O_E___height_1                              

In [20]:
display(dm.train_set.tokenizer._stoi.keys())

dict_keys(['PAD', 'UNK', 'ADDISONS_DISEASE', 'CYSTICFIBROSIS', 'SYSTEMIC_SCLEROSIS', 'SICKLE_CELL_DISEASE_V2', 'ADDISON_DISEASE', 'DOWNSSYNDROME', 'HAEMOCHROMATOSIS_V2', 'PLASMACELL_NEOPLASM_V2', 'SJOGRENSSYNDROME', 'SYSTEMIC_LUPUS_ERYTHEMATOSUS', 'HIVAIDS', 'PSORIATICARTHRITIS2021', 'MS', 'Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70', 'LEUKAEMIA_PREVALENCEV2', 'N_terminal_pro_brain_natriuretic_peptide_level_67', 'ILD_SH', 'CHRONIC_LIVER_DISEASE_ALCOHOL', 'PERNICIOUSANAEMIA', 'MENIERESDISEASE', 'LYMPHOMA_PREVALENCE_V2', 'CROHNS_DISEASE', 'AllHIVdrugs_HIV', 'Plasma_B_natriuretic_peptide_level_69', 'CHRONICFATIGUESYNDROMEMM_V2', 'Plasma_pro_brain_natriuretic_peptide_level_64', 'STROKE_HAEMRGIC', 'PARKINSONS', 'AORTICANEURYSM_V2', 'BIPOLAR', 'BRONCHIECTASIS', 'ULCERATIVE_COLITIS', 'SCHIZOPHRENIAMM_V2', 'PTSDDIAGNOSIS', 'TYPE1DM', 'FIBROMYALGIA', 'VISUAL_IMPAIRMENT', 'AUTISM', 'NAFLD_V2', 'ISCHAEMICSTROKE_V2', 'Albumin___creatinine_ratio_37', 'PVD_V3', 'EATINGDISORDERS', 'PMRA

In [23]:
from FastEHR.dataloader.dataset.collector import SQLiteDataCollector
print(cfg.data.path_to_db)
collector = SQLiteDataCollector(PATH_TO_DB)
collector.connect()

/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/


In [24]:
collector.cursor.execute("""SELECT * FROM measurement_ACE_Inhibitors_D2T LIMIT 10""")
results = collector.cursor.fetchall()
for result in results:
    print(result)

(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2007-09-25')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2007-10-11')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2007-11-30')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2008-01-28')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2008-03-26')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2008-05-20')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2008-07-22')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2008-09-16')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2008-11-10')
(20931, 2375682920931, 'ACE_Inhibitors_D2T', None, '2008-12-31')
