In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


## Build configurations

In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../modelling/SurvStreamGPT/confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_SingleRisk11M", overrides=[])


# cfg.data.batch_size = 16
# cfg.transformer.block_size = 32
# # cfg.transformer.n_layer = 10

In [4]:
print(OmegaConf.to_yaml(cfg))

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 20
  global_diagnoses: false
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information.pickle
experiment:
  project_name: SurvStreamGPT_${head.SurvLayer}
  run_id: PreTrain_${head.SurvLayer}_11M_${experiment.seed}
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
optim:
  num_epochs: 1
  learning_rate: 0.0001
  scheduler: CAWarmRestarts
  scheduler_periods: 5000
  scheduler_warmup: true
  lr_cosine_decay_period: 10000000.0
  val_check_

In [5]:
# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            inclusion_conditions=["COUNTRY = 'E'"],
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)

INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,613,894 samples
INFO:root:Loaded /rds/projects/g/g

265 vocab elements


In [None]:
start = time.time()   # starting time
for batch in dm.train_dataloader():
    break
print(f"batch loaded in {time.time()-start} seconds")    
    
# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

# View an example sample

In [7]:
dm.train_set.view_sample(12348, max_dynamic_events=120, report_time=True)

Time to retrieve sample index 12348 was 0.03458809852600098 seconds

SEX                 | F
IMD                 | 1.0
ETHNICITY           | WHITE
birth_year          | 1965.0
Sequence of 99 events

Token                                                                      | Age               | Standardised value
Diastolic_blood_pressure_5                                                 | 11137             | -0.04             
Systolic_blood_pressure_4                                                  | 11137             | -0.03             
Diastolic_blood_pressure_5                                                 | 11554             | 0.04              
Systolic_blood_pressure_4                                                  | 11554             | -0.03             
Anxiolytics_mumpredict                                                     | 12157             | nan               
Benzodiazepines                                                            | 12157             | nan     

In [8]:
times = []
start_time = time.time()
for _idx, row in enumerate(tqdm(dm.train_dataloader())):
    times.append(time.time() - start_time)
    start_time = time.time()
    if _idx > 10:
        break
print(np.mean(times))

  0%|          | 11/359580 [00:16<146:46:04,  1.47s/it]

0.8016159335772196





In [9]:
print(np.mean(times))

0.8016159335772196


## View counts and frequency of each event

In [8]:
print(dm.tokenizer._event_counts)

shape: (264, 3)
┌───────────────────────────────────┬───────────┬───────────┐
│ EVENT                             ┆ COUNT     ┆ FREQUENCY │
│ ---                               ┆ ---       ┆ ---       │
│ str                               ┆ u32       ┆ f64       │
╞═══════════════════════════════════╪═══════════╪═══════════╡
│ UNK                               ┆ 0         ┆ 0.0       │
│ ADDISONS_DISEASE                  ┆ 6691      ┆ 8.8559e-7 │
│ CYSTICFIBROSIS                    ┆ 7053      ┆ 9.3350e-7 │
│ SYSTEMIC_SCLEROSIS                ┆ 8772      ┆ 0.000001  │
│ SICKLE_CELL_DISEASE_V2            ┆ 11159     ┆ 0.000001  │
│ ADDISON_DISEASE                   ┆ 11794     ┆ 0.000002  │
│ DOWNSSYNDROME                     ┆ 17006     ┆ 0.000002  │
│ HAEMOCHROMATOSIS_V2               ┆ 18631     ┆ 0.000002  │
│ PLASMACELL_NEOPLASM_V2            ┆ 20301     ┆ 0.000003  │
│ SJOGRENSSYNDROME                  ┆ 23326     ┆ 0.000003  │
│ SYSTEMIC_LUPUS_ERYTHEMATOSUS      ┆ 26820     ┆ 0.00

## View the meta information that was collected during Polars transformation of data into ML friendly parquet files

In [9]:
import pandas as pd
pd.set_option('display.max_rows', 1000) #replace n with the number of columns you want to see completely
for _key in dm.train_set.meta_information.keys():
    display(dm.train_set.meta_information[_key])


