# CPRD Notebook:
## Evaluation of fine-tuning the pre-trained SurvivEHR-CR model on a supervised cohort study.

Cohort study: predicting follow-up Systolic Blood Pressure following anti-hypertensive medication (within 60 days of hypertension diagnosis)

This notebook quantifies the performance obtained when fine-tuning the pre-trained model to a sub-population.

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

print(os.getcwd())

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvivEHR/notebooks/CompetingRisk/1_Study2_BP


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf

from FastEHR.dataloader import FoundationalDataModule


from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


# Fine-tuning on full dataset
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `fine-tune` experiment type, which will lead to running the ```SupervisedExperiment```. 

We tell this experiment that we want to perform training (true by default). Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

We design a new optimisation strategy for fine-tuning. Pre-training was achieved with a warmup and cosine annealing, with rates which are no appropriate for much smaller dataset sizes seen in clinical prediction models (CPMs). We here choose a simpler strategy: of ReduceOnPlateau with no warmup, increasing the number of epochs (default is 1) and reduced validation intervals, and the addition of early stopping. Additionally, as this is not a causal model we can increase the batch size. Finally, as this CPM is not trying to predict the value of any outcomes, we set the value weight to zero allowing the model to focus entirely on optimising survival outcome prediction.

```

```

In [3]:
pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337']
adapter = False

In [None]:
for pre_trained_model in pre_trained_model_ids[1:2]:
    print(pre_trained_model)
    
    wandb.finish()
    # load the configuration file, override any settings 
    with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
        cfg = compose(config_name="config_CompetingRisk11M", 
                      overrides=[# Experiment setup
                                 f"experiment.project_name='SurvivEHR-Study2-fine-tune-value'",
                                 f"experiment.type='fine-tune'",
                                 f"experiment.run_id='{pre_trained_model}'",
                                 f"experiment.fine_tune_id='SBP-Value-A{adapter}-notebook'",
                                 "experiment.train=True",
                                 "experiment.test=True",                                 
                                 "experiment.notes=Table result",
                                 "experiment.log_dir='/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/'",
                                 "experiment.ckpt_dir='/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/checkpoints/'",
                                 # Fine tuning
                                 "experiment.fine_tune_outcomes=['Systolic_blood_pressure_4']",
                                 "fine_tuning.head.surv_weight=1",
                                 "fine_tuning.head.value_weight=1",
                                 # Dataloader
                                 "data.path_to_ds='/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_BPpostHypertension/'",
                                 "data.batch_size=128",
                                 "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                 "data.min_workers=3",
                                 "data.global_diagnoses=True",
                                 # Optimiser
                                 "optim.num_epochs=20",
                                 "optim.limit_test_batches=null",
                                 "optim.scheduler=ReduceOnPlateau",
                                 "optim.scheduler_warmup=False",
                                 "optim.learning_rate=1e-3",
                                 "optim.val_check_interval=50",
                                 "optim.early_stop=True",
                                 "optim.early_stop_patience=4",
                                 "optim.limit_val_batches=0.035",
                                 "optim.accumulate_grad_batches=4",
                                 # Model
                                 # "transformer.n_embd=384",
                                 f"transformer.use_fine_tune_adapter={False if adapter is False else True}",
                                 f"transformer.adapter_dim={8 if adapter is False else adapter}",
                                 "transformer.block_size=512", 
                                 "transformer.dropout=0.0",
                                 "transformer.resid_dropout=0.0",
                                 "transformer.attention_dropout=0.0",                                  
                                ]
                     )
    
    model, dm = run(cfg)
    print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    wandb.finish()
    


SurvivEHR-cr-small-v1


INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_BPpostHypertension/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Scaling supervised target ages by a factor of 1.0 times the context scale.
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_BPpostHypertension/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_BPpostHypertension/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optim

{'_target_': None, 'tokens': None}


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
INFO:root:Training model.
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcwlgadd[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:root:Using ReduceLROnPlateau scheduler
INFO:root:Not using warm-up in scheduler

  | Name        | Type                            | Params
----------------------------------------------------------------
0 | model       | SurvStreamGPTForCausalModelling | 11.2 M
1 | surv_layer  | ODESurvCompetingRiskLayer       | 26.9 K
2 | value_layer | GaussianRegressionLayer         | 13.4 K
----------------------------------------------------------------
11.3 M    Trainable params
30        Non-trainable params
11.3 M    Total params
45.006    Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[5.27265987e-16 2.34064237e-07 1.93300852e-06 2.20521097e-05]
not reaching the requested tolerance 3.9528470752104736e-06.
  _, diffusion_map = lobpcg(


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: -284.940
Epoch 0, global step 12: 'val_loss' reached -284.94012 (best -284.94012), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/checkpoints/SurvivEHR-cr-small-v1_SBP-Value-AFalse-notebook.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 96.062 >= min_delta = 0. New best score: -381.002
Epoch 0, global step 25: 'val_loss' reached -381.00217 (best -381.00217), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/checkpoints/SurvivEHR-cr-small-v1_SBP-Value-AFalse-notebook.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 4.293 >= min_delta = 0. New best score: -385.295
Epoch 0, global step 37: 'val_loss' reached -385.29526 (best -385.29526), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/checkpoints/SurvivEHR-cr-small-v1_SBP-Value-AFalse-notebook.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 3.633 >= min_delta = 0. New best score: -388.929
Epoch 0, global step 50: 'val_loss' reached -388.92874 (best -388.92874), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/checkpoints/SurvivEHR-cr-small-v1_SBP-Value-AFalse-notebook.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 2.669 >= min_delta = 0. New best score: -391.598
Epoch 0, global step 62: 'val_loss' reached -391.59802 (best -391.59802), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/checkpoints/SurvivEHR-cr-small-v1_SBP-Value-AFalse-notebook.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 2.320 >= min_delta = 0. New best score: -393.918
Epoch 0, global step 75: 'val_loss' reached -393.91797 (best -393.91797), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/Study2/checkpoints/SurvivEHR-cr-small-v1_SBP-Value-AFalse-notebook.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
# [_i if _i == _i.upper() else 0 for _i in dm.train_set.tokenizer._stoi.keys()]

In [None]:
wandb.finish()