# CPRD Notebook:
## Evaluation of fine-tuning the pre-trained SurvivEHR-CR model on a supervised cohort study.

Cohort study:

This notebook quantifies the performance obtained when fine-tuning the pre-trained model

In [2]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

print(os.getcwd())

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvivEHR/notebooks/CompetingRisk/1_Study3_MultiMorbidity


In [3]:
import FastEHR
from FastEHR.dataloader import FoundationalDataModule

In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf

from FastEHR.dataloader import FoundationalDataModule

from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from CPRD.examples.modelling.SurvivEHR.helpers import count_prior_tokens

import time
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000


# TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

env: SLURM_NTASKS_PER_NODE=28


# Fine-tuning on full dataset
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `fine-tune` experiment type, which will lead to running the ```SupervisedExperiment```. 

We tell this experiment that we want to perform training (true by default). Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

We design a new optimisation strategy for fine-tuning. Pre-training was achieved with a warmup and cosine annealing, with rates which are no appropriate for much smaller dataset sizes seen in clinical prediction models (CPMs). We here choose a simpler strategy: of ReduceOnPlateau with no warmup, increasing the number of epochs (default is 1) and reduced validation intervals, and the addition of early stopping. Additionally, as this is not a causal model we can increase the batch size. Finally, as this CPM is not trying to predict the value of any outcomes, we set the value weight to zero allowing the model to focus entirely on optimising survival outcome prediction.

In [5]:
# pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337']
pre_trained_model_ids = ["SurvivEHR-cr-small-debug7_exp1000-v1"]
experiments = ["mm"] 
experiment_types = [ "fine-tune-sr"] 
adapter = False

## Extract MM information for experiment setup

We want to stratify RMST by level of existing multi-morbidity, so we create a custom function to pass to the callback

In [6]:
# Load dataset in most minimal form (this isnt used for the experiment - only to extract the token names for the diagnoses)
with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M")

dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/",
                            overwrite_meta_information=cfg.data.meta_information_path,
                            load=True)

conditions = (
    dm.tokenizer._event_counts.filter((pl.col("COUNT") > 0) &
        (pl.col("EVENT").str.contains(r'^[A-Z0-9_]+$')))
      .select("EVENT")
      .to_series()
      .to_list()
)
encoded_conditions = dm.tokenizer.encode(conditions)


In [7]:
display(cfg)

