# CPRD Notebook:
## Evaluation of fine-tuning the pre-trained SurvivEHR-CR model on a supervised cohort study.

Cohort study: predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population.

This notebook quantifies the performance obtained when fine-tuning the pre-trained model to a sub-population.

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

print(os.getcwd())
    

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from FastEHR.dataloader.foundational_loader import FoundationalDataModule
import pickle 

from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest
from CPRD.src.modules.head_layers.survival.desurv import ODESurvSingle
from pycox.evaluation import EvalSurv

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

from contextlib import redirect_stdout

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


In [3]:
def make_dataset(datamodule, target_tokens, split='train', n=None):

    X = pd.DataFrame(columns=[f'static_{_idx}' for _idx in range(16)] + [f'{datamodule.train_set.tokenizer._itos[_idx]}' for _idx in range(2,vocab_size)])
    Y = []

    match split:
        case 'train':
            dataloader = datamodule.train_dataloader()
        case 'val':
            dataloader = datamodule.val_dataloader()
        case 'test':
            dataloader = datamodule.test_dataloader()
        case _:
            raise NotImplementedError
    
    for b_idx, batch in tqdm(enumerate(dataloader), total=n, desc=f"Creating {split} cross-sectional dataset to be used for benchmarking."):

        # Input
        ########
        # Static variables are already processed into categories where required
        static = batch["static_covariates"].numpy()
    
        # Get a binary vector of vocab_size elements, which indicate if patient has any history of a condition (at any time, as long as it fits study criteria)
        # Note, 0 and 1 are PAD and UNK tokens which arent required
        input_tokens = batch["tokens"]
        token_binary = np.zeros((static.shape[0], vocab_size-2))
        for s_idx in range(static.shape[0]):
            for tkn_idx in range(2, vocab_size):
                if tkn_idx in input_tokens[s_idx, :]:
                    token_binary[s_idx, tkn_idx-2] = 1
    
        batch_input = np.hstack((static, token_binary))
        batch_df = pd.DataFrame(batch_input, columns=X.columns)
        X = pd.concat([X, batch_df])
        
        # Target
        ########
        targets = batch["target_token"].numpy()
        for s_idx in range(static.shape[0]):
            # default to 0
            target = 0

            # replace with target if its in the outcome set
            for idx_outcome, outcome in enumerate(target_tokens):
                if targets[s_idx] == outcome:
                    target = idx_outcome + 1
            
            Y.append((target, batch["target_age_delta"][s_idx] ))
    
        # if n is not None and b_idx >= n:
        #     break

    y = np.array(Y, dtype=[('Status', 'int'), ('Survival_in_days', '<f8')])

    return X, y


## Load or create dataset

In [4]:
sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]


In [None]:
experiments = [ "mm"]   # "Hypertension", "" "MM"  "cvd"
sample_sizes = [None]
# sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)] #+ [None]
# sample_sizes = sample_sizes[-1:]


