# Creating the parquet dataset from SQLite tables

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/study_data


In [2]:
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.data.foundational_loader import FoundationalDataModule

import logging
import time

from CPRD.examples.data.study_data.study_criteria import t2d_inclusion_method

torch.manual_seed(1337)

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")


Using device: cuda.


In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../modelling/SurvStreamGPT/confs", job_name="dataset_creation_notebook"):
    cfg = compose(config_name="config_CompetingRisk37M", overrides=[])

# Create new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
max_seq_length=cfg.transformer.block_size = 1e6

print(OmegaConf.to_yaml(cfg))

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 12
  global_diagnoses: false
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
  subsample_training: null
experiment:
  type: pre-train
  project_name: SurvivEHR
  run_id: ${head.SurvLayer}PreTrain_large_${experiment.seed}
  fine_tune_id: null
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes: null
  notes: null
optim:
  num_epochs: 1
  learning_rate: 0.0003
  scheduler_warmup: true
  

In [7]:
# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            include_diagnoses=True,                            
                            include_measurements=True,
                            drop_missing_data=False,
                            drop_empty_dynamic=True,
                            tokenizer="tabular",
                            overwrite_practice_ids = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/practice_id_splits.pickle",
                            overwrite_meta_information=cfg.data.meta_information_path,
                            # study_inclusion_method=t2d_inclusion_method(min_events=50),
                            num_threads=5
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_C

572096 training patients
33280 validation patients
35758 test patients
265 vocab elements


In [8]:
dm.train_set.view_sample(1, max_dynamic_events=None, report_time=True)

Time to retrieve sample index 1 was 0.10100746154785156 seconds

SEX                 | M
IMD                 | 4.0
ETHNICITY           | MISSING
birth_year          | 1966.0
Sequence of 79 events

Token                                                                      | Age at event (in days)      | Standardised value
EPILEPSY                                                                   | 10032.0                     | nan               
Anticonvulsants_OPTIMAL                                                    | 10595.0                     | nan               
Carbamazepine_Optimal                                                      | 10595.0                     | nan               
Anticonvulsants_OPTIMAL                                                    | 10717.0                     | nan               
Carbamazepine_Optimal                                                      | 10717.0                     | nan               
Anticonvulsants_OPTIMAL                        

In [9]:
dm.meta_information["diagnosis_table"].event.to_list()

['ADDISONS_DISEASE',
 'ADDISON_DISEASE',
 'AF',
 'ALCOHOLMISUSE_V2',
 'ALLCANCER_NOHAEM_NOBCC',
 'ALLERGICRHINITISCONJ',
 'ALL_DEMENTIA',
 'ANXIETY',
 'ANY_DEAFNESS_HEARING_LOSS_V2',
 'AORTICANEURYSM_V2',
 'ASTHMA_PUSHASTHMA',
 'ATOPICECZEMA',
 'AUTISM',
 'BIPOLAR',
 'BRONCHIECTASIS',
 'CHRONICFATIGUESYNDROMEMM_V2',
 'CHRONIC_LIVER_DISEASE_ALCOHOL',
 'CKDSTAGE3TO5',
 'COPD',
 'CROHNS_DISEASE',
 'CYSTICFIBROSIS',
 'DEATH',
 'DEPRESSION',
 'DOWNSSYNDROME',
 'EATINGDISORDERS',
 'ENDOMETRIOSIS_ADENOMYOSIS_V2',
 'EPILEPSY',
 'FIBROMYALGIA',
 'GOUT',
 'HAEMOCHROMATOSIS_V2',
 'HF_V3',
 'HIVAIDS',
 'HYPERTENSION',
 'HYPERTHYROIDISM_V2',
 'HYPOTHYROIDISM_DRAFT_V1',
 'IHDINCLUDINGMI_OPTIMALV2',
 'ILD_SH',
 'ISCHAEMICSTROKE_V2',
 'LEUKAEMIA_PREVALENCEV2',
 'LYMPHOMA_PREVALENCE_V2',
 'MENIERESDISEASE',
 'MINFARCTION',
 'MS',
 'NAFLD_V2',
 'OSA',
 'OSTEOARTHRITIS',
 'OSTEOPOROSIS',
 'OTHER_CHRONIC_LIVER_DISEASE_OPTIMAL',
 'PAD_STRICT',
 'PARKINSONS',
 'PERIPHERAL_NEUROPATHY',
 'PERNICIOUSANAEMIA',


In [10]:
import polars as pl
pl.Config.set_tbl_rows(300)
pl.Config.set_fmt_str_lengths(100)

print(dm.train_set.tokenizer._event_counts)

shape: (264, 3)
┌──────────────────────────────────────────────────────────┬───────────┬───────────┐
│ EVENT                                                    ┆ COUNT     ┆ FREQUENCY │
│ ---                                                      ┆ ---       ┆ ---       │
│ str                                                      ┆ u32       ┆ f64       │
╞══════════════════════════════════════════════════════════╪═══════════╪═══════════╡
│ UNK                                                      ┆ 0         ┆ 0.0       │
│ ADDISONS_DISEASE                                         ┆ 6691      ┆ 8.8559e-7 │
│ CYSTICFIBROSIS                                           ┆ 7053      ┆ 9.3350e-7 │
│ SYSTEMIC_SCLEROSIS                                       ┆ 8772      ┆ 0.000001  │
│ SICKLE_CELL_DISEASE_V2                                   ┆ 11159     ┆ 0.000001  │
│ ADDISON_DISEASE                                          ┆ 11794     ┆ 0.000002  │
│ DOWNSSYNDROME                                  

# Test data loading times (so we can optimise cpu usage)

In [None]:
import pyarrow.parquet as pq

dataset1 = pq.ParquetDataset(path_to_db + "polars/split=train/", 
                            filters=[('PRACTICE_ID','=','p20763')]
                            )

import time

start = time.time()   # starting time
df  = dataset1.read().to_pandas()
df = df[df["row_nr"] == 100]
print(df)
print(time.time() - start)


In [13]:
from tqdm import tqdm
import numpy as np

times = []
start = time.time()   # starting time
for row_idx, row in enumerate(tqdm(dm.train_set)):
    # print(f"Sample loaded in {time.time()-start} seconds")
    times.append(time.time()-start)
    start = time.time()
    if row_idx > 1e4:
        break
print(np.mean(times))

  0%|          | 2044/23343104 [01:24<269:23:50, 24.07it/s]

KeyboardInterrupt



In [14]:
times = []
start = time.time()   # starting time
for batch_idx, batch in enumerate(tqdm(dm.train_dataloader())):
    # print(f"batch loaded in {time.time()-start} seconds")    
    times.append(time.time()-start)
    start = time.time()
    if batch_idx > 1e4:
        break
print(np.mean(times))

# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

  0%|          | 4/364736 [00:31<808:15:55,  7.98s/it]

KeyboardInterrupt



In [15]:
dm.train_set.view_sample(1236, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 1236 was 0.02561783790588379 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | ASIAN
birth_year          | 2013.0

Token                                                                      | Age               | Standardised value
O_E___height_1                                                             | 47                | nan               
O_E___weight_2                                                             | 47                | nan               
O_E___weight_2                                                             | 1468              | nan               
Body_mass_index_3                                                          | 1503              | -0.39             
O_E___height_1                                                             | 1503              | nan               
O_E___weight_2                                                             | 1503              | nan               
GFR_calculat