# CPRD Notebook:
## Evaluation of fine-tuning the pre-trained SurvivEHR-CR model on a supervised cohort study.

Cohort study: predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population.

This notebook quantifies the performance obtained when fine-tuning the pre-trained model to a sub-population.

In [2]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.run_experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
# from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


In [4]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             "experiment.type='clinical prediction model'",
                             "experiment.run_id='CR_11M_new'",
                             "experiment.train=True",
                             "experiment.test=True",
                             'experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]',
                             # Dataloader
                             "data.path_to_ds=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/",
                             # "data.batch_size=512",
                             "data.min_workers=12",
                             # Optimiser
                             "optim.num_epochs=10",
                             "optim.limit_test_batches=null",
                             "optim.scheduler=ReduceOnPlateau",
                             "optim.scheduler_warmup=False",
                             "optim.val_check_interval=50",
                             "optim.early_stop=True",
                             # Head
                             "head.surv_weight=1",
                             "head.value_weight=0",
                            ]
                 )     

print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"


is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 12
  global_diagnoses: false
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
experiment:
  type: clinical prediction model
  project_name: SurvEHR_${head.SurvLayer}
  run_id: CR_11M_new
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
  fine_tune_outcomes:
  - IHDINCLUDINGMI_OPTIMALV2
  - ISCHAEMICSTROKE_V2
  - MINFARCTION
  - STROKEUNSPECIFIED_V2
  - STROKE_HAEMRGIC
optim:
  num_epochs: 10
  learning_r

In [5]:
supervised = True if cfg.experiment.fine_tune_outcomes is not None else False
logging.info("="*100)
logging.info(f"# Loading DataModule for dataset {cfg.data.path_to_ds}. This will be loaded in {'supervised' if supervised else 'causal'} form.")
logging.info("="*100)
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            global_diagnoses=cfg.data.global_diagnoses,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                            supervised=supervised
                           )
# Get required information from initialised dataloader
# ... vocab size
vocab_size = dm.train_set.tokenizer.vocab_size
# ... Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
#     encode into the list of univariate measurements to model with Normal distribution
# measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
# cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) #
measurements_for_univariate_regression = dm.train_set.meta_information["measurement_tables"][dm.train_set.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression)
logging.debug(OmegaConf.to_yaml(cfg))

target_tokens = dm.encode(['IHDINCLUDINGMI_OPTIMALV2', 'ISCHAEMICSTROKE_V2', 'MINFARCTION', 'STROKEUNSPECIFIED_V2', 'STROKE_HAEMRGIC'])


INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabula

# Run experiment

In [6]:
# X, y = load_gbsg2()
# grade_str = X.loc[:, "tgrade"].astype(object).values[:, np.newaxis]
# grade_num = OrdinalEncoder(categories=[["I", "II", "III"]]).fit_transform(grade_str)

# X_no_grade = X.drop("tgrade", axis=1)
# Xt = OneHotEncoder().fit_transform(X_no_grade)
# Xt.loc[:, "tgrade"] = grade_num
# print(X.head())
# print(Xt.head())

# print(type(y))
# print(y)

# Create dataset