for experiment in experiments:
    for sample_size in sample_sizes:

            # set the seeds - if we are using all data there is no need to bootstrap
            if sample_size is None:
                seeds = [42]
            else:
                seeds = [1,2,3,4,5]
                
            for seed in seeds:
                
                # load the configuration file, override any settings 
                with initialize(version_base=None, config_path="../modelling/SurvivEHR/confs", job_name="testing_notebook"):
                    cfg = compose(config_name="config_CompetingRisk37M") 
                    cfg.transformer.block_size=1000000     # Ensure all records get included
            
                match experiment.lower():
                    case "cvd":
                        cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
                    case "hypertension":
                        cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/"
                    case "mm":
                        # cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity2/"
                        cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/"
                        
            
                supervised = True 
                logging.info("="*100)
                logging.info(f"# Loading DataModule for dataset {cfg.data.path_to_ds}. This will be loaded in {'supervised' if supervised else 'causal'} form.")
                logging.info("="*100)
                dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                                            path_to_ds=cfg.data.path_to_ds,
                                            load=True,
                                            tokenizer="tabular",
                                            batch_size=cfg.data.batch_size,
                                            max_seq_length=cfg.transformer.block_size,
                                            global_diagnoses=cfg.data.global_diagnoses,
                                            freq_threshold=cfg.data.unk_freq_threshold,
                                            min_workers=cfg.data.min_workers,
                                            overwrite_meta_information=cfg.data.meta_information_path,
                                            supervised=supervised,
                                            subsample_training=sample_size,
                                            seed=seed,
                                           )
                if sample_size is not None:
                    print(dm.train_set.subsample_indicies)
                # Get required information from initialised dataloader
                # ... vocab size
                vocab_size = dm.train_set.tokenizer.vocab_size
                # ... Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
                #     encode into the list of univariate measurements to model with Normal distribution
                # measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
                # cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) #
                measurements_for_univariate_regression = dm.train_set.meta_information["measurement_tables"][dm.train_set.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
                cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression)
                logging.debug(OmegaConf.to_yaml(cfg))
        
                match experiment.lower():
                    case "cvd":
                        conditions = ["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]
                        cfg.experiment.fine_tune_outcomes=conditions
                    case "hypertension":
                        conditions = ["HYPERTENSION"]
                        cfg.experiment.fine_tune_outcomes=conditions
                    case "mm":
                        conditions = (
                            dm.tokenizer._event_counts.filter((pl.col("COUNT") > 0) &
                                (pl.col("EVENT").str.contains(r'^[A-Z0-9_]+$')))
                              .select("EVENT")
                              .to_series()
                              .to_list()
                        )
                        cfg.experiment.fine_tune_outcomes=conditions
                
                target_tokens = dm.encode(conditions)
            
                # print(OmegaConf.to_yaml(cfg))
        
                if sample_size is not None:
                    save_path = cfg.data.path_to_ds + f"benchmark_data/N={sample_size}_seed{seed}.pickle" 
                else:
                    save_path = cfg.data.path_to_ds + "benchmark_data/all.pickle"
                
                try:
                    # Load the pickled file for testing
                    print(f"Trying to load {save_path}")
                    
                    with open(save_path, "rb") as handle:
                        data = pickle.load(handle)
                
                except:
                    print(f"Loading failed, creating dataset")
                    
                    # Training set
                    n_train =  len(dm.train_dataloader()) 
                    X_train, y_train = make_dataset(dm, target_tokens, split='train', n=n_train)    
        
                    data = {
                        "X_train": X_train,
                        "y_train": y_train,
                    }
        
                     # Test and validation sets - only for the full dataset version, as there is no point repeating this operation
                    if sample_size is None:
                        
                        n_val = len(dm.val_dataloader())  
                        X_val, y_val = make_dataset(dm, target_tokens, split='val', n=n_val)
                        
                        n_test = len(dm.test_dataloader())  
                        X_test, y_test = make_dataset(dm, target_tokens, split='test', n=n_test)
                        
                        print(X_test)
                        print(X_test.head())
                        
                        data = {**data,
                                "X_val": X_val,
                                "y_val": y_val,
                                "X_test": X_test,
                                "y_test": y_test
                                }
                    
                    with open(save_path, 'wb') as handle:
                        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
                        print(f"Saving to {save_path}")
            
            


INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Scaling supervised target ages by a factor of 1.0 times the context scale.
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalMod

In [None]:
dm.decode([33])
dm.tokenizer._stoi
dm.tokenizer._event_counts


print(events_more_than_zero)


In [None]:
supervised = True if cfg.experiment.fine_tune_outcomes is not None else False
logging.info("="*100)
logging.info(f"# Loading DataModule for dataset {cfg.data.path_to_ds}. This will be loaded in {'supervised' if supervised else 'causal'} form.")
logging.info("="*100)
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            global_diagnoses=cfg.data.global_diagnoses,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                            supervised=supervised
                           )
# Get required information from initialised dataloader
# ... vocab size
vocab_size = dm.train_set.tokenizer.vocab_size
# ... Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
#     encode into the list of univariate measurements to model with Normal distribution
# measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
# cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) #
measurements_for_univariate_regression = dm.train_set.meta_information["measurement_tables"][dm.train_set.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression)
logging.debug(OmegaConf.to_yaml(cfg))

# target_tokens = dm.encode(['IHDINCLUDINGMI_OPTIMALV2', 'ISCHAEMICSTROKE_V2', 'MINFARCTION', 'STROKEUNSPECIFIED_V2', 'STROKE_HAEMRGIC'])
target_tokens = dm.encode(['HYPERTENSION'])


# Create dataset

In [None]:
# for batch in dm.train_dataloader():
#     break
# print(batch)

In [None]:
# X_train, y_train = make_static_dataset(dm.train_set, n=1e6)
# X_test, y_test = make_static_dataset(dm.test_set, n=None)
# # print(Y)
# # print(X.head())


In [None]:
load = False

