# CPRD Notebook:
## Evaluation of fine-tuning the pre-trained SurvivEHR-CR model on a supervised cohort study.

Cohort study: predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population.

This notebook quantifies the performance obtained when fine-tuning the pre-trained model to a sub-population.

In [None]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

print(os.getcwd())

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.run_experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

import time
import pyarrow.dataset as ds
import pyarrow.parquet as pq
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


# Choosing configurations
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `fine-tune` experiment type, which will lead to running the ```SupervisedExperiment```. 

We tell this experiment that we want to perform training (true by default). Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

We design a new optimisation strategy for fine-tuning. Pre-training was achieved with a warmup and cosine annealing, with rates which are no appropriate for much smaller dataset sizes seen in clinical prediction models (CPMs). We here choose a simpler strategy: of ReduceOnPlateau with no warmup, increasing the number of epochs (default is 1) and reduced validation intervals, and the addition of early stopping. Additionally, as this is not a causal model we can increase the batch size. Finally, as this CPM is not trying to predict the value of any outcomes, we set the value weight to zero allowing the model to focus entirely on optimising survival outcome prediction.

In [5]:
experiments = ["hypertension", "cvd"]

for experiment in experiments:

    wandb.finish()
    # load the configuration file, override any settings 
    with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
        cfg = compose(config_name="config_CompetingRisk37M", 
                      overrides=[# Experiment setup
                                 "experiment.type='fine-tune-cr'",
                                 "experiment.run_id='CR_11M_24_11_01_big_posencscale_'",  # CR_11M_24_10_31_posencscale
                                 f"experiment.fine_tune_id='{experiment}-Fine-Tune'",
                                 "experiment.train=True",
                                 "experiment.test=True",
                                 # Dataloader
                                 "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                 "data.min_workers=12",
                                 # Optimiser
                                 "optim.num_epochs=20",
                                 "optim.limit_test_batches=null",
                                 "optim.scheduler=ReduceOnPlateau",
                                 "optim.scheduler_warmup=False",
                                 "optim.learning_rate=1e-4",
                                 "optim.val_check_interval=50",
                                 "optim.early_stop=True",
                                 "optim.early_stop_patience=5",
                                 "optim.limit_val_batches=0.035",
                                 # Transformer
                                 # "transformer.dropout=0.5",
                                 # "transformer.attention_dropout=0.5",
                                 # "transformer.resid_dropout=0.5",
                                ]
                     )
    
    
    match experiment.lower():
        case "cvd":
            cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
            cfg.experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]
        case "hypertension":
            cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/"
            cfg.experiment.fine_tune_outcomes=["HYPERTENSION"]
    
    
    model, dm = run(cfg)
    print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    wandb.finish()


INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_val.pickle
I

/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:root:Using ReduceLROnPlateau scheduler
INFO:root:Not using warm-up in scheduler

  | Name       | Type                            | Params
---------------------------------------------------------------
0 | model      | SurvStreamGPTForCausalModelling | 129 M 
1 | surv_layer | ODESurvCompetingRiskLayer       | 67.9 K
2 | dropout    | Dropout                         | 0     
---------------------------------------------------------------
129 M     Trainable params
572       Non-trainable params
129 M     Total params
517.697   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.872
Epoch 0, global step 50: 'val_loss' reached 0.87158 (best 0.87158), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.196 >= min_delta = 0. New best score: 0.676
Epoch 0, global step 100: 'val_loss' reached 0.67567 (best 0.67567), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.030 >= min_delta = 0. New best score: 0.646
Epoch 0, global step 150: 'val_loss' reached 0.64608 (best 0.64608), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.023 >= min_delta = 0. New best score: 0.623
Epoch 0, global step 200: 'val_loss' reached 0.62278 (best 0.62278), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.045 >= min_delta = 0. New best score: 0.578
Epoch 0, global step 250: 'val_loss' reached 0.57780 (best 0.57780), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.025 >= min_delta = 0. New best score: 0.553
Epoch 0, global step 300: 'val_loss' reached 0.55281 (best 0.55281), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0. New best score: 0.547
Epoch 0, global step 350: 'val_loss' reached 0.54651 (best 0.54651), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0. New best score: 0.539
Epoch 0, global step 400: 'val_loss' reached 0.53906 (best 0.53906), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 450: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 500: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 550: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 600: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.539. Signaling Trainer to stop.
Epoch 0, global step 650: 'val_loss' was not in top 1
INFO:root:Re-loading from best cached checkpoint /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__hypertension-Fine-Tune.ckpt
INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Competing-Risk DeSurv head.
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1.0] with 1000 intervals
INFO:root:Using Competing-Risk DeSurv head.
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1.0] with 1000 intervals
INFO:root:Trainable parameters: all parameters
INFO:root:Testing model.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

