# Creating the parquet dataset from SQLite tables

In [1]:
import os
from pathlib import Path
import sys
if False:
    node_type = os.getenv('BB_CPU')
    venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
    venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
    if venv_site_pkgs.exists():
        sys.path.insert(0, str(venv_site_pkgs))
        print(f"Added path '{venv_site_pkgs}' at start of search paths.")
    else:
        print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

/home/ubuntu/Documents/GitHub/SurvivEHR/SurvivEHR_ExampleData/example/2_build_dataset


In [10]:
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
from SurvivEHR.SurvivEHR_ExampleData.dataloader.foundational_loader import FoundationalDataModule
import logging
import time

torch.manual_seed(1337)

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")


Using device: cpu.


In [11]:
# # load the configuration file, override any settings 
# with initialize(version_base=None, config_path="../modelling/SurvStreamGPT/confs", job_name="dataset_creation_notebook"):
#     cfg = compose(config_name="config_CompetingRisk129M", overrides=[])
# print(OmegaConf.to_yaml(cfg))

path_to_directory = "/home/ubuntu/Documents/GitHub/SurvivEHR/SurvivEHR_ExampleData/example/data/_built/"
PATH_TO_DB = path_to_directory + "example_database.db"
PATH_TO_DS = path_to_directory + "dataset/"

In [51]:
# Build 
dm = FoundationalDataModule(path_to_db=PATH_TO_DB,
                            path_to_ds=PATH_TO_DS,
                            load=False,
                            include_diagnoses=True,                            
                            include_measurements=True,
                            drop_missing_data=False,
                            drop_empty_dynamic=False,        # Change to True in real use case (False due to lack of example data)
                            tokenizer="tabular",
                            # practice_inclusion_conditions=["COUNTRY = 'E'"],
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Building Polars datasets and saving to /home/ubuntu/Documents/GitHub/SurvivEHR/SurvivEHR_ExampleData/example/data/_built/dataset/
INFO:root:Chunking by unique practice ID with no practice inclusion conditions
INFO:root:Creating train/test/val splits using practice_ids
INFO:root:

Collecting meta information from database. This will be used for tokenization and (optionally) standardisation.
INFO:root:	 Static meta information
INFO:root:	 Diagnosis meta information
INFO:root:	 Measurements meta information
                                      Measurements: 0it [00:00, ?it/s]
INFO:root:Writing test split into a DL friendly .parquet dataset.
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 31.07it/s]
INFO:root:Created dataset at /home/ubuntu/Documents/GitHub/SurvivEHR/SurvivEHR_ExampleData/example/data/_built/dataset/split=test with 2 samples
INFO:root:Creating file_row_count_dicts for file-index look-u

9 training patients
0 validation patients
2 test patients
5 vocab elements


# Meta information

In building the dataset, summary statistics of the events were accumulated and stored in meta_information.

These were stored on file, but are also loaded into the data module to be used for
* pre-processing
* tokenizer building

In [64]:
for key in dm.train_set.meta_information.keys():
    print(f"\n\n{key}\n" + "="*len(key) + "\n")
    display(dm.train_set.meta_information[key])



static_table



{'SEX':   category  count
 0        F     33
 1        I     33
 2        M     34,
 'IMD':    category  count
 0       NaN     16
 1       1.0     17
 2       2.0     17
 3       3.0     17
 4       4.0     17
 5       5.0     16,
 'ETHNICITY':   category  count
 0    ASIAN     20
 1    BLACK     20
 2  MISSING     20
 3    MIXED     20
 4    WHITE     20}



diagnosis_table



Unnamed: 0,event,count
0,AF,6
1,DEATH,100
2,STROKE_HAEMRGIC,6




measurement_tables



Unnamed: 0,event,count,count_obs,digest,min,max,mean,approx_lqr,approx_uqr


There are also stored within the tokenizer

In [65]:
import polars as pl
pl.Config.set_tbl_rows(300)
pl.Config.set_fmt_str_lengths(100)
print(dm.train_set.tokenizer._event_counts)

