# Creating the parquet dataset from SQLite tables

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/2_build_pre_training_dataset


In [105]:
import numpy as np
import polars as pl
import pandas as pd
import time
from FastEHR.dataloader.foundational_loader import FoundationalDataModule
from CPRD.examples.data.map_to_reduced_names import convert_event_names, EVENT_NAME_SHORT_MAP, EVENT_NAME_LONG_MAP

pl.Config.set_tbl_rows(300)

polars.config.Config

In [24]:
PATH_TO_DB = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db"
PATH_TO_DS = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/"

In [25]:
# Create

#####
##### See ./build_dataset.py for dataset creation.
#####

In [26]:
# Load

dm = FoundationalDataModule(path_to_db=PATH_TO_DB,
                            path_to_ds=PATH_TO_DS,
                            load=True,
                            include_diagnoses=True,
                            include_measurements=True,
                            drop_missing_data=False,
                            drop_empty_dynamic=True,
                            tokenizer="tabular",
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain

23613894 training patients
1426714 validation patients
1508320 test patients
265 vocab elements


## Get the tokenizer mapping from string to token idx

In [129]:
map_name_to_idx = dm.train_set.tokenizer._stoi
print(map_name_to_idx)

{'PAD': 0, 'UNK': 1, 'ADDISONS_DISEASE': 2, 'CYSTICFIBROSIS': 3, 'SYSTEMIC_SCLEROSIS': 4, 'SICKLE_CELL_DISEASE_V2': 5, 'ADDISON_DISEASE': 6, 'DOWNSSYNDROME': 7, 'HAEMOCHROMATOSIS_V2': 8, 'PLASMACELL_NEOPLASM_V2': 9, 'SJOGRENSSYNDROME': 10, 'SYSTEMIC_LUPUS_ERYTHEMATOSUS': 11, 'HIVAIDS': 12, 'PSORIATICARTHRITIS2021': 13, 'MS': 14, 'Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70': 15, 'LEUKAEMIA_PREVALENCEV2': 16, 'N_terminal_pro_brain_natriuretic_peptide_level_67': 17, 'ILD_SH': 18, 'CHRONIC_LIVER_DISEASE_ALCOHOL': 19, 'PERNICIOUSANAEMIA': 20, 'MENIERESDISEASE': 21, 'LYMPHOMA_PREVALENCE_V2': 22, 'CROHNS_DISEASE': 23, 'AllHIVdrugs_HIV': 24, 'Plasma_B_natriuretic_peptide_level_69': 25, 'CHRONICFATIGUESYNDROMEMM_V2': 26, 'Plasma_pro_brain_natriuretic_peptide_level_64': 27, 'STROKE_HAEMRGIC': 28, 'PARKINSONS': 29, 'AORTICANEURYSM_V2': 30, 'BIPOLAR': 31, 'BRONCHIECTASIS': 32, 'ULCERATIVE_COLITIS': 33, 'SCHIZOPHRENIAMM_V2': 34, 'PTSDDIAGNOSIS': 35, 'TYPE1DM': 36, 'FIBROMYALGIA': 37, 

# Summary statistics

In [126]:
# Convert DEXTER produced name to long-name
map_to_short = lambda x: EVENT_NAME_SHORT_MAP.get(x, x)
map_to_long = lambda x: EVENT_NAME_LONG_MAP.get(x, x)

# conditional formatter
def formatter(x):
    """
    • Whole number (no decimals) if x ≥ 1  
    • Four significant figures if 0 ≤ x < 1  
      (uses general format so 0.0001234 → '0.0001234', 1e-7 → '1.000e-07')
    """
    if pd.isna(x):
        return ""                    # keep NaNs blank
    if np.abs(x) >= 1000:
        return f"{x:.2g}"
    if np.abs(x) >= 100:
        return f"{x:.0f}"
    if np.abs(x) >= 10:
        return f"{x:.1f}"
    if np.abs(x) >= 1:
        return f"{x:.2f}"
    if np.abs(x) >= 0.1:
        return f"{x:.3f}"
    return f"{x:.3g}"                # 4-sig-figs for small numbers

In [159]:
def get_stats_table(df, report_values=False):
    
    # Create the plotting name column
    df.loc[:, "Event"] = df["event"].apply(map_to_long)
    df.loc[:, "Event (plotting)"] = df["event"].apply(map_to_short)

    df.loc[:, "idx"] = df["event"].map(map_name_to_idx)

    if report_values:
        df["missing"] = df["count"] - df["count_obs"]
        columns = ["event","mean", "min", "max", "count", "missing"]
    else:
        columns = ["event", "count"]
    
    latex_columns = df[columns]
    fmt = {
        "mean":    formatter,   # e.g. whole numbers
        "min":     formatter,   # e.g. whole numbers
        "max":     formatter,   # e.g. whole numbers
        "count":   "{:.0f}".format,   # e.g. whole numbers
        "missing": "{:.0f}".format,   # e.g. whole numbers
    }
    latex_code = latex_columns.to_latex(index=False, formatters=fmt)
    print(latex_code)

def get_vocab_table(df):
    
    # Create the plotting name column
    df.loc[:, "Event"] = df["event"].apply(map_to_long)
    df.loc[:, "Event (plotting)"] = df["event"].apply(map_to_short)

    df.loc[:, "idx"] = df["event"].map(map_name_to_idx)
    
    columns = ["event", "Event (plotting)", "idx"]
    
    latex_columns = df[columns]
    latex_code = latex_columns.to_latex(index=False)
    print(latex_code)

## Diagnoses

In [158]:
diagnoses = dm.meta_information["diagnosis_table"].copy()

get_stats_table(diagnoses)

\begin{tabular}{lr}
\toprule
                              event &   count \\
\midrule
                   ADDISONS\_DISEASE &    6691 \\
                    ADDISON\_DISEASE &   11794 \\
                                 AF &  731332 \\
                   ALCOHOLMISUSE\_V2 & 1125212 \\
             ALLCANCER\_NOHAEM\_NOBCC & 1496973 \\
               ALLERGICRHINITISCONJ & 3291165 \\
                       ALL\_DEMENTIA &  528602 \\
                            ANXIETY & 3560978 \\
       ANY\_DEAFNESS\_HEARING\_LOSS\_V2 & 2282766 \\
                  AORTICANEURYSM\_V2 &  101134 \\
                  ASTHMA\_PUSHASTHMA & 4175115 \\
                       ATOPICECZEMA & 4369082 \\
                             AUTISM &  156860 \\
                            BIPOLAR &  108852 \\
                     BRONCHIECTASIS &  112618 \\
        CHRONICFATIGUESYNDROMEMM\_V2 &   82799 \\
      CHRONIC\_LIVER\_DISEASE\_ALCOHOL &   63405 \\
                       CKDSTAGE3TO5 & 1088754 \\
               

  latex_code = latex_columns.to_latex(index=False, formatters=fmt)


In [160]:
get_vocab_table(diagnoses)

\begin{tabular}{llr}
\toprule
                              event &        Event (plotting) &  idx \\
\midrule
                   ADDISONS\_DISEASE &               Addison's &    2 \\
                    ADDISON\_DISEASE &               Addison's &    6 \\
                                 AF &                      AF &   81 \\
                   ALCOHOLMISUSE\_V2 &          Alcohol misuse &   94 \\
             ALLCANCER\_NOHAEM\_NOBCC &                  Cancer &  104 \\
               ALLERGICRHINITISCONJ &       Allergic rhinitis &  123 \\
                       ALL\_DEMENTIA &                Dementia &   75 \\
                            ANXIETY &                 Anxiety &  126 \\
       ANY\_DEAFNESS\_HEARING\_LOSS\_V2 &            Hearing loss &  114 \\
                  AORTICANEURYSM\_V2 &        Aortic aneurysms &   30 \\
                  ASTHMA\_PUSHASTHMA &                  Asthma &  134 \\
                       ATOPICECZEMA &                  Eczema &  136 \\
             

  latex_code = latex_columns.to_latex(index=False)


## Medications

In [157]:
medication = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0].copy()

get_stats_table(medication)

\begin{tabular}{lr}
\toprule
                                          event &     count \\
\midrule
                             ACE\_Inhibitors\_D2T & 191462111 \\
                               AMD\_DH\_Donepezil &   5635153 \\
                               AMD\_DH\_Memantine &   2309471 \\
                                    ARBs\_Luyuan &  72329278 \\
                                 Acarbose\_AURUM &    507434 \\
                      AldosteroneAntagonist\_D2T &  11182704 \\
                                AllHIVdrugs\_HIV &     82262 \\
                        All\_AntiArrhythmics\_D2T &  29502352 \\
                              All\_Antiplatelets & 162940629 \\
                              All\_Diuretics\_D2T & 176070847 \\
                 All\_Diuretics\_ExclLactones\_D2T & 176070847 \\
                                   AlphaBlocker &  51635379 \\
                                     Amantadine &    479791 \\
                          Amitriptyline\_optimal &  46415710 \

  latex_code = latex_columns.to_latex(index=False, formatters=fmt)


In [161]:
get_vocab_table(medication)

\begin{tabular}{llr}
\toprule
                                          event &            Event (plotting) &  idx \\
\midrule
                             ACE\_Inhibitors\_D2T &                        ACEI &  259 \\
                               AMD\_DH\_Donepezil &                   Donepezil &  142 \\
                               AMD\_DH\_Memantine &                   Memantine &  116 \\
                                    ARBs\_Luyuan &                         ARB &  221 \\
                                 Acarbose\_AURUM &                    Acarbose &   72 \\
                      AldosteroneAntagonist\_D2T &      Aldosterone antagonist &  155 \\
                                AllHIVdrugs\_HIV &             Antiretrovirals &   24 \\
                        All\_AntiArrhythmics\_D2T &             Antiarrhythmics &  191 \\
                              All\_Antiplatelets &                Antiplatelet &  256 \\
                              All\_Diuretics\_D2T &                 

  latex_code = latex_columns.to_latex(index=False)


## Investigations

In [156]:
investigation = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0].copy()
# print(investigation.head())