{'is_decoder': True, 'data': {'batch_size': 64, 'unk_freq_threshold': 0.0, 'min_workers': 12, 'global_diagnoses': False, 'repeating_events': True, 'path_to_db': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db', 'path_to_ds': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/', 'meta_information_path': '/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle', 'subsample_training': None}, 'experiment': {'type': 'pre-train', 'project_name': 'SurvivEHR', 'run_id': '${head.SurvLayer}PreTrain_small_${experiment.seed}', 'fine_tune_id': None, 'notes': None, 'tags': None, 'train': True, 'test': True, 'verbose': True, 'seed': 1337, 'log': True, 'log_dir': '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/', 'ckpt_dir': '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/', 'fine_tune_

In [8]:
display(dm.tokenizer.decode([196,233]))

'Antipsychotics_OPTIMAL Anticonvulsants_OPTIMAL'

# Test utility function

Utility function `count_prior_tokens()` is used to stratify samples basedon the number of times they experienced a set of events. In this case these are diagnoses events and our stratification is on the level of existing multi-morbidity

In [9]:
for batch in dm.train_dataloader():
    break

In [10]:
for idx, sample in enumerate(dm.train_set):
    if torch.unique(sample["tokens"]).shape[0] < 5:
        print(idx)
        print(sample["tokens"])

3512
tensor([255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 254, 255, 255, 255, 255, 133, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
        255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 2


KeyboardInterrupt



In [11]:
display(count_prior_tokens(batch, encoded_conditions)[:5])

print(encoded_conditions)
print(sum([_i in encoded_conditions for _i in batch["tokens"][0]]))

['5 current diagnoses',
 '3 current diagnoses',
 '1 current diagnosis',
 '1 current diagnosis',
 '0 current diagnoses']

[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 19, 20, 21, 22, 23, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 48, 49, 52, 56, 57, 58, 60, 62, 65, 67, 70, 74, 75, 78, 79, 81, 82, 83, 89, 91, 93, 94, 95, 97, 100, 104, 106, 114, 120, 123, 126, 129, 133, 134, 136]
5


## Get outcome list

In [None]:
for pre_trained_model in pre_trained_model_ids:
    print(pre_trained_model)

    for experiment, experiment_type in zip(experiments, experiment_types):
    
        wandb.finish()
        # load the configuration file, override any settings 
        with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
            cfg = compose(config_name="config_CompetingRisk11M", 
                          overrides=[# Experiment setup
                                     f"experiment.project_name='SurvivEHR-Study3-MM'",
                                     f"experiment.type='{experiment_type}'",
                                     f"experiment.run_id='{pre_trained_model}'",
                                     f"experiment.fine_tune_id='{experiment}-{experiment_type}-A{adapter}-notebook-MM50+_3'",
                                     "experiment.train=True",
                                     "experiment.test=True",
                                     # Dataloader
                                     "data.batch_size=128",
                                     "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                     "data.min_workers=12",
                                     "data.global_diagnoses=True",
                                     "data.repeating_events=False",
                                     # f"data.subsample_training={1000}",
                                     # Optimiser
                                     "optim.num_epochs=20",
                                     "optim.limit_test_batches=null",
                                     "optim.scheduler=ReduceOnPlateau",
                                     "optim.scheduler_warmup=False",
                                     "optim.learning_rate=1e-3",
                                     "optim.val_check_interval=0.25",
                                     "optim.early_stop=True",
                                     "optim.early_stop_patience=10",
                                     "optim.limit_val_batches=1.0",
                                     "optim.limit_test_batches=1.0",
                                     "optim.accumulate_grad_batches=5",
                                     # Model
                                     # "transformer.n_embd=384",
                                     f"transformer.use_fine_tune_adapter={False if adapter is False else True}",
                                     f"transformer.adapter_dim={8 if adapter is False else adapter}",
                                     "transformer.block_size=500", 
                                     "head.value_weight=0",
                                    ]
                         )
        
        match experiment.lower():
            case "mm":
                cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/"
                cfg.experiment.fine_tune_outcomes=conditions
                cfg.fine_tuning.custom_stratification_method._target_="CPRD.examples.modelling.SurvivEHR.helpers.count_prior_tokens"
                cfg.fine_tuning.custom_stratification_method.tokens=encoded_conditions
        
        model, dm = run(cfg)
        print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
        wandb.finish()



SurvivEHR-cr-small-debug7_exp1000-v1


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:173: Found keys that are in the model state dict but not in the checkpoint: ['reduce_hidden.0.weight', 'reduce_hidden.0.bias', 'surv_layer.sr_ode.net.u', 'surv_layer.sr_ode.net.w', 'surv_layer.sr_ode.net.BaseNet.mapping.0.weight', 'surv_layer.sr_ode.net.BaseNet.mapping.0.bias', 'surv_layer.sr_ode.net.BaseNet.mapping.2.weight', 'surv_layer.sr_ode.net.BaseNet.mapping.2.bias', 'surv_layer.sr_ode.net.BaseNet.mapping.4.weight', 'surv_layer.sr_ode.net.BaseNet.mapping.4.bias']
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..
ERROR:wandb.jupyter:Failed to detect the name of this

/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                            | Params
------------------------------------------------------------------
0 | model         | SurvStreamGPTForCausalModelling | 11.2 M
1 | reduce_hidden | Sequential                      | 147 K 
2 | surv_layer    | ODESurvSingleRiskLayer          | 13.5 K
------------------------------------------------------------------
11.4 M    Trainable params
30        Non-trainable params
11.4 M    Total params
45.482    Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:211: You called `self.log('val_loss_values', ...)` in your `validation_step` but the value needs to be floating point. Converting it to torch.float32.


Training: |          | 0/? [00:00<?, ?it/s]

/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:211: You called `self.log('train_loss_values', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
                                    		 Bad sample tokens: tensor([[196,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.2296,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[160,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[233,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.6395,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved. New best score: 0.678
Epoch 0, global step 644: 'val_loss' reached 0.67848 (best 0.67848), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-debug7_exp1000-v1_mm-fine-tune-sr-AFalse-notebook-MM50+_3.ckpt' as top 1


```
/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1.ckpt
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
└─────────────────────────────────────┴─────────────────────────────────────┘
```

In [None]:
tokens = encoded_conditions
print(isinstance(tokens, list))
print([isinstance(i, int) for i in tokens])
# [_i if _i == _i.upper() else 0 for _i in dm.train_set.tokenizer._stoi.keys()]

# Fine-tuning on sub-set of data

In [None]:
sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]
print(sample_sizes)

In [None]:
sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]      # [3000, 12500, 30000, 60000, 100000]: # 600, 1200, 
sample_sizes = [None]

accumulate_grad_batches = 5


for pre_trained_model in pre_trained_model_ids[1:2]:
    print(pre_trained_model)

    for experiment, experiment_type in zip(experiments, experiment_types):

        for sample_size in sample_sizes:

            wandb.finish()
            # load the configuration file, override any settings 
            with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
                cfg = compose(config_name="config_CompetingRisk11M", 
                              overrides=[# Experiment setup
                                         f"experiment.project_name='SurvivEHR-Study1-CVD'",
                                         f"experiment.type='{experiment_type}'",
                                         f"experiment.run_id='{pre_trained_model}'",
                                         f"experiment.fine_tune_id='{experiment}-{experiment_type}-A{adapter}-Ns{sample_size}-500-notebook'",
                                         "experiment.train=True",
                                         "experiment.test=True",
                                         "experiment.notes=Ablation on increasing cohort study size result",
                                         # Dataloader
                                         "data.batch_size=128",
                                         "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                         "data.min_workers=3",
                                         "data.global_diagnoses=True",
                                         "data.repeating_events=False",
                                         # f"data.subsample_training={sample_size}",
                                         # Optimiser
                                         "optim.num_epochs=500",
                                         "optim.limit_test_batches=null",
                                         "optim.scheduler=ReduceOnPlateau",
                                         "optim.scheduler_warmup=False",
                                         "optim.learning_rate=1e-3",
                                         "optim.val_check_interval=0.25",
                                         "optim.early_stop=True",
                                         "optim.early_stop_patience=4",
                                         "optim.limit_val_batches=0.035",
                                         f"optim.accumulate_grad_batches={accumulate_grad_batches}",
                                         # Model
                                         # "transformer.n_embd=384",
                                         f"transformer.use_fine_tune_adapter={False if adapter is False else True}",
                                         f"transformer.adapter_dim={8 if adapter is False else adapter}",
                                         "transformer.block_size=500", 
                                        ]
                             )
            
            
            match experiment.lower():
                case "cvd":
                    cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
                    cfg.experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]
                case "hypertension":
                    cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/"
                    cfg.experiment.fine_tune_outcomes=["HYPERTENSION"]
            
            
            model, dm = run(cfg)
            print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
            wandb.finish()


In [None]:
wandb.finish()

In [None]:
dm.tokenizer._event_counts

In [None]:
all_CVD = {
    "SurvivEHR-cr-small-v1": (
        "SurvivEHR-CR",
        [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000]],
        [0.6270919442176819],
        [0.03389651543792928],
        [0.14657540914939907]
        ),
    
}