{'SEX':   category     count
 0        F  14278868
 1        I       683
 2        M  13841966,
 'IMD':    category    count
 0       NaN  2442913
 1       1.0  4789813
 2       2.0  4981882
 3       3.0  5072397
 4       4.0  5650727
 5       5.0  5183785,
 'ETHNICITY':   category     count
 0    ASIAN   2267997
 1    BLACK   1156866
 2  MISSING   8058247
 3    MIXED    485838
 4    OTHER    422374
 5    WHITE  15730195}

Unnamed: 0,event,count
0,ADDISONS_DISEASE,6691
1,ADDISON_DISEASE,11794
2,AF,731332
3,ALCOHOLMISUSE_V2,1125212
4,ALLCANCER_NOHAEM_NOBCC,1496973
5,ALLERGICRHINITISCONJ,3291165
6,ALL_DEMENTIA,528602
7,ANXIETY,3560978
8,ANY_DEAFNESS_HEARING_LOSS_V2,2282766
9,AORTICANEURYSM_V2,101134


Unnamed: 0,event,count,count_obs,digest,min,max,mean,approx_lqr,approx_uqr
0,25_Hydroxyvitamin_D2_level_92,782791,693470,"({'m': 0.0, 'c': 9.0}, {'m': 0.1, 'c': 112.0},...",0.0,686.0,3.908721,-4.673936,10.857805
1,25_Hydroxyvitamin_D3_level_90,809104,781118,"({'m': 0.1, 'c': 3.0}, {'m': 1.0, 'c': 314.0},...",0.0,951.8,47.14889,-36.770239,121.65323
2,AST___aspartate_transam_SGOT__46,1738489,1680613,"({'m': 0.0, 'c': 3901.0}, {'m': 0.770571428571...",0.0,15330.0,26.61963,3.417134,41.771075
3,AST_serum_level_47,10837982,10485351,"({'m': 0.0, 'c': 53.0}, {'m': 1.8, 'c': 1.0}, ...",-5.0,20700.0,27.25168,4.558863,41.966985
4,Albumin___creatinine_ratio_37,180911,78420,"({'m': -1.0, 'c': 1.0}, {'m': 0.0, 'c': 4213.0...",-1.0,12821.0,10.67255,-4.318622,8.826145
5,Basophil_count_22,86869779,85642540,"({'m': 0.0, 'c': 37098.0}, {'m': 0.01, 'c': 28...",-0.1,111111.0,0.05008992,-0.093801,0.160919
6,Blood_calcium_level_38,415717,385464,"({'m': 0.0, 'c': 33.0}, {'m': 1.0, 'c': 1.0}, ...",0.0,440.0,2.35298,2.025402,2.62252
7,Blood_urea_28,785766,671861,"({'m': 0.0, 'c': 2746.0}, {'m': 0.09, 'c': 1.0...",0.0,1265.0,6.513018,0.158376,11.063455
8,Body_mass_index_3,99868822,97759312,"({'m': 0.0, 'c': 14.0}, {'m': 0.05, 'c': 1.0},...",-32680.0,2100000000.0,293.305,10.477956,43.319633
9,Brain_natriuretic_peptide_level_66,229202,159318,"({'m': 0.0, 'c': 120.0}, {'m': 0.1, 'c': 1.0},...",0.0,500142.0,416.8786,-245.377936,483.324994


# Outlier bounds

### Get the pre-processed estimated quantile based upper and lower outlier bounds 

In [15]:
meta_measurement = dm.train_set.meta_information["measurement_tables"]

# Fille with quantile based estimates
measurement_limit_list = {}
for _, _row in meta_measurement.iterrows():
    measurement_limit_list[_row.event] = [_row.approx_lqr, _row.approx_uqr]
display(measurement_limit_list)

