# Evaluation of a pre-trained SurvivEHR model


Environment setup for BlueBear (Birmingham HPC)

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")


Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import wandb
from hydra import compose, initialize
import polars as pl
pl.Config.set_tbl_rows(10000)
# import pandas as pd
# pd.options.display.max_rows = 10000
import logging
logging.basicConfig(level=logging.INFO)
import torch
torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

from FastEHR.dataloader.foundational_loader import FoundationalDataModule
from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

INFO:numexpr.utils:Note: detected 72 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 72 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.


Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


## Choosing configurations
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `zero-shot` experiment type, which will lead to running a ```CausalExperiment```. 

We tell this experiment that no further training is needed. Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

# Run small (11M) Competing-Risk model experiment

```

```

In [5]:
# pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr', 'SurvivEHR-cr-v1', 'SurvivEHR-cr-v1-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337',
                        # 'SurvivEHR-cr-small-192', "SurvivEHR-cr-small-192-v1"]

pre_trained_model, config_name = "SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1", "config_CompetingRisk11M"
# pre_trained_model, config_name = "SurvivEHR-cr-small-debug7_exp1000-v1-v4", "config_CompetingRisk11M"
# pre_trained_model, config_name = "SurvivEHR-cr-small-debug7_exp1000", "config_CompetingRisk11M"

# pre_trained_model, config_name = "SurvivEHR-cr-big-debug3_2_exp1000-v1", "config_CompetingRiskMOTOR"

print(pre_trained_model)

SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1


# Embedding with different stratification

In [6]:
root = "CPRD.examples.modelling.SurvivEHR.callbacks.embedding_labels."
core_stratification_methods = [root + _func for _func in ["log_token_count",
                                                            "Token_count",]]

root = "CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers."
custom_stratification_methods = [root + _func for _func in ["Number_of_preexisting_comorbidities",
                                                            "Collection_history",
                                                            "Type2_Diabetes_history",
                                                            "CVD_history",
                                                            "Hypertension_history",
                                                            "Gender",
                                                            "IMD", 
                                                            "Ethnicity", 
                                                            "Birth_year"]]

custom_stratification_methods = core_stratification_methods + custom_stratification_methods
display(custom_stratification_methods)

['CPRD.examples.modelling.SurvivEHR.callbacks.embedding_labels.log_token_count',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_labels.Token_count',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.Number_of_preexisting_comorbidities',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.Collection_history',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.Type2_Diabetes_history',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.CVD_history',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.Hypertension_history',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.Gender',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.IMD',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.Ethnicity',
 'CPRD.examples.modelling.SurvivEHR.callbacks.embedding_wrappers.Birth_year']

In [7]:
wandb.finish()


# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../../confs", job_name="causal_metric_testing_notebook"):
    cfg = compose(config_name=config_name, 
                  overrides=[# Experiment setup
                             "experiment.project_name='Evaluating fine-tuned models'",
                             "experiment.type=fine-tune-sr",
                             f"experiment.run_id='{pre_trained_model}'",
                             f"experiment.fine_tune_id='MM_fine-tune-sr-Afalse8-Ns20000-s1'",
                             "experiment.train=False",
                             "experiment.test=True",
                             "data.batch_size=128",
                             "data.path_to_ds='/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/'",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=12",
                             "data.global_diagnoses=True",
                             "data.repeating_events=False",
                             "optim.limit_test_batches=1",
                             "fine_tuning.custom_outcome_method._target_='CPRD.examples.modelling.SurvivEHR.helpers.custom_mm_outcomes'",
                             f"fine_tuning.custom_stratification_method._target_={custom_stratification_methods}",
                             "fine_tuning.use_callbacks.performance_metrics=False",
                             "fine_tuning.use_callbacks.hidden_embedding=100",
                             # "transformer.block_size=1000",
                            ]
                 )     

model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

wandb.finish()

INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Scaling supervised target ages by a factor of 1.0 times the context scale.
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.trainer.connectors.signal_connector:SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

  log_labels = [np.log(_label) for _label in labels]
/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:211: You called `self.log('test_loss_values', ...)` in your `test_step` but the value needs to be floating point. Converting it to torch.float32.


Loaded model with 11.370501 M parameters


0,1
epoch,▁
test_loss,▁
test_loss_desurv,▁
test_loss_values,▁
trainer/global_step,▁

0,1
epoch,0.0
test_loss,0.6336
test_loss_desurv,0.6336
test_loss_values,0.0
trainer/global_step,0.0


In [None]:
# wandb.finish()


# # load the configuration file, override any settings 
# with initialize(version_base=None, config_path="../../../confs", job_name="causal_metric_testing_notebook"):
#     cfg = compose(config_name=config_name, 
#                   overrides=[# Experiment setup
#                              "experiment.project_name='Evaluating pre-trained models'",
#                              f"experiment.run_id='{pre_trained_model}'",
#                              "experiment.train=False",
#                              "experiment.test=True",
#                              "data.batch_size=128",
#                              "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
#                              "data.min_workers=12",
#                              "optim.limit_test_batches=1",
#                              f"fine_tuning.custom_stratification_method._target_={custom_stratification_methods}",
#                              "fine_tuning.use_callbacks.performance_metrics=False",
#                              "fine_tuning.use_callbacks.hidden_embedding=100",
#                              # "transformer.block_size=2000",
#                             ]
#                  )     

# model, dm = run(cfg)
# print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

# wandb.finish()

In [None]:
wandb.finish()
print(pre_trained_model)

In [None]:
# for batch in dm.test_dataloader():
#     break
for patient_tokens, patient_mask in zip(batch["tokens"], batch["attention_mask"]):
    break
print(patient_tokens)
print(patient_mask)

In [None]:

print(dm.train_set._decode_covariates(batch["static_covariates"]).keys())

key = "birth_year"
key_static_covariates = dm.train_set._decode_covariates(batch["static_covariates"])[key]
for patient_tokens in key_static_covariates:
    print(patient_tokens.numpy())

In [None]:
display(dm.encode(['IHDINCLUDINGMI_OPTIMALV2', 'ISCHAEMICSTROKE_V2', 'MINFARCTION', 'STROKEUNSPECIFIED_V2', 'STROKE_HAEMRGIC']))
display(dm.encode(['HYPERTENSION']))
# display(dm.decode([95, 175, 263,249]).split(" "))

In [None]:
dm.tokenizer._event_counts["EVENT"][-5:].to_list()