# CPRD Notebook:
## Evaluation of a loaded pre-trained SurvivEHR-CR model for causal sequence modelling on CPRD.

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.run_experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}.")

# TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


## Load configurations
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `causal` (equivalently `pre-train` or `self-supervised`) experiment type, which will lead to running the ```CausalExperiment```. 

We tell this experiment that no further training is needed. Additionally, we do choose to perform testing (true by default). As this is a causal model, this would not test the ability to predict the outcomes of interest, but to perform the causal modelling task on the chosen dataset. In this notebook, this is chosen to be the original training dataset, containing over 7 billion medical events.

Finally, we set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             "experiment.type='causal'",
                             "experiment.run_id='CR_11M'",
                             "experiment.train=False",
                             "experiment.test=True",
                             # Dataloader
                             "data.min_workers=12",
                             # Optimiser
                             "optim.limit_test_batches=null",
                            ]
                 )     

print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 12
  global_diagnoses: false
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
experiment:
  type: causal
  project_name: SurvEHR_${head.SurvLayer}
  run_id: CR_11M_old
  train: false
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes: null
optim:
  num_epochs: 1
  learning_rate: 0.0003
  scheduler: CAWarmRestarts
  scheduler_periods: 5000
  scheduler_warmup: true
  lr_cosine_decay_period: 10000000.0
  v

# Run experiment

In [4]:
# Run experiment, returning configured loader and best (in the case of training) model
model, dm = run(cfg)

# End logging session to print summary
wandb.finish()

print(f"The loaded model contains {sum(p.numel() for p in model.parameters())/1e6} M parameters")


INFO:root:Running cr on 72 CPUs and 2 GPUs
INFO:root:
Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/. This will be loaded in causal form.
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta i

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
test_loss,▁
test_loss_desurv,▁
test_loss_values,▁
trainer/global_step,▁

0,1
epoch,0.0
test_loss,-17.54701
test_loss_desurv,2.048
test_loss_values,-37.14216
trainer/global_step,0.0


The loaded model contains 11.433294 M parameters


In [5]:
print(model)

SurvStreamGPTForCausalModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (static_proj): Linear(in_features=16, out_features=384, bias=True)
      (dynamic_embedding_layer): SplitDynamicEmbeddingLayer(
        (cat_event_embed_layer): Embedding(265, 384, padding_idx=0)
        (cat_event_proj): Linear(in_features=384, out_features=384, bias=True)
        (num_value_embed_layer): EmbeddingBag(265, 384, mode='sum', padding_idx=0)
        (num_value_proj): Linear(in_features=384, out_features=384, bias=True)
      )
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_pr

In [6]:
dm.tokenizer._event_counts["EVENT"][-5:].to_list()

['NSAIDS_oral_OPTIMAL_final',
 'Diastolic_blood_pressure_5',
 'Systolic_blood_pressure_4',
 'Statins',
 'Lipid_lowering_drugs_Optimal']