{'25_Hydroxyvitamin_D2_level_92': [-4.67393603726698, 10.857804944243398],
 '25_Hydroxyvitamin_D3_level_90': [-36.770239453029085, 121.65323028859444],
 'AST___aspartate_transam_SGOT__46': [3.4171338038240116, 41.77107521642682],
 'AST_serum_level_47': [4.558863141370336, 41.966985461548944],
 'Albumin___creatinine_ratio_37': [-4.318622277704492, 8.826145270522485],
 'Basophil_count_22': [-0.09380115366794299, 0.1609191424008079],
 'Blood_calcium_level_38': [2.025401527575622, 2.6225197728812093],
 'Blood_urea_28': [0.15837579012681058, 11.063454600066638],
 'Body_mass_index_3': [10.477955941584963, 43.319632799587666],
 'Brain_natriuretic_peptide_level_66': [-245.3779363349517,
  483.32499369786353],
 'Calcium_adjusted_level_41': [2.0852967482708866, 2.6282899313330343],
 'Calculated_LDL_cholesterol_level_103': [0.19823248550905048,
  5.486489620225761],
 'Combined_total_vitamin_D2_and_D3_level_93': [-33.92579046750491,
  127.10458377019135],
 'Corrected_serum_calcium_level_42': [2.04

### Update with the clinician set bounds 

In [16]:
# manually set clinician provided values
# Defaulting to lower and upper quantile approximations otherwise (TODO: modify these quantiles to be less rigid)
# when missing both lower and upper I have included name (not formatted) as comment

# 25-HYDROXYVITAMIN_D2_LEVEL_92
# ALBUMIN_/_CREATININE_RATIO_37
measurement_limit_list["AST___aspartate_transam_SGOT__46"]            = [0.1, 40000]
measurement_limit_list["AST_serum_level_47"]                          = [0.1, 40000]
measurement_limit_list["Basophil_count_22"][1]                        = 10
measurement_limit_list["Blood_calcium_level_38"]                      = [0.1, 10]
measurement_limit_list["Blood_urea_28"][0]                            = 0.1
measurement_limit_list["Body_mass_index_3"]                           = [5, 130]
# BRAIN_NATRIURETIC_PEPTIDE_LEVEL_66
measurement_limit_list["Calcium_adjusted_level_41"]                   = [0.1, 10]
measurement_limit_list["Calculated_LDL_cholesterol_level_103"]        = [0, 30]
measurement_limit_list["Combined_total_vitamin_D2_and_D3_level_93"][0]= 0.1
measurement_limit_list["Corrected_serum_calcium_level_42"]            = [0.1, 10]
measurement_limit_list["Current_smoker_83"]                           = [0, 80]                   # but manually treated as always missing inside model anyway - as awful quality
measurement_limit_list["Diastolic_blood_pressure_5"]                  = [20, 400]
measurement_limit_list["eGFR_using_creatinine_CKD_EPI_per_1_73_square_metres_33"][1] = 300
measurement_limit_list["Eosinophil_count_21"][1]                      = 10
measurement_limit_list["Erythrocyte_sedimentation_rate_61"][0]        = 0
measurement_limit_list["Ex_smoker_84"]                                = [0, 80]                   # but manually treated as always missing inside model anyway - as awful quality
measurement_limit_list["Free_T4_level_76"]                            = [0.1, 1000]
measurement_limit_list["GFR_calculated_abbreviated_MDRD_34"]          = [0.1, 300]
measurement_limit_list["Haematocrit___PCV_16"]                        = [0.1, 1]
measurement_limit_list["Haematocrit_15"]                              = [0, 1]
measurement_limit_list["Haemoglobin_A1c_level___IFCC_standardised_6"] = [5, 250]
measurement_limit_list["Haemoglobin_A1c_level_8"]                     = [5, 250]
measurement_limit_list["Haemoglobin_estimation_9"]                    = [10, 300]
measurement_limit_list["HbA1c_level__DCCT_aligned__7"]                = [5, 250]
measurement_limit_list["INR___international_normalised_ratio_81"][1]  = 100
measurement_limit_list["International_normalised_ratio_82"][1]        = 100
measurement_limit_list["Lymphocyte_count_20"]                         = [0.1, 200]
measurement_limit_list["Mean_corpusc_Hb_conc__MCHC__14"]              = [10, 500]
measurement_limit_list["Mean_corpusc_haemoglobin_MCH__13"]            = [1, 100]
measurement_limit_list["Mean_corpuscular_volume__MCV__11"]            = [10, 250]
measurement_limit_list["Monocyte_count_23"][1]                        = 20
measurement_limit_list["Neutrophil_count_19"]                         = [0, 50]
measurement_limit_list["Never_smoked_tobacco_85"]                     = [0, 0]                 # but manually treated as always missing inside model anyway
measurement_limit_list["Non_HDL_cholesterol_level_108"][0]            = 0
# N_TERMINAL_PRO-BRAIN_NATRIURETIC_PEPTIDE_LEVEL_67
measurement_limit_list["O_E___height_1"]                              = [20, 272]
measurement_limit_list["O_E___weight_2"]                              = [0.25, 650]
measurement_limit_list["Serum_alanine_aminotransferase_level_45"]     = [0.1, 40000]
measurement_limit_list["Plasma_albumin_level_52"]                     = [2, 180]
measurement_limit_list["Plasma_alkaline_phosphatase_level_49"][0]     = 1
# PLASMA_B_NATRIURETIC_PEPTIDE_LEVEL_69
measurement_limit_list["Plasma_calcium_level_40"]                     = [0.1, 10]
measurement_limit_list["Plasma_cholesterol_HDL_ratio_96"]             = [1, 250]
measurement_limit_list["Plasma_corrected_calcium_level_43"]           = [0.1, 1]
measurement_limit_list["Plasma_creatinine_level_32"][0]               = 1
# PLASMA_C_REACTIVE_PROTEIN_60
measurement_limit_list["Plasma_ferritin_level_62"][1]                 = 500000
measurement_limit_list["Plasma_free_T4_level_77"]                     = [0.1, 1000]
# PLASMA_GAMMA-GLUTAMYL_TRANSFERASE_LEVEL_58
measurement_limit_list["Plasma_HDL_cholesterol_level_101"][1]         = 3.9
measurement_limit_list["Plasma_LDL_cholesterol_level_104"]            = [0, 30]
# PLASMA_N-TERMINAL_PRO_B-TYPE_NATRIURETIC_PEPTIDE_CONC_70
measurement_limit_list["Plasma_potassium_level_27"]                   = [0.1, 10]
# PLASMA_PRO-BRAIN_NATRIURETIC_PEPTIDE_LEVEL_64
measurement_limit_list["Plasma_sodium_level_25"]                      = [90, 200]
measurement_limit_list["Plasma_total_bilirubin_level_54"][1]          = 900
measurement_limit_list["Plasma_total_cholesterol_level_99"][1]        = 50
# PLASMA_TRIGLYCERIDE_LEVEL_106
measurement_limit_list["Plasma_TSH_level_73"][1]                      = 500
measurement_limit_list["Plasma_urea_level_30"][0]                     = 0.1
measurement_limit_list["Platelet_count_12"]                           = [0, 10e6]
measurement_limit_list["Red_blood_cell_distribution_width_17"][1]     = 100
measurement_limit_list["Red_blood_cell__RBC__count_10"]               = [0, 500]
measurement_limit_list["Serum_25_Hydroxy_vitamin_D3_level_88"][0]     = 0.1
measurement_limit_list["Plasma_alanine_aminotransferase_level_44"]    = [0.1, 40000]                   # Jenny named this SERUM_ALANINE_AMINOTRANSFERASE_LEVEL_44?
measurement_limit_list["Serum_albumin_51"]                            = [2, 180]
measurement_limit_list["Serum_alkaline_phosphatase_50"]               = [1, 20000]
measurement_limit_list["Serum_bilirubin_level_53"][1]                 = 900
measurement_limit_list["Serum_calcium_39"]                            = [0.1, 10]
measurement_limit_list["Serum_cholesterol_97"][1]                     = 50
measurement_limit_list["Serum_cholesterol_HDL_ratio_94"]              = [1, 250]
measurement_limit_list["Serum_creatinine_31"]                         = [1, 10000]
# SERUM_C_REACTIVE_PROTEIN_LEVEL_59
measurement_limit_list["Serum_ferritin_63"][1]                        = 5e5
#SERUM_FOLATE_80
measurement_limit_list["Serum_free_T4_level_75"]                      = [0.1, 10000]
# SERUM_GAMMA-GLUTAMYL_TRANSFERASE_LEVEL_57
measurement_limit_list["Serum_HDL_cholesterol_level_100"]             = [0, 3.9]
measurement_limit_list["Serum_LDL_cholesterol_level_102"]             = [0, 30]
# SERUM_N-TERMINAL_PRO_B-TYPE_NATRIURETIC_PEPTIDE_CONC_68
measurement_limit_list["Serum_non_high_density_lipoprotein_cholesterol_level_107"] = [0, 3.9]
measurement_limit_list["Serum_potassium_26"]                          = [0.1, 10]
# SERUM_PRO-BRAIN_NATRIURETIC_PEPTIDE_LEVEL_65
measurement_limit_list["Serum_sodium_24"]                             = [90, 200]
measurement_limit_list["Serum_T4_level_78"]                           = [0.1, 1000]
# SERUM_TOTAL_25-HYDROXY_VITAMIN_D_LEVEL_87
measurement_limit_list["Serum_total_bilirubin_level_56"][1]           = 900
measurement_limit_list["Serum_total_cholesterol_level_98"][1]         = 50
measurement_limit_list["Serum_triglycerides_105"]                     = [0, 100]
measurement_limit_list["Serum_TSH_level_71"][1]                       = 500
measurement_limit_list["Serum_urea_level_29"][0]                      = 0.1
measurement_limit_list["Serum_vitamin_B12_79"]                        = [5, 20000]
# measurement_limit_list["Serum_vitamin_D2_level_89"] = 
measurement_limit_list["Serum_vitamin_D_86"]                          = [0.1, 1000]
measurement_limit_list["Systolic_blood_pressure_4"]                   = [30, 400]
measurement_limit_list["Total_25_hydroxyvitamin_D_level_91"]          = [0.1, 1000]
measurement_limit_list["Total_alkaline_phosphatase_48"][0]            = 0.1
measurement_limit_list["Total_bilirubin_55"][1]                       = 900
measurement_limit_list["Total_cholesterol_HDL_ratio_95"]              = [1, 250]
measurement_limit_list["Total_white_cell_count_18"]                   = [0.1, 5000]
measurement_limit_list["TSH___thyroid_stim_hormone_72"][1]            = 500
measurement_limit_list["TSH_level_74"][1]                             = 500
measurement_limit_list["Urine_albumin_creatinine_ratio_35"][1]        = 300
# measurement_limit_list["Urine_microalbumin_creatinine_ratio_36"] 


In [17]:
for _key in measurement_limit_list.keys():
    print(f"{_key}".ljust(80) + f"{measurement_limit_list[_key][0]:.2f}".ljust(10) + f"to\t {measurement_limit_list[_key][1]:.2f}")

25_Hydroxyvitamin_D2_level_92                                                   -4.67     to	 10.86
25_Hydroxyvitamin_D3_level_90                                                   -36.77    to	 121.65
AST___aspartate_transam_SGOT__46                                                0.10      to	 40000.00
AST_serum_level_47                                                              0.10      to	 40000.00
Albumin___creatinine_ratio_37                                                   -4.32     to	 8.83
Basophil_count_22                                                               -0.09     to	 10.00
Blood_calcium_level_38                                                          0.10      to	 10.00
Blood_urea_28                                                                   0.10      to	 11.06
Body_mass_index_3                                                               5.00      to	 130.00
Brain_natriuretic_peptide_level_66                                              -245.38   to	

### Update the meta information on file

In [18]:
# Load from file - this is the pre-processed information the dataloader loads back in. It is used for building the tokenizer and on-the-fly standardisation
with open(cfg.data.path_to_ds + 'meta_information.pickle', 'rb') as handle:
    meta_information = pickle.load(handle)

In [19]:
# we can update the contents of this file from the pre-processed quantile based limits to the clinician derived feasible limits
for _key in measurement_limit_list.keys():
    meta_information["measurement_tables"].loc[meta_information["measurement_tables"].event == _key, "approx_lqr"] = measurement_limit_list[_key][0]
    meta_information["measurement_tables"].loc[meta_information["measurement_tables"].event == _key, "approx_uqr"] = measurement_limit_list[_key][1]
    
print(meta_information["measurement_tables"].head())

                              event     count  count_obs  \
0     25_Hydroxyvitamin_D2_level_92    782791     693470   
1     25_Hydroxyvitamin_D3_level_90    809104     781118   
2  AST___aspartate_transam_SGOT__46   1738489    1680613   
3                AST_serum_level_47  10837982   10485351   
4     Albumin___creatinine_ratio_37    180911      78420   

                                              digest  min      max       mean  \
0  ({'m': 0.0, 'c': 9.0}, {'m': 0.1, 'c': 112.0},...  0.0    686.0   3.908721   
1  ({'m': 0.1, 'c': 3.0}, {'m': 1.0, 'c': 314.0},...  0.0    951.8  47.148892   
2  ({'m': 0.0, 'c': 3901.0}, {'m': 0.770571428571...  0.0  15330.0  26.619633   
3  ({'m': 0.0, 'c': 53.0}, {'m': 1.8, 'c': 1.0}, ... -5.0  20700.0  27.251680   
4  ({'m': -1.0, 'c': 1.0}, {'m': 0.0, 'c': 4213.0... -1.0  12821.0  10.672548   

   approx_lqr    approx_uqr  
0   -4.673936     10.857805  
1  -36.770239    121.653230  
2    0.100000  40000.000000  
3    0.100000  40000.000000  
4 

In [20]:
with open(cfg.data.path_to_ds  + 'meta_information_QuantJenny.pickle', 'wb') as handle:
    pickle.dump(meta_information, handle, protocol=pickle.HIGHEST_PROTOCOL)

# test this looks correct
with open(cfg.data.path_to_ds  + 'meta_information_QuantJenny.pickle', 'rb') as handle:
    meta_information_new = pickle.load(handle)
print(meta_information_new["measurement_tables"].head())

                              event     count  count_obs  \
0     25_Hydroxyvitamin_D2_level_92    782791     693470   
1     25_Hydroxyvitamin_D3_level_90    809104     781118   
2  AST___aspartate_transam_SGOT__46   1738489    1680613   
3                AST_serum_level_47  10837982   10485351   
4     Albumin___creatinine_ratio_37    180911      78420   

                                              digest  min      max       mean  \
0  ({'m': 0.0, 'c': 9.0}, {'m': 0.1, 'c': 112.0},...  0.0    686.0   3.908721   
1  ({'m': 0.1, 'c': 3.0}, {'m': 1.0, 'c': 314.0},...  0.0    951.8  47.148892   
2  ({'m': 0.0, 'c': 3901.0}, {'m': 0.770571428571...  0.0  15330.0  26.619633   
3  ({'m': 0.0, 'c': 53.0}, {'m': 1.8, 'c': 1.0}, ... -5.0  20700.0  27.251680   
4  ({'m': -1.0, 'c': 1.0}, {'m': 0.0, 'c': 4213.0... -1.0  12821.0  10.672548   

   approx_lqr    approx_uqr  
0   -4.673936     10.857805  
1  -36.770239    121.653230  
2    0.100000  40000.000000  
3    0.100000  40000.000000  
4 

### Now we can load this file into the loader and it will use the updated bounds

In [36]:
# Build 

dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            inclusion_conditions=["COUNTRY = 'E'"],
                            overwrite_meta_information=cfg.data.path_to_ds  + 'meta_information_QuantJenny.pickle',
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = dm.train_set.meta_information["measurement_tables"][dm.train_set.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression)
# display(measurements_for_univariate_regression)

display(dm.train_set.meta_information)

INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,613,894 samples
INFO:root:Loaded /rds/p

{'static_table': {'SEX':   category     count
  0        F  14278868
  1        I       683
  2        M  13841966,
  'IMD':    category    count
  0       NaN  2442913
  1       1.0  4789813
  2       2.0  4981882
  3       3.0  5072397
  4       4.0  5650727
  5       5.0  5183785,
  'ETHNICITY':   category     count
  0    ASIAN   2267997
  1    BLACK   1156866
  2  MISSING   8058247
  3    MIXED    485838
  4    OTHER    422374
  5    WHITE  15730195},
 'diagnosis_table':                                   event    count
 0                      ADDISONS_DISEASE     6691
 1                       ADDISON_DISEASE    11794
 2                                    AF   731332
 3                      ALCOHOLMISUSE_V2  1125212
 4                ALLCANCER_NOHAEM_NOBCC  1496973
 5                  ALLERGICRHINITISCONJ  3291165
 6                          ALL_DEMENTIA   528602
 7                               ANXIETY  3560978
 8          ANY_DEAFNESS_HEARING_LOSS_V2  2282766
 9                  