In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

# Perform sqlite operations on disk
%env SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
%env TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
!echo $SQLITE_TMPDIR
!echo $TMPDIR
!echo $USERPROFILE

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
env: SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
env: TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles



In [7]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal import TTETransformerForCausalSequenceModelling
from tqdm import tqdm
import time
import os
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/data


In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 128
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 3
    
opt = OptConfig()

In [5]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=20,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,912,046 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,207,449 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,226,576 samples


184 vocab elements


In [6]:
dm.train_set.view_sample(1, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 1 was 0.1088871955871582 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1997.0

Token                                                                      | Age               | Standardised value
O_E___height_1                                                             | 6639              | -0.03             
O_E___weight_2                                                             | 6639              | 0.19              
Systolic_blood_pressure_4                                                  | 6639              | -0.03             
Diastolic_blood_pressure_5                                                 | 6665              | 0.18              
Systolic_blood_pressure_4                                                  | 6665              | -0.03             
Body_mass_index_3                                                          | 7031              | 0.26              
Diastolic_blood_

In [8]:
print(dm.tokenizer._event_counts)

shape: (183, 3)
┌───────────────────────────────────┬───────────┬───────────┐
│ EVENT                             ┆ COUNT     ┆ FREQUENCY │
│ ---                               ┆ ---       ┆ ---       │
│ str                               ┆ u32       ┆ f64       │
╞═══════════════════════════════════╪═══════════╪═══════════╡
│ UNK                               ┆ 0         ┆ 0.0       │
│ ADDISONS_DISEASE                  ┆ 6691      ┆ 0.000002  │
│ CYSTICFIBROSIS                    ┆ 7053      ┆ 0.000002  │
│ SYSTEMIC_SCLEROSIS                ┆ 8772      ┆ 0.000002  │
│ SICKLE_CELL_DISEASE_V2            ┆ 11159     ┆ 0.000003  │
│ ADDISON_DISEASE                   ┆ 11794     ┆ 0.000003  │
│ DOWNSSYNDROME                     ┆ 17006     ┆ 0.000005  │
│ HAEMOCHROMATOSIS_V2               ┆ 18631     ┆ 0.000005  │
│ PLASMACELL_NEOPLASM_V2            ┆ 20301     ┆ 0.000006  │
│ SJOGRENSSYNDROME                  ┆ 23326     ┆ 0.000007  │
│ SYSTEMIC_LUPUS_ERYTHEMATOSUS      ┆ 26820     ┆ 0.00

In [9]:
meta_measurement = dm.train_set.meta_information["measurement_tables"]
# display(meta_measurement)

for measurement in dm.train_set.meta_information["measurement_tables"]["event"]:
    event_meta = meta_measurement[meta_measurement.event == measurement]
    display(event_meta.approx_lqr.to_numpy()[0])
    # print(event_meta[""])

-4.646246225231446

-36.374906271016215

3.4171338038240116

4.558863141370336

-4.327567643517323

-0.09380115366794299

2.025401527575622

0.23197832603292756

10.46932224355399

-245.08594949475705

2.0852967482708866

0.1803616361823952

-33.97440418922238

2.049077139786654

-11.221780004827133

48.77673642066362

-0.18513714481306787

-21.913521168501823

-2956.165860047326

5.337434618821507

25.888325065998195

0.27568323624704905

0.29179201759677853

13.031608961134598

2.9676883416277944

96.39517387655162

3.0213707473470905

0.6897185519972915

0.42469387969626915

0.10592671915519225

266.316085693917

24.96548937304727

74.5955180673202

-0.015579047822844971

-608.8215624025

-0.08415847981411151

-1101.1510071503315

0.44082834066354026

131.9086062119922

17.514388056157323

-178.95179427526466

-11.241276526203777

0.3410042161093746

-0.07071026791180035

-634.2811525461449

-1.3435911355534242

-6.462308605404846

31.29299777566158

-59.841460829377425

2.066499658254852

0.611548432375844

2.0704853785268815

29.78406994771627

-90.0527323030709

3.87030918674629

-26.335877811704524

2.9021453944583095

-347.07341172914676

132.4522144388257

-0.9687525953043634

1.638190669798231

-0.4748875165692883

0.42659851815511507

75.53188401700393

3.2771862521066706

10.41179363867223

-34.65095292326544

-7.819155475746674

0.3326011632094429

-0.011790829546063808

-615.9263350979847

-117.61110071295766

-1.1587890715552331

-4.416892478441605

30.61276663933536

13.084213334528137

-1.7519816723332982

2.0583367066129465

1.748473038434375

0.6425251680749247

23.4382982516102

-112.64091914730389

-4.5791376368802

4.277940480826725

-29.040429250235967

0.15515852397669994

3.1318756077076455

-500.2572283308142

132.03160863389553

-19.637593973229656

-0.9080738144377101

1.6670643034994281

-0.5305374511083958

0.4114468717090314

-93.14980914107466

-7.7969674401202775

-30.18561580872797

82.36669364711412

-1.5387806929769965

-1.9622082579070372

-27.585735519352404

13.39294673862522

-2.3081087331177175

0.5039839286086738

1.5954247277740436

-3.8672877345805388

-4.142799746985002

22.62504211170451

In [14]:
import pandas as pd
pd.set_option('display.max_rows', 1000) #replace n with the number of columns you want to see completely
for _key in dm.train_set.meta_information.keys():
    display(dm.train_set.meta_information[_key])


{'SEX':   category     count
 0        F  14278868
 1        I       683
 2        M  13841966,
 'IMD':    category    count
 0       NaN  2442913
 1       1.0  4789813
 2       2.0  4981882
 3       3.0  5072397
 4       4.0  5650727
 5       5.0  5183785,
 'ETHNICITY':   category     count
 0    ASIAN   2267997
 1    BLACK   1156866
 2  MISSING   8058247
 3    MIXED    485838
 4    OTHER    422374
 5    WHITE  15730195}

Unnamed: 0,event,count
0,ADDISONS_DISEASE,6691
1,ADDISON_DISEASE,11794
2,AF,731332
3,ALCOHOLMISUSE_V2,1125212
4,ALLCANCER_NOHAEM_NOBCC,1496973
5,ALLERGICRHINITISCONJ,3291165
6,ALL_DEMENTIA,528602
7,ANXIETY,3560978
8,ANY_DEAFNESS_HEARING_LOSS_V2,2282766
9,AORTICANEURYSM_V2,101134


Unnamed: 0,event,count,count_obs,digest,min,max,mean,approx_lqr,approx_uqr
0,25_Hydroxyvitamin_D2_level_92,782791,693470,"({'m': 0.0, 'c': 9.0}, {'m': 0.1, 'c': 112.0},...",0.0,686.0,3.908721,-4.646246,10.80667
1,25_Hydroxyvitamin_D3_level_90,809104,781118,"({'m': 0.1, 'c': 3.0}, {'m': 1.0, 'c': 314.0},...",0.0,951.8,47.14889,-36.374906,121.281425
2,AST___aspartate_transam_SGOT__46,1738489,1680613,"({'m': 0.0, 'c': 3901.0}, {'m': 0.770571428571...",0.0,15330.0,26.61963,3.417134,41.771075
3,AST_serum_level_47,10837982,10485351,"({'m': 0.0, 'c': 53.0}, {'m': 1.8, 'c': 1.0}, ...",-5.0,20700.0,27.25168,4.558863,41.966985
4,Albumin___creatinine_ratio_37,180911,78420,"({'m': -1.0, 'c': 1.0}, {'m': 0.0, 'c': 4213.0...",-1.0,12821.0,10.67255,-4.327568,8.831512
5,Basophil_count_22,86869779,85642540,"({'m': 0.0, 'c': 37098.0}, {'m': 0.01, 'c': 28...",-0.1,111111.0,0.05008992,-0.093801,0.160919
6,Blood_calcium_level_38,415717,385464,"({'m': 0.0, 'c': 33.0}, {'m': 1.0, 'c': 1.0}, ...",0.0,440.0,2.35298,2.025402,2.62252
7,Blood_urea_28,785766,671861,"({'m': 0.0, 'c': 2746.0}, {'m': 0.09, 'c': 1.0...",0.0,1265.0,6.513018,0.231978,11.019293
8,Body_mass_index_3,99868822,97759312,"({'m': 0.0, 'c': 14.0}, {'m': 0.05, 'c': 1.0},...",-32680.0,2100000000.0,293.305,10.469322,43.324813
9,Brain_natriuretic_peptide_level_66,229202,159318,"({'m': 0.0, 'c': 120.0}, {'m': 0.1, 'c': 1.0},...",0.0,500142.0,416.8786,-245.085949,483.166025