shape: (4, 3)
┌─────────────────┬───────┬───────────┐
│ EVENT           ┆ COUNT ┆ FREQUENCY │
│ ---             ┆ ---   ┆ ---       │
│ str             ┆ u32   ┆ f64       │
╞═════════════════╪═══════╪═══════════╡
│ UNK             ┆ 0     ┆ 0.0       │
│ AF              ┆ 6     ┆ 0.053571  │
│ STROKE_HAEMRGIC ┆ 6     ┆ 0.053571  │
│ DEATH           ┆ 100   ┆ 0.892857  │
└─────────────────┴───────┴───────────┘


# Test data loading times (so we can optimise cpu usage)

In [68]:
import pyarrow.parquet as pq
import time

print(PATH_TO_DS + "split=train/")
dataset1 = pq.ParquetDataset(PATH_TO_DS + "split=train/", 
                             filters=[('PRACTICE_ID','=','20968')]
                             )

# Time to read 
start = time.time()   # starting time
df  = dataset1.read().to_pandas()
print(df[df["row_nr"] == df.row_nr[0]])
print(time.time() - start)


/home/ubuntu/Documents/GitHub/SurvivEHR/SurvivEHR_ExampleData/example/data/_built/dataset/split=train/
   row_nr  PATIENT_ID         VALUE        EVENT DAYS_SINCE_BIRTH  \
0       0           4  [None, None]  [AF, DEATH]   [18919, 25222]   

                                                DATE ETHNICITY YEAR_OF_BIRTH  \
0  [1992-06-14T00:00:00.000000, 2009-09-16T00:00:...     ASIAN    1940-08-27   

  SEX  IMD INDEX_DATE START_DATE   END_DATE COUNTRY      HEALTH_AUTH  \
0   F  2.0 2008-09-04 2008-09-04 2022-04-29       E  East of England   

  PRACTICE_ID CHUNK  
0       20968     0  
0.03758740425109863


## Time to load individual samples

TODO: this error is due to a poorly populated example dataset 

In [71]:
from tqdm import tqdm
import numpy as np

times = []
start = time.time()   # starting time
for row_idx, row in enumerate(tqdm(dm.train_set)):
    # print(f"Sample loaded in {time.time()-start} seconds")
    times.append(time.time()-start)
    start = time.time()
    if row_idx > 100:
        break

print(np.mean(times))

10it [00:00, 116.03it/s]                                                        


## Time to load batch (with only one worker)

In [72]:
times = []
start = time.time()   # starting time
for batch_idx, batch in enumerate(tqdm(dm.train_dataloader())):
    # print(f"batch loaded in {time.time()-start} seconds")    
    times.append(time.time()-start)
    start = time.time()
    if batch_idx > 2:
        break
print(np.mean(times))

# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  5.07it/s]

0.17037177085876465





In [85]:
dm.train_set.view_sample(2, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 2 was 0.017313241958618164 seconds

SEX                 | M
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1940.0
Sequence of 2 events

Token                                                                      | Age at event (in days)      | Standardised value
STROKE_HAEMRGIC                                                            | 18917.0                     | nan               
DEATH                                                                      | 25222.0                     | nan               


# Tokenizer keys

Here we can see that we have built a tokenizer on only the few event examples given in the example dataset. 

In [76]:
display(dm.train_set.tokenizer._stoi.keys())

dict_keys(['PAD', 'UNK', 'AF', 'STROKE_HAEMRGIC', 'DEATH'])

# Connecting to the SQLite database

In [80]:
from SurvivEHR.SurvivEHR_ExampleData.dataloader.dataset.collector import SQLiteDataCollector
collector = SQLiteDataCollector(db_path=PATH_TO_DB)
collector.connect()

We can perform quick queries on the database (if valued events were included in the example dataset)

In [84]:
collector.cursor.execute("""SELECT * FROM measurement_ACE_Inhibitors_D2T LIMIT 10""")
results = collector.cursor.fetchall()
for result in results:
    print(result)

In [83]:
# for batch in generator:
#     print(batch.columns)
#     break