# Creating the parquet dataset from SQLite tables

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/3_build_fine_tuning_datasets/Study3_MultiMorbidity


In [2]:
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
from FastEHR.dataloader.foundational_loader import FoundationalDataModule

import logging
import time

from CPRD.examples.data.study_criteria import multimorbidity_inclusion_method

torch.manual_seed(1337)

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")


Using device: cuda.


In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../../modelling/SurvStreamGPT/confs", job_name="dataset_creation_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", overrides=[])

# Create new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity/"
print(OmegaConf.to_yaml(cfg))

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 12
  global_diagnoses: false
  repeating_events: true
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
  subsample_training: null
experiment:
  type: pre-train
  project_name: SurvivEHR
  run_id: ${head.SurvLayer}PreTrain_small_${experiment.seed}
  fine_tune_id: null
  notes: null
  tags: null
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes: null
optim:
  num_epoch

In [5]:
# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=False,
                            include_diagnoses=True,
                            include_measurements=True,
                            drop_missing_data=False,
                            drop_empty_dynamic=True,
                            tokenizer="tabular",
                            overwrite_practice_ids = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/practice_id_splits.pickle",
                            overwrite_meta_information=cfg.data.meta_information_path,
                            study_inclusion_method=multimorbidity_inclusion_method(),  # min_events=50
                            num_threads=5   # Note: Most of the overhead is in the Polars processing, so unless you have lots of CPus this may not speed up dataset creation
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Building Polars datasets and saving to /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity/
INFO:root:Using train/test/val splits from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/practice_id_splits.pickle
INFO:root:Processing test split...
[Parallel(n_jobs=5)]: Using backend ThreadingBackend with 5 concurrent workers.
Thread generating parquet for 15 practices:   0%|          | 0/15 [00:00<?, ?it/s]
Thread generating parquet for 15 practices:   0%|          | 0/15 [00:00<?, ?it/s][A

Thread generating parquet for 15 practices:   0%|          | 0/15 [00:00<?, ?it/s][A[A


Thread generating parquet for 15 practices:   0%|          | 0/15 [00:00<?, ?it/s][A[A[A



Thread generating parquet for 14 practices:   0%|          | 0/14 [00:00<?, ?it/s][A[A[A[A



Thread generating parquet for 14 practices:   7%|▋         | 1/14

In [None]:
dm.train_set.view_sample(1, max_dynamic_events=None, report_time=True)

In [None]:
dm.meta_information["diagnosis_table"].event.to_list()

In [None]:
import polars as pl
pl.Config.set_tbl_rows(300)
pl.Config.set_fmt_str_lengths(100)

print(dm.train_set.tokenizer._event_counts)

# Test data loading times (so we can optimise cpu usage)

In [None]:
import pyarrow.parquet as pq

dataset1 = pq.ParquetDataset(path_to_db + "polars/split=train/", 
                            filters=[('PRACTICE_ID','=','p20763')]
                            )

import time

start = time.time()   # starting time
df  = dataset1.read().to_pandas()
df = df[df["row_nr"] == 100]
print(df)
print(time.time() - start)


In [None]:
from tqdm import tqdm
import numpy as np

times = []
start = time.time()   # starting time
for row_idx, row in enumerate(tqdm(dm.train_set)):
    # print(f"Sample loaded in {time.time()-start} seconds")
    times.append(time.time()-start)
    start = time.time()
    if row_idx > 1e4:
        break
print(np.mean(times))

In [None]:
times = []
start = time.time()   # starting time
for batch_idx, batch in enumerate(tqdm(dm.train_dataloader())):
    # print(f"batch loaded in {time.time()-start} seconds")    
    times.append(time.time()-start)
    start = time.time()
    if batch_idx > 1e4:
        break
print(np.mean(times))

# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

In [None]:
dm.train_set.view_sample(1236, max_dynamic_events=12, report_time=True)