# Creating the parquet dataset from SQLite tables

In [3]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/2_build_pre_training_dataset
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import numpy as np
import polars as pl
import pandas as pd
import time
import logging

logging.disable(logging.CRITICAL)

from FastEHR.dataloader.foundational_loader import FoundationalDataModule
from CPRD.examples.data.map_to_reduced_names import convert_event_names, EVENT_NAME_SHORT_MAP, EVENT_NAME_LONG_MAP

pl.Config.set_tbl_rows(300)

polars.config.Config

In [8]:
PATH_TO_DB = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db"
PATH_TO_DS = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/"

In [9]:
# Create

#####
##### See ./build_dataset.py for dataset creation.
#####

In [11]:
# Load

dm = FoundationalDataModule(path_to_db=PATH_TO_DB,
                            path_to_ds=PATH_TO_DS,
                            load=True,
                            include_diagnoses=True,
                            include_measurements=True,
                            drop_missing_data=False,
                            drop_empty_dynamic=True,
                            tokenizer="tabular",
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

23613894 training patients
1426714 validation patients
1508320 test patients
265 vocab elements


In [15]:
for batch in dm.train_dataloader():
    break
print(batch)

{'static_covariates': tensor([], size=(64, 0)), 'tokens': tensor([[  3, 263,   2,  ...,   0,   0,   0],
        [  3, 250, 264,  ...,   0,   0,   0],
        [  3, 259, 256,  ...,   0,   0,   0],
        ...,
        [  3, 188,   2,  ...,   0,   0,   0],
        [  3, 233,   2,  ...,   0,   0,   0],
        [  3, 210, 126,  ...,   0,   0,   0]]), 'ages': tensor([[ 5.2899,  5.2899,  5.2899,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.7551,  4.7551,  4.7551,  ...,  0.0000,  0.0000,  0.0000],
        [10.8038, 10.8038, 10.8038,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 5.4268,  5.4268,  5.4268,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0416,  0.0416,  0.0416,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.3923,  2.3923,  2.3923,  ...,  0.0000,  0.0000,  0.0000]]), 'values': tensor([[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,  0.0274, -0.2522,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan

## Get the tokenizer mapping from string to token idx

In [None]:
map_name_to_idx = dm.train_set.tokenizer._stoi
print(map_name_to_idx)

# Summary statistics

In [None]:
# Convert DEXTER produced name to long-name
map_to_short = lambda x: EVENT_NAME_SHORT_MAP.get(x, x)
map_to_long = lambda x: EVENT_NAME_LONG_MAP.get(x, x)

# conditional formatter
def formatter(x):
    """
    • Whole number (no decimals) if x ≥ 1  
    • Four significant figures if 0 ≤ x < 1  
      (uses general format so 0.0001234 → '0.0001234', 1e-7 → '1.000e-07')
    """
    if pd.isna(x):
        return ""                    # keep NaNs blank
    if np.abs(x) >= 1000:
        return f"{x:.2g}"
    if np.abs(x) >= 100:
        return f"{x:.0f}"
    if np.abs(x) >= 10:
        return f"{x:.1f}"
    if np.abs(x) >= 1:
        return f"{x:.2f}"
    if np.abs(x) >= 0.1:
        return f"{x:.3f}"
    return f"{x:.3g}"                # 4-sig-figs for small numbers

In [None]:
def get_stats_table(df, report_values=False):
    
    # Create the plotting name column
    df.loc[:, "Event"] = df["event"].apply(map_to_long)
    df.loc[:, "Event (plotting)"] = df["event"].apply(map_to_short)

    df.loc[:, "idx"] = df["event"].map(map_name_to_idx)

    if report_values:
        df["missing"] = df["count"] - df["count_obs"]
        columns = ["event","mean", "min", "max", "count", "missing"]
    else:
        columns = ["event", "count"]
    
    latex_columns = df[columns]
    fmt = {
        "mean":    formatter,   # e.g. whole numbers
        "min":     formatter,   # e.g. whole numbers
        "max":     formatter,   # e.g. whole numbers
        "count":   "{:.0f}".format,   # e.g. whole numbers
        "missing": "{:.0f}".format,   # e.g. whole numbers
    }
    latex_code = latex_columns.to_latex(index=False, formatters=fmt)
    print(latex_code)

def get_vocab_table(df):
    
    # Create the plotting name column
    df.loc[:, "Event"] = df["event"].apply(map_to_long)
    df.loc[:, "Event (plotting)"] = df["event"].apply(map_to_short)

    df.loc[:, "idx"] = df["event"].map(map_name_to_idx)
    
    columns = ["event", "Event (plotting)", "idx"]
    
    latex_columns = df[columns]
    latex_code = latex_columns.to_latex(index=False)
    print(latex_code)

## Diagnoses

In [None]:
diagnoses = dm.meta_information["diagnosis_table"].copy()

get_stats_table(diagnoses)

In [None]:
get_vocab_table(diagnoses)

## Medications

In [None]:
medication = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0].copy()

get_stats_table(medication)

In [None]:
get_vocab_table(medication)

## Investigations

In [None]:
investigation = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0].copy()
# print(investigation.head())

get_stats_table(investigation, report_values=True)

In [None]:
get_vocab_table(investigation)

In [None]:
event_counts = (
    dm.train_set.tokenizer._event_counts
        .with_columns(
            pl.col("EVENT").map_dict(EVENT_NAME_LONG_MAP, default=pl.first())
        )
        .groupby("EVENT", maintain_order=True)     # 2️⃣ group by event name
        .agg(pl.all().sum()) 
)

print(event_counts)

## Time to load individual samples

In [None]:
from tqdm import tqdm
import numpy as np

times = []
start = time.time()   # starting time
for row_idx, row in enumerate(tqdm(dm.train_set)):
    # print(f"Sample loaded in {time.time()-start} seconds")
    times.append(time.time()-start)
    start = time.time()
    if row_idx > 100:
        break
print(np.mean(times))

## Loading times for batches

This will be over-estimated

In [None]:
times = []
start = time.time()   # starting time
for batch_idx, batch in enumerate(tqdm(dm.train_dataloader())):
    # print(f"batch loaded in {time.time()-start} seconds")    
    times.append(time.time()-start)
    start = time.time()
    if batch_idx > 2:
        break
print(np.mean(times))

In [None]:
dm.train_set.view_sample(1236, max_dynamic_events=12, report_time=True)

## Vocabulary

In [None]:
for key, item in dm.train_set.tokenizer._itos.items():
    print(f"{key}: {item}")