if load is False:
    
    n_train =  len(dm.train_dataloader()) 
    X_train, y_train = make_xsectional_dataset2(dm, target_tokens, split='train', n=n_train)    

    n_val = len(dm.val_dataloader())  
    X_val, y_val = make_xsectional_dataset2(dm, target_tokens, split='val', n=n_val)#

    n_test = len(dm.test_dataloader())  
    X_test, y_test = make_xsectional_dataset2(dm, target_tokens, split='test', n=n_test)

    print(X_test)
    print(X_test.head())
    
    import pickle 
    
    data = {
        "X_train": X_train,
        "y_train": y_train,
        "X_val": X_val,
        "y_val": y_val,
        "X_test": X_test,
        "y_test": y_test
    }
    with open('/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/xsectional_data_CR.pickle', 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open('/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/xsectional_data_CR.pickle', "rb") as handle:
        data = pickle.load(handle)

    X_train = data["X_train"]
    y_train = data["y_train"]
    X_val = data["X_val"][:10000]
    y_val = data["y_val"][:10000]
    X_test = data["X_test"][:10000]
    y_test = data["y_test"][:10000]

    # print(X_train.shape)
    # print(y_train)
    print(X_val.shape)
    print(y_val)
    # print(X_test.shape)
    # print(y_test.shape)
    

In [None]:
print(np.unique([_i[0] for _i in y_test], return_counts=True))

In [None]:
print(data.keys())

In [None]:
target_ages = np.asarray([i[1] for i in y_test])
lbls = np.asarray([1 if i[0] == True else 0 for i in y_test])

In [None]:
# dm.train_set.tokenizer._stoi
print(target_tokens)

In [None]:
_time_scale = 365*5                               
# the time grid which we generate over
t_eval = np.linspace(0, _time_scale, _time_scale + 1) 
print(t_eval[-1])

# Random Survival Forest

In [None]:
rsf = RandomSurvivalForest(
    n_estimators=100, n_jobs=-1, random_state=1337,
    bootstrap=True, max_samples=1000, low_memory=False,
    # min_samples_split=50,
    # min_samples_leaf=15, 
)
# rsf.unique_times_ = t_eval
est = rsf.fit(X_train, y_train)

In [None]:
# est.unique_times_ = t_eval

In [None]:
# rsf.score(X_test, y_test)

In [None]:
from sksurv.metrics import integrated_brier_score

survs = est.predict_survival_function(X_test)
# times = np.arange(1, 365*5)[::20]
# preds = np.asarray([[fn(t) for t in t_eval] for fn in survs])
# score = integrated_brier_score(y_train, y_test, preds, times)

# print(score)

In [None]:
print(survs.shape)
preds = np.asarray([[fn(t) for t in t_eval] for fn in survs])

print(preds.shape)

In [None]:
# Evaluate concordance. Scale using the head layers internal scaling.
# surv = pd.DataFrame(np.transpose((1 - cdf)), index=_pl_module.model.surv_layer.t_eval)

surv = pd.DataFrame(np.transpose(preds), index=t_eval)
ev = EvalSurv(surv, target_ages, lbls, censor_surv='km')

time_grid = np.linspace(start=0, stop=t_eval[-1] , num=300)
print(ev.concordance_td())
print(ev.integrated_brier_score(time_grid))
print(ev.integrated_nbll(time_grid))

# DeSurv

In [None]:
desurv_model = ODESurvSingle(cov_dim=X_train.shape[1],
                             hidden_dim=[],
                             n=15)

batch_size = data_loader.batch_size

for epoch in range(n_epochs):

    train_loss = 0.0

    for batch_idx, (x, t, k) in enumerate(data_loader):
        argsort_t = torch.argsort(t)
        x_ = x[argsort_t,:].to(self.odenet.device)
        t_ = t[argsort_t].to(self.odenet.device)
        k_ = k[argsort_t].to(self.odenet.device)

        self.optimizer.zero_grad()
        loss = self.forward(x_,t_,k_)
        loss.backward()
        self.optimizer.step()

        train_loss += loss.item()

    if epoch % logging_freq == 0:
        print(f"\tEpoch: {epoch:2}. Total loss: {train_loss:11.2f}")
        if data_loader_val is not None:
            val_loss = 0
            for batch_idx, (x, t, k) in enumerate(data_loader_val):
                argsort_t = torch.argsort(t)
                x_ = x[argsort_t,:].to(self.odenet.device)
                t_ = t[argsort_t].to(self.odenet.device)
                k_ = k[argsort_t].to(self.odenet.device)

                loss = self.forward(x_,t_,k_)
                val_loss += loss.item()

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                wait = 0
                print(f"best_epoch: {epoch}")
                torch.save(self.state_dict(), "low_")
            else:
                wait += 1

            if wait > max_wait:
                state_dict = torch.load("low_")
                self.load_state_dict(state_dict)
                return

            print(f"\tEpoch: {epoch:2}. Total val loss: {val_loss:11.2f}")