In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

# Perform sqlite operations on disk
%env SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
%env TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
!echo $SQLITE_TMPDIR
!echo $TMPDIR
!echo $USERPROFILE

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
env: SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
env: TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles



In [7]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal import TTETransformerForCausalSequenceModelling
from tqdm import tqdm
import time
import os
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/data


In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 128
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 3
    
opt = OptConfig()

In [5]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=20,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,912,046 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,207,449 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,226,576 samples


184 vocab elements


In [6]:
dm.train_set.view_sample(1, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 1 was 0.1088871955871582 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1997.0

Token                                                                      | Age               | Standardised value
O_E___height_1                                                             | 6639              | -0.03             
O_E___weight_2                                                             | 6639              | 0.19              
Systolic_blood_pressure_4                                                  | 6639              | -0.03             
Diastolic_blood_pressure_5                                                 | 6665              | 0.18              
Systolic_blood_pressure_4                                                  | 6665              | -0.03             
Body_mass_index_3                                                          | 7031              | 0.26              
Diastolic_blood_

In [8]:
print(dm.tokenizer._event_counts)

shape: (183, 3)
┌───────────────────────────────────┬───────────┬───────────┐
│ EVENT                             ┆ COUNT     ┆ FREQUENCY │
│ ---                               ┆ ---       ┆ ---       │
│ str                               ┆ u32       ┆ f64       │
╞═══════════════════════════════════╪═══════════╪═══════════╡
│ UNK                               ┆ 0         ┆ 0.0       │
│ ADDISONS_DISEASE                  ┆ 6691      ┆ 0.000002  │
│ CYSTICFIBROSIS                    ┆ 7053      ┆ 0.000002  │
│ SYSTEMIC_SCLEROSIS                ┆ 8772      ┆ 0.000002  │
│ SICKLE_CELL_DISEASE_V2            ┆ 11159     ┆ 0.000003  │
│ ADDISON_DISEASE                   ┆ 11794     ┆ 0.000003  │
│ DOWNSSYNDROME                     ┆ 17006     ┆ 0.000005  │
│ HAEMOCHROMATOSIS_V2               ┆ 18631     ┆ 0.000005  │
│ PLASMACELL_NEOPLASM_V2            ┆ 20301     ┆ 0.000006  │
│ SJOGRENSSYNDROME                  ┆ 23326     ┆ 0.000007  │
│ SYSTEMIC_LUPUS_ERYTHEMATOSUS      ┆ 26820     ┆ 0.00