get_stats_table(investigation, report_values=True)

\begin{tabular}{lrrrrr}
\toprule
                                             event &    mean &      min &     max &     count &  missing \\
\midrule
                     25\_Hydroxyvitamin\_D2\_level\_92 &    3.91 &        0 &     686 &    782791 &    89321 \\
                     25\_Hydroxyvitamin\_D3\_level\_90 &    47.1 &        0 &     952 &    809104 &    27986 \\
                  AST\_\_\_aspartate\_transam\_SGOT\_\_46 &    26.6 &        0 & 1.5e+04 &   1738489 &    57876 \\
                                AST\_serum\_level\_47 &    27.3 &    -5.00 & 2.1e+04 &  10837982 &   352631 \\
                     Albumin\_\_\_creatinine\_ratio\_37 &    10.7 &    -1.00 & 1.3e+04 &    180911 &   102491 \\
                                 Basophil\_count\_22 &  0.0501 &   -0.100 & 1.1e+05 &  86869779 &  1227239 \\
                            Blood\_calcium\_level\_38 &    2.35 &        0 &     440 &    415717 &    30253 \\
                                     Blood\_urea\_28 &    6.51 &  

  latex_code = latex_columns.to_latex(index=False, formatters=fmt)


In [162]:
get_vocab_table(investigation)