Loaded model with 129.424208 M parameters


VBox(children=(Label(value='71.765 MB of 71.765 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
Scheduler,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test:OutcomePerformanceMetricsctd,▁
Test:OutcomePerformanceMetricsibs,▁
Test:OutcomePerformanceMetricsinbll,▁
Val:OutcomePerformanceMetricsctd,▁▅▆███▇▇█▇▇▇▇
Val:OutcomePerformanceMetricsibs,█▅▅▄▂▁▂▁▂▃▂▃▃
Val:OutcomePerformanceMetricsinbll,█▆▇▅▃▁▁▁▂▂▁▃▂
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_loss,▁
train_loss,█▆▅▄▄▅▃▃▄▃▁▄▃▃▂▃▂▃▃▃▃▂▂▃▂▂▂▂▂▂▂▁

0,1
Scheduler,0.0001
Test:OutcomePerformanceMetricsctd,0.70422
Test:OutcomePerformanceMetricsibs,0.08789
Test:OutcomePerformanceMetricsinbll,0.29442
Val:OutcomePerformanceMetricsctd,0.66998
Val:OutcomePerformanceMetricsibs,0.08578
Val:OutcomePerformanceMetricsinbll,0.28804
epoch,1.0
test_loss,0.50387
train_loss,0.36745


INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 

/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:root:Using ReduceLROnPlateau scheduler
INFO:root:Not using warm-up in scheduler

  | Name       | Type                            | Params
---------------------------------------------------------------
0 | model      | SurvStreamGPTForCausalModelling | 129 M 
1 | surv_layer | ODESurvCompetingRiskLayer       | 68.2 K
2 | dropout    | Dropout                         | 0     
---------------------------------------------------------------
129 M     Trainable params
572       Non-trainable params
129 M     Total params
517.698   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 1.055
Epoch 0, global step 50: 'val_loss' reached 1.05469 (best 1.05469), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.354 >= min_delta = 0. New best score: 0.700
Epoch 0, global step 100: 'val_loss' reached 0.70023 (best 0.70023), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.106 >= min_delta = 0. New best score: 0.595
Epoch 0, global step 150: 'val_loss' reached 0.59467 (best 0.59467), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.022 >= min_delta = 0. New best score: 0.573
Epoch 0, global step 200: 'val_loss' reached 0.57302 (best 0.57302), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 0.571
Epoch 0, global step 250: 'val_loss' reached 0.57104 (best 0.57104), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 0.569
Epoch 0, global step 300: 'val_loss' reached 0.56864 (best 0.56864), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0. New best score: 0.567
Epoch 0, global step 350: 'val_loss' reached 0.56747 (best 0.56747), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0. New best score: 0.564
Epoch 0, global step 400: 'val_loss' reached 0.56391 (best 0.56391), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0. New best score: 0.563
Epoch 0, global step 450: 'val_loss' reached 0.56283 (best 0.56283), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0. New best score: 0.560
Epoch 0, global step 500: 'val_loss' reached 0.56014 (best 0.56014), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0. New best score: 0.559
Epoch 0, global step 550: 'val_loss' reached 0.55888 (best 0.55888), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 0.557
Epoch 0, global step 600: 'val_loss' reached 0.55680 (best 0.55680), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 650: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 700: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 750: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 800: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0. New best score: 0.555
Epoch 0, global step 850: 'val_loss' reached 0.55538 (best 0.55538), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 900: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 950: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 1000: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 0.554
Epoch 0, global step 1050: 'val_loss' reached 0.55353 (best 0.55353), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 1100: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 1150: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 1200: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 1250: 'val_loss' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.554. Signaling Trainer to stop.
Epoch 0, global step 1300: 'val_loss' was not in top 1
INFO:root:Re-loading from best cached checkpoint /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/CR_11M_24_11_01_big_posencscale__cvd-Fine-Tune.ckpt
INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Competing-Risk DeSurv head.
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1.0] with 1000 intervals
INFO:root:Using Competing-Risk DeSurv head.
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1.0] with 1000 intervals
INFO:root:Trainable parameters: all parameters
INFO:root:Testing model.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

Loaded model with 129.424472 M parameters


0,1
Scheduler,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test:OutcomePerformanceMetricsctd,▁
Test:OutcomePerformanceMetricsibs,▁
Test:OutcomePerformanceMetricsinbll,▁
Val:OutcomePerformanceMetricsctd,▁▄▃▄▄▄▅▅▆▆█▇▇▇▇▇█▇▇▇██▇███
Val:OutcomePerformanceMetricsibs,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val:OutcomePerformanceMetricsinbll,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
test_loss,▁
train_loss,██▅▄▃▂▃▃▅▄▂▃▄▂▄▃▄▂▂▂▂▄▄▅▂▃▅▂▂▂▁▂▄▃▃▄▆▁▅▁

0,1
Scheduler,0.0001
Test:OutcomePerformanceMetricsctd,0.64245
Test:OutcomePerformanceMetricsibs,0.03376
Test:OutcomePerformanceMetricsinbll,0.14482
Val:OutcomePerformanceMetricsctd,0.64511
Val:OutcomePerformanceMetricsibs,0.03414
Val:OutcomePerformanceMetricsinbll,0.14665
epoch,1.0
test_loss,0.57647
train_loss,0.27516


In [10]:
[_i if _i == _i.upper() else 0 for _i in dm.train_set.tokenizer._stoi.keys()]

['PAD',
 'UNK',
 'ADDISONS_DISEASE',
 'CYSTICFIBROSIS',
 'SYSTEMIC_SCLEROSIS',
 'SICKLE_CELL_DISEASE_V2',
 'ADDISON_DISEASE',
 'DOWNSSYNDROME',
 'HAEMOCHROMATOSIS_V2',
 'PLASMACELL_NEOPLASM_V2',
 'SJOGRENSSYNDROME',
 'SYSTEMIC_LUPUS_ERYTHEMATOSUS',
 'HIVAIDS',
 'PSORIATICARTHRITIS2021',
 'MS',
 0,
 'LEUKAEMIA_PREVALENCEV2',
 0,
 'ILD_SH',
 'CHRONIC_LIVER_DISEASE_ALCOHOL',
 'PERNICIOUSANAEMIA',
 'MENIERESDISEASE',
 'LYMPHOMA_PREVALENCE_V2',
 'CROHNS_DISEASE',
 0,
 0,
 'CHRONICFATIGUESYNDROMEMM_V2',
 0,
 'STROKE_HAEMRGIC',
 'PARKINSONS',
 'AORTICANEURYSM_V2',
 'BIPOLAR',
 'BRONCHIECTASIS',
 'ULCERATIVE_COLITIS',
 'SCHIZOPHRENIAMM_V2',
 'PTSDDIAGNOSIS',
 'TYPE1DM',
 'FIBROMYALGIA',
 'VISUAL_IMPAIRMENT',
 'AUTISM',
 'NAFLD_V2',
 'ISCHAEMICSTROKE_V2',
 0,
 'PVD_V3',
 'EATINGDISORDERS',
 'PMRANDGCA',
 'RHEUMATOIDARTHRITIS',
 0,
 'ENDOMETRIOSIS_ADENOMYOSIS_V2',
 'HYPERTHYROIDISM_V2',
 0,
 0,
 'OSA',
 0,
 0,
 0,
 'PAD_STRICT',
 'OTHER_CHRONIC_LIVER_DISEASE_OPTIMAL',
 'POLYCYSTIC_OVARIAN_SYNDRO

# Small 11.4M parameter model

```

```

In [None]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             "experiment.type='fine-tune'",
                             "experiment.run_id='CR_11M'",
                             "experiment.fine_tune_id='CVD_full-finetune-1e-4'",
                             "experiment.train=False",
                             "experiment.test=False",
                             'experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]',
                             # Dataloader
                             "data.path_to_ds=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=12",
                             # Optimiser
                             "optim.num_epochs=2",
                             "optim.limit_test_batches=null",
                             "optim.scheduler=ReduceOnPlateau",
                             "optim.scheduler_warmup=False",
                             "optim.learning_rate=1e-4",
                             "optim.val_check_interval=50",
                             "optim.early_stop=True",
                             "optim.limit_val_batches=0.035",
                             # Head
                             "head.surv_weight=1",
                             "head.value_weight=0",
                            ]
                 )     

model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
wandb.finish()


In [None]:
wandb.finish()

In [None]:
dm.tokenizer._event_counts