In [7]:
def make_xsectional_dataset(dataset, n=None):

    X = pd.DataFrame(columns=[f'static_{_idx}' for _idx in range(16)] + [f'{dataset.tokenizer._itos[_idx]}' for _idx in range(2,vocab_size)])
    Y = []
    
    for s_idx, sample in tqdm(enumerate(dataset), total=n):
    
        # Input
        ########
        # Static variables are already processed into categories where required
        static = sample["static_covariates"]
    
        # Get a binary vector of vocab_size elements, which indicate if patient has any history of a condition (at any time, as long as it fits study criteria)
        # Note, 0 and 1 are PAD and UNK tokens which arent required
        input_tokens = sample["tokens"][:-1]
        token_binary = torch.zeros(vocab_size-2)
        for tkn_idx in range(2, vocab_size):
            if tkn_idx in input_tokens:
                token_binary[tkn_idx-2] = 1
    
        sample_input = torch.hstack((static, token_binary)).tolist()
        X.loc[s_idx] = sample_input
    
        # Target
        ########
        target = sample["tokens"][-1]
        if target in target_tokens:
            target = True
        else:
            target = False
        delta_age = sample["ages"][-1] - sample["ages"][-2]
        Y.append((target, int(delta_age)))
    
        if n is not None and s_idx >= n:
            break

    y = np.array(Y, dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

    return X, y

def make_xsectional_dataset2(datamodule, target_tokens, split='train', n=None):

    X = pd.DataFrame(columns=[f'static_{_idx}' for _idx in range(16)] + [f'{datamodule.train_set.tokenizer._itos[_idx]}' for _idx in range(2,vocab_size)])
    Y = []

    match split:
        case 'train':
            dataloader = datamodule.train_dataloader()
        case 'val':
            dataloader = datamodule.val_dataloader()
        case 'test':
            dataloader = datamodule.test_dataloader()
    
    for b_idx, batch in tqdm(enumerate(dataloader), total=n):

        # Input
        ########
        # Static variables are already processed into categories where required
        static = batch["static_covariates"].numpy()
    
        # Get a binary vector of vocab_size elements, which indicate if patient has any history of a condition (at any time, as long as it fits study criteria)
        # Note, 0 and 1 are PAD and UNK tokens which arent required
        input_tokens = batch["tokens"]
        token_binary = np.zeros((static.shape[0], vocab_size-2))
        for s_idx in range(static.shape[0]):
            for tkn_idx in range(2, vocab_size):
                if tkn_idx in input_tokens[s_idx, :]:
                    token_binary[s_idx, tkn_idx-2] = 1
    
        batch_input = np.hstack((static, token_binary))
        batch_df = pd.DataFrame(batch_input, columns=X.columns)
        X = pd.concat([X, batch_df])
        
        # Target
        ########
        targets = batch["target_token"].numpy()
        for s_idx in range(static.shape[0]):
            if targets[s_idx] in target_tokens:
                target = True
            else:
                target = False
            
            Y.append((target, int(batch["target_age_delta"][s_idx] )))
    
        if n is not None and b_idx >= n:
            break

    y = np.array(Y, dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

    return X, y


In [8]:
# for batch in dm.train_dataloader():
#     break
# print(batch)

In [9]:
# X_train, y_train = make_static_dataset(dm.train_set, n=1e6)
# X_test, y_test = make_static_dataset(dm.test_set, n=None)
# # print(Y)
# # print(X.head())


In [10]:
n_train =  len(dm.train_dataloader()) 
X_train, y_train = make_xsectional_dataset2(dm, target_tokens, split='train', n=n_train)

n_test = len(dm.test_dataloader())  
X_test, y_test = make_xsectional_dataset2(dm, target_tokens, split='test', n=n_test)
# print(Y)
# print(X.head())

import pickle 

data = {"X_train": X_train,
        "y_train": y_train,
        "X_test": X_test,
        "y_test": y_test}
with open('/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/xsectional_data.pickle', 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

100%|██████████| 8939/8939 [3:30:51<00:00,  1.42s/it]  
100%|██████████| 559/559 [03:17<00:00,  2.82it/s]


In [1]:

print(X_train.shape)
print(y_train)


NameError: name 'X_train' is not defined

In [11]:
# dm.train_set.tokenizer._stoi
print(target_tokens)

[95, 41, 67, 65, 28]


In [12]:
rsf = RandomSurvivalForest(
    n_estimators=1000, min_samples_split=50, min_samples_leaf=15, n_jobs=-1, random_state=1337,
    bootstrap=True, max_samples=1000, low_memory=False
)
est = rsf.fit(X_train, y_train)

In [None]:
rsf.score(X_test, y_test)

In [None]:
from sksurv.metrics import integrated_brier_score

survs = est.predict_survival_function(X_test)
times = np.arange(1, 365*5)[::20]
preds = np.asarray([[fn(t) for t in times] for fn in survs])
score = integrated_brier_score(y_train, y_test, preds, times)

print(score)