\begin{tabular}{llr}
\toprule
                                             event &       Event (plotting) &  idx \\
\midrule
                     25\_Hydroxyvitamin\_D2\_level\_92 &              Vitamin D &   84 \\
                     25\_Hydroxyvitamin\_D3\_level\_90 &              Vitamin D &   86 \\
                  AST\_\_\_aspartate\_transam\_SGOT\_\_46 &                    AST &  107 \\
                                AST\_serum\_level\_47 &                    AST &  154 \\
                     Albumin\_\_\_creatinine\_ratio\_37 &              urine ACR &   42 \\
                                 Basophil\_count\_22 &              Basophils &  231 \\
                            Blood\_calcium\_level\_38 &                Calcium &   63 \\
                                     Blood\_urea\_28 &                   Urea &   85 \\
                                 Body\_mass\_index\_3 &                    BMI &  247 \\
                Brain\_natriuretic\_peptide\_level\_66 &            

  latex_code = latex_columns.to_latex(index=False)


In [142]:
event_counts = (
    dm.train_set.tokenizer._event_counts
        .with_columns(
            pl.col("EVENT").map_dict(EVENT_NAME_LONG_MAP, default=pl.first())
        )
        .groupby("EVENT", maintain_order=True)     # 2️⃣ group by event name
        .agg(pl.all().sum()) 
)

print(event_counts)

shape: (203, 3)
┌───────────────────────────────────┬───────────┬───────────┐
│ EVENT                             ┆ COUNT     ┆ FREQUENCY │
│ ---                               ┆ ---       ┆ ---       │
│ str                               ┆ u32       ┆ f64       │
╞═══════════════════════════════════╪═══════════╪═══════════╡
│ UNK                               ┆ 0         ┆ 0.0       │
│ Addison's disease                 ┆ 18485     ┆ 0.000002  │
│ Cystic Fibrosis                   ┆ 7053      ┆ 9.3350e-7 │
│ Systemic sclerosis                ┆ 8772      ┆ 0.000001  │
│ Sickle cell disease               ┆ 11159     ┆ 0.000001  │
│ Down sydrome                      ┆ 17006     ┆ 0.000002  │
│ Haemochromatosis                  ┆ 18631     ┆ 0.000002  │
│ Myelomas                          ┆ 20301     ┆ 0.000003  │
│ Sjögren's syndrome                ┆ 23326     ┆ 0.000003  │
│ Systemic lupus erythematosus      ┆ 26820     ┆ 0.000004  │
│ HIV/AIDS                          ┆ 41951     ┆ 0.00

## Time to load individual samples

In [145]:
from tqdm import tqdm
import numpy as np

times = []
start = time.time()   # starting time
for row_idx, row in enumerate(tqdm(dm.train_set)):
    # print(f"Sample loaded in {time.time()-start} seconds")
    times.append(time.time()-start)
    start = time.time()
    if row_idx > 100:
        break
print(np.mean(times))

  0%|          | 101/23613894 [00:07<456:14:43, 14.38it/s]

0.0688858990575753





## Loading times for batches

This will be over-estimated

In [146]:
times = []
start = time.time()   # starting time
for batch_idx, batch in enumerate(tqdm(dm.train_dataloader())):
    # print(f"batch loaded in {time.time()-start} seconds")    
    times.append(time.time()-start)
    start = time.time()
    if batch_idx > 2:
        break
print(np.mean(times))

  0%|          | 3/368968 [00:47<1612:40:15, 15.73s/it]

10.507936418056488





In [147]:
dm.train_set.view_sample(1236, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 1236: 0.0748 seconds

SEX                 | M
IMD                 | 4.0
ETHNICITY           | MISSING
birth_year          | 1961.0
Sequence of 61 events

Token                                                                       | Age at event (days)         | Standardized value
ANXIETY                                                                    | 13695                         | nan
ALCOHOLMISUSE_V2                                                           | 13695                         | nan
Diastolic_blood_pressure_5                                                 | 18601                         | 0.01
Systolic_blood_pressure_4                                                  | 18601                         | -0.13
Body_mass_index_3                                                          | 19281                         | -0.17
O_E___height_1                                                             | 19281                         | 0.10
O_E__

## Vocabulary

In [154]:
for key, item in dm.train_set.tokenizer._itos.items():
    print(f"{key}: {item}")

0: PAD
1: UNK
2: ADDISONS_DISEASE
3: CYSTICFIBROSIS
4: SYSTEMIC_SCLEROSIS
5: SICKLE_CELL_DISEASE_V2
6: ADDISON_DISEASE
7: DOWNSSYNDROME
8: HAEMOCHROMATOSIS_V2
9: PLASMACELL_NEOPLASM_V2
10: SJOGRENSSYNDROME
11: SYSTEMIC_LUPUS_ERYTHEMATOSUS
12: HIVAIDS
13: PSORIATICARTHRITIS2021
14: MS
15: Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70
16: LEUKAEMIA_PREVALENCEV2
17: N_terminal_pro_brain_natriuretic_peptide_level_67
18: ILD_SH
19: CHRONIC_LIVER_DISEASE_ALCOHOL
20: PERNICIOUSANAEMIA
21: MENIERESDISEASE
22: LYMPHOMA_PREVALENCE_V2
23: CROHNS_DISEASE
24: AllHIVdrugs_HIV
25: Plasma_B_natriuretic_peptide_level_69
26: CHRONICFATIGUESYNDROMEMM_V2
27: Plasma_pro_brain_natriuretic_peptide_level_64
28: STROKE_HAEMRGIC
29: PARKINSONS
30: AORTICANEURYSM_V2
31: BIPOLAR
32: BRONCHIECTASIS
33: ULCERATIVE_COLITIS
34: SCHIZOPHRENIAMM_V2
35: PTSDDIAGNOSIS
36: TYPE1DM
37: FIBROMYALGIA
38: VISUAL_IMPAIRMENT
39: AUTISM
40: NAFLD_V2
41: ISCHAEMICSTROKE_V2
42: Albumin___creatinine_ratio_37
43: PVD_V3
4