# CPRD Notebook:
## Evaluation of fine-tuning the pre-trained SurvivEHR-CR model on a supervised cohort study.

Cohort study:

This notebook quantifies the performance obtained when fine-tuning the pre-trained model

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

print(os.getcwd())

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvivEHR/notebooks/CompetingRisk/1_Study3_MultiMorbidity


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf

from FastEHR.dataloader.foundational_loader import FoundationalDataModule

from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

import time
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000


# TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

env: SLURM_NTASKS_PER_NODE=28


# Fine-tuning on full dataset
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `fine-tune` experiment type, which will lead to running the ```SupervisedExperiment```. 

We tell this experiment that we want to perform training (true by default). Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

We design a new optimisation strategy for fine-tuning. Pre-training was achieved with a warmup and cosine annealing, with rates which are no appropriate for much smaller dataset sizes seen in clinical prediction models (CPMs). We here choose a simpler strategy: of ReduceOnPlateau with no warmup, increasing the number of epochs (default is 1) and reduced validation intervals, and the addition of early stopping. Additionally, as this is not a causal model we can increase the batch size. Finally, as this CPM is not trying to predict the value of any outcomes, we set the value weight to zero allowing the model to focus entirely on optimising survival outcome prediction.

In [5]:
pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337']
experiments = ["mm"] 
experiment_types = [ "fine-tune-cr"] 
adapter = 8

## Get outcome list

In [9]:
if True:
    for pre_trained_model in pre_trained_model_ids[1:2]:
        print(pre_trained_model)
    
        for experiment, experiment_type in zip(experiments, experiment_types):
        
            wandb.finish()
            # load the configuration file, override any settings 
            with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
                cfg = compose(config_name="config_CompetingRisk11M", 
                              overrides=[# Experiment setup
                                         f"experiment.project_name='SurvivEHR-Study3-MM'",
                                         f"experiment.type='{experiment_type}'",
                                         f"experiment.run_id='{pre_trained_model}'",
                                         f"experiment.fine_tune_id='{experiment}-{experiment_type}-A{adapter}-notebook'",
                                         "experiment.train=True",
                                         "experiment.test=True",
                                         "experiment.notes=Table result",
                                         # Dataloader
                                         "data.batch_size=128",
                                         "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                         "data.min_workers=3",
                                         "data.global_diagnoses=True",
                                         "data.repeating_events=False",
                                         # Optimiser
                                         "optim.num_epochs=20",
                                         "optim.limit_test_batches=null",
                                         "optim.scheduler=ReduceOnPlateau",
                                         "optim.scheduler_warmup=False",
                                         "optim.learning_rate=1e-3",
                                         "optim.val_check_interval=50",
                                         "optim.early_stop=True",
                                         "optim.early_stop_patience=4",
                                         "optim.limit_val_batches=0.035",
                                         "optim.accumulate_grad_batches=4",
                                         # Model
                                         # "transformer.n_embd=384",
                                         f"transformer.use_fine_tune_adapter={False if adapter is False else True}",
                                         f"transformer.adapter_dim={8 if adapter is False else adapter}",
                                         "transformer.block_size=512", 
                                        ]
                             )
            
            match experiment.lower():
                case "mm":
                    cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity/"

                    # Load dataset in most minimal form (this isnt used for the experiment - only to extract the token names for the diagnoses)
                    dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                                                path_to_ds=cfg.data.path_to_ds,
                                                overwrite_meta_information=cfg.data.meta_information_path,
                                                load=True)
                    conditions = (
                        dm.tokenizer._event_counts.filter((pl.col("COUNT") > 0) &
                            (pl.col("EVENT").str.contains(r'^[A-Z0-9_]+$')))
                          .select("EVENT")
                          .to_series()
                          .to_list()
                    )
                    cfg.experiment.fine_tune_outcomes=conditions
            
            
            model, dm = run(cfg)
            print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
            wandb.finish()



SurvivEHR-cr-small-v1


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:173: Found keys that are in the model state dict but not in the checkpoint: ['model.transformer.blocks.0.adapter_1.proj.0.weight', 'model.transformer.blocks.0.adapter_1.proj.0.bias', 'model.transformer.blocks.0.adapter_1.proj.2.weight', 'model.transformer.blocks.0.adapter_1.proj.2.bias', 'model.transformer.blocks.0.adapter_2.proj.0.weight', 'model.transformer.blocks.0.adapter_2.proj.0.bias', 'model.transformer.blocks.0.adapter_2.proj.2.weight', 'model.transformer.blocks.0.adapter_2.proj.2.bias', 'model.transformer.blocks.1.adapter_1.proj.0.weight', 'model.transformer.blocks.1.adapter_1.proj.0.bias', 'model.transformer.blocks.1.adapter_1.proj.2.weight', 'model.transformer.blocks.1.adapter_1.proj.2.bias', 'model.transformer.blocks.1.adapter_2.proj.0.weight', 'model.transformer.blocks.1.adapter_2.proj.0.bias', 'model.transformer.blocks.1.adapte

/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                            | Params
---------------------------------------------------------------
0 | model      | SurvStreamGPTForCausalModelling | 11.3 M
1 | surv_layer | ODESurvCompetingRiskLayer       | 31.7 K
---------------------------------------------------------------
120 K     Trainable params
11.2 M    Non-trainable params
11.3 M    Total params
45.286    Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[14.0833,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 3.152
Epoch 0, global step 12: 'val_loss' reached 3.15236 (best 3.15236), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.189 >= min_delta = 0. New best score: 2.963
Epoch 0, global step 25: 'val_loss' reached 2.96288 (best 2.96288), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.7600,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5375,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample 

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.111 >= min_delta = 0. New best score: 2.852
Epoch 0, global step 37: 'val_loss' reached 2.85193 (best 2.85193), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[215,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.8285,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5233,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample 

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.066 >= min_delta = 0. New best score: 2.786
Epoch 0, global step 50: 'val_loss' reached 2.78634 (best 2.78634), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.2121,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.3375,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.020 >= min_delta = 0. New best score: 2.767
Epoch 0, global step 62: 'val_loss' reached 2.76654 (best 2.76654), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5288,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5638,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample 

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.022 >= min_delta = 0. New best score: 2.744
Epoch 0, global step 75: 'val_loss' reached 2.74416 (best 2.74416), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.5830,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.014 >= min_delta = 0. New best score: 2.730
Epoch 0, global step 87: 'val_loss' reached 2.73019 (best 2.73019), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.4055,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5101,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.016 >= min_delta = 0. New best score: 2.714
Epoch 0, global step 100: 'val_loss' reached 2.71441 (best 2.71441), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.4860,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.4323,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.009 >= min_delta = 0. New best score: 2.705
Epoch 0, global step 112: 'val_loss' reached 2.70536 (best 2.70536), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.5244,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0. New best score: 2.700
Epoch 0, global step 125: 'val_loss' reached 2.69967 (best 2.69967), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[207,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.1644,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[242,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.1156,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 137: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0. New best score: 2.693
Epoch 0, global step 150: 'val_loss' reached 2.69336 (best 2.69336), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[255,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.1397,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.6658,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 2.692
Epoch 0, global step 162: 'val_loss' reached 2.69169 (best 2.69169), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.8033,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.0948,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.0789,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.010 >= min_delta = 0. New best score: 2.681
Epoch 0, global step 175: 'val_loss' reached 2.68142 (best 2.68142), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[167,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.4592,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.1326,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.009 >= min_delta = 0. New best score: 2.672
Epoch 0, global step 187: 'val_loss' reached 2.67201 (best 2.67201), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.0625,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0. New best score: 2.666
Epoch 0, global step 200: 'val_loss' reached 2.66574 (best 2.66574), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0. New best score: 2.659
Epoch 0, global step 212: 'val_loss' reached 2.65913 (best 2.65913), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.8619,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[167,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.2745,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 2.657
Epoch 0, global step 225: 'val_loss' reached 2.65725 (best 2.65725), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[167,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.0871,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0. New best score: 2.650
Epoch 0, global step 237: 'val_loss' reached 2.65046 (best 2.65046), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.6263,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0. New best score: 2.648
Epoch 0, global step 250: 'val_loss' reached 2.64795 (best 2.64795), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.4959,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.012 >= min_delta = 0. New best score: 2.636
Epoch 0, global step 262: 'val_loss' reached 2.63602 (best 2.63602), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.8334,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.3375,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.6784,  0.0000,  0.0000,  0.0000,  0.0000]])
Epoch 0, global step 275: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.9025,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5595,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
    

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.016 >= min_delta = 0. New best score: 2.620
Epoch 0, global step 287: 'val_loss' reached 2.62016 (best 2.62016), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[255,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.4274,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.1764,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 300: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.4153,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.1332,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.6389,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0. New best score: 2.617
Epoch 0, global step 312: 'val_loss' reached 2.61679 (best 2.61679), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[189,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.5293,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.015 >= min_delta = 0. New best score: 2.602
Epoch 0, global step 325: 'val_loss' reached 2.60180 (best 2.60180), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[255,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.9616,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.9479,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0. New best score: 2.601
Epoch 0, global step 337: 'val_loss' reached 2.60150 (best 2.60150), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[152,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.8937,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0. New best score: 2.598
Epoch 0, global step 350: 'val_loss' reached 2.59778 (best 2.59778), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[210,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.4734,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0],
        [152,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.1742,  0.0000,  0.0000,  0.0000,  0.0000],
        [10.714

Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[152,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.1277,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.015 >= min_delta = 0. New best score: 2.583
Epoch 0, global step 362: 'val_loss' reached 2.58251 (best 2.58251), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.1381,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0. New best score: 2.575
Epoch 0, global step 375: 'val_loss' reached 2.57545 (best 2.57545), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[13.9342,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.6756,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0. New best score: 2.573
Epoch 0, global step 387: 'val_loss' reached 2.57270 (best 2.57270), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.2904,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.8745,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0. New best score: 2.570
Epoch 0, global step 400: 'val_loss' reached 2.56961 (best 2.56961), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[255,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.5644,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[182,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.3414,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0. New best score: 2.566
Epoch 0, global step 412: 'val_loss' reached 2.56603 (best 2.56603), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[254,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.0444,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0. New best score: 2.566
Epoch 0, global step 425: 'val_loss' reached 2.56601 (best 2.56601), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.7666,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.7578,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.005 >= min_delta = 0. New best score: 2.561
Epoch 0, global step 437: 'val_loss' reached 2.56134 (best 2.56134), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.1847,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 2.560
Epoch 0, global step 450: 'val_loss' reached 2.55954 (best 2.55954), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 462: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[152,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.0027,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.9770,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.001 >= min_delta = 0. New best score: 2.558
Epoch 0, global step 475: 'val_loss' reached 2.55808 (best 2.55808), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 487: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.3584,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.1414,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.4016,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
    

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 500: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.1375,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[189,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.5512,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.4203,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.6099,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 2.556
Epoch 0, global step 512: 'val_loss' reached 2.55598 (best 2.55598), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.0663,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.009 >= min_delta = 0. New best score: 2.547
Epoch 0, global step 525: 'val_loss' reached 2.54670 (best 2.54670), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[207,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.3814,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[152,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.3868,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 537: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[177,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.0433,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.5233,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.6049,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
    

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0. New best score: 2.547
Epoch 0, global step 550: 'val_loss' reached 2.54650 (best 2.54650), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.4493,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.6784,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.1644,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.003 >= min_delta = 0. New best score: 2.543
Epoch 0, global step 562: 'val_loss' reached 2.54336 (best 2.54336), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.4690,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 2.541
Epoch 0, global step 575: 'val_loss' reached 2.54113 (best 2.54113), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[251,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.3912,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0],
        [152,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.7162,  0.0000,  0.0000,  0.0000,  0.0000],
        [10.326

Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.6866,  0.0000,  0.0000,  0.0000,  0.0000]])
Metric val_loss improved by 0.001 >= min_delta = 0. New best score: 2.540
Epoch 0, global step 587: 'val_loss' reached 2.53979 (best 2.53979), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[255,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.6471,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 600: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 612: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.8060,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[167,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.3879,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0. New best score: 2.538
Epoch 0, global step 625: 'val_loss' reached 2.53782 (best 2.53782), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0. New best score: 2.536
Epoch 0, global step 637: 'val_loss' reached 2.53640 (best 2.53640), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.2153,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 650: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.7474,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0. New best score: 2.533
Epoch 0, global step 662: 'val_loss' reached 2.53329 (best 2.53329), saving model to '/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/SurvivEHR-cr-small-v1_mm-fine-tune-cr-A8-notebook.ckpt' as top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.9764,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.5074,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 675: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 687: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.0652,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[11.6049,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.7386,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
    

Validation: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[207,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.0258,  0.0000,  0.0000,  0.0000,  0.0000]])
Epoch 0, global step 700: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 4 records. Best score: 2.533. Signaling Trainer to stop.
Epoch 0, global step 712: 'val_loss' was not in top 1
                                    		 Bad sample tokens: tensor([[184,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[12.1134,  0.0000,  0.0000,  0.0000,  0.0000]])
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

                                    		 Bad sample tokens: tensor([[255,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.7184,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[234,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.1036,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[152,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.2855,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[217,   0,   0,   0,   0]])
                                    		 and corresponding ages tensor([[10.0384,  0.0000,  0.0000,  0.0000,  0.0000]])
                                    		 Bad sample tokens: tensor([[260,   0,   0,   0,   0]])
                                    		 and corresponding a

Loaded model with 11.321432 M parameters


VBox(children=(Label(value='244.416 MB of 244.416 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
Scheduler,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test:OutcomePerformanceMetricsctd,▁
Test:OutcomePerformanceMetricsibs,▁
Test:OutcomePerformanceMetricsinbll,▁
Val:OutcomePerformanceMetricsctd,▁▂▃▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇██████████████████████
Val:OutcomePerformanceMetricsibs,█▆▆▅▅▄▄▄▄▄▃▃▃▃▂▃▂▂▂▃▂▂▁▁▁▁▁▂▂▁▁▁▁▂▁▁▂▁▁▁
Val:OutcomePerformanceMetricsinbll,█▇▆▆▆▅▅▅▅▄▄▃▃▄▃▄▂▃▂▃▂▂▂▁▂▁▁▂▂▁▂▁▁▂▂▁▂▁▁▁
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
test_loss,▁
train_loss,█▇▇▄▆▄▇▅▆▄▆▆▄▅▄▆▄▂▃▄▅▄▃▄▇▅▃▃▇▁▄▄▅▆▃

0,1
Scheduler,0.001
Test:OutcomePerformanceMetricsctd,0.72615
Test:OutcomePerformanceMetricsibs,0.10713
Test:OutcomePerformanceMetricsinbll,0.33452
Val:OutcomePerformanceMetricsctd,0.7278
Val:OutcomePerformanceMetricsibs,0.09737
Val:OutcomePerformanceMetricsinbll,0.31127
epoch,1.0
test_loss,2.73171
train_loss,2.51459


```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
└─────────────────────────────────────┴─────────────────────────────────────┘
```

In [5]:
# [_i if _i == _i.upper() else 0 for _i in dm.train_set.tokenizer._stoi.keys()]

# Fine-tuning on sub-set of data

In [6]:
sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]
print(sample_sizes)

[2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000]


In [None]:
sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]      # [3000, 12500, 30000, 60000, 100000]: # 600, 1200, 
sample_sizes = [None]

accumulate_grad_batches = 5


for pre_trained_model in pre_trained_model_ids[1:2]:
    print(pre_trained_model)

    for experiment, experiment_type in zip(experiments, experiment_types):

        for sample_size in sample_sizes:

            wandb.finish()
            # load the configuration file, override any settings 
            with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
                cfg = compose(config_name="config_CompetingRisk11M", 
                              overrides=[# Experiment setup
                                         f"experiment.project_name='SurvivEHR-Study1-CVD'",
                                         f"experiment.type='{experiment_type}'",
                                         f"experiment.run_id='{pre_trained_model}'",
                                         f"experiment.fine_tune_id='{experiment}-{experiment_type}-A{adapter}-Ns{sample_size}-500-notebook'",
                                         "experiment.train=True",
                                         "experiment.test=True",
                                         "experiment.notes=Ablation on increasing cohort study size result",
                                         # Dataloader
                                         "data.batch_size=128",
                                         "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                         "data.min_workers=3",
                                         "data.global_diagnoses=True",
                                         "data.repeating_events=False",
                                         # f"data.subsample_training={sample_size}",
                                         # Optimiser
                                         "optim.num_epochs=500",
                                         "optim.limit_test_batches=null",
                                         "optim.scheduler=ReduceOnPlateau",
                                         "optim.scheduler_warmup=False",
                                         "optim.learning_rate=1e-3",
                                         "optim.val_check_interval=0.25",
                                         "optim.early_stop=True",
                                         "optim.early_stop_patience=4",
                                         "optim.limit_val_batches=0.035",
                                         f"optim.accumulate_grad_batches={accumulate_grad_batches}",
                                         # Model
                                         # "transformer.n_embd=384",
                                         f"transformer.use_fine_tune_adapter={False if adapter is False else True}",
                                         f"transformer.adapter_dim={8 if adapter is False else adapter}",
                                         "transformer.block_size=500", 
                                        ]
                             )
            
            
            match experiment.lower():
                case "cvd":
                    cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
                    cfg.experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]
                case "hypertension":
                    cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/"
                    cfg.experiment.fine_tune_outcomes=["HYPERTENSION"]
            
            
            model, dm = run(cfg)
            print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
            wandb.finish()


SurvivEHR-cr-small-v1


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:173: Found keys that are in the model state dict but not in the checkpoint: ['model.transformer.blocks.0.adapter_1.proj.0.weight', 'model.transformer.blocks.0.adapter_1.proj.0.bias', 'model.transformer.blocks.0.adapter_1.proj.2.weight', 'model.transformer.blocks.0.adapter_1.proj.2.bias', 'model.transformer.blocks.0.adapter_2.proj.0.weight', 'model.transformer.blocks.0.adapter_2.proj.0.bias', 'model.transformer.blocks.0.adapter_2.proj.2.weight', 'model.transformer.blocks.0.adapter_2.proj.2.bias', 'model.transformer.blocks.1.adapter_1.proj.0.weight', 'model.transformer.blocks.1.adapter_1.proj.0.bias', 'model.transformer.blocks.1.adapter_1.proj.2.weight', 'model.transformer.blocks.1.adapter_1.proj.2.bias', 'model.transformer.blocks.1.adapter_2.proj.0.weight', 'model.transformer.blocks.1.adapter_2.proj.0.bias', 'model.transformer.blocks.1.adapte

/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                            | Params
---------------------------------------------------------------
0 | model      | SurvStreamGPTForCausalModelling | 11.3 M
1 | surv_layer | ODESurvCompetingRiskLayer       | 27.1 K
---------------------------------------------------------------
115 K     Trainable params
11.2 M    Non-trainable params
11.3 M    Total params
45.268    Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
wandb.finish()

In [None]:
dm.tokenizer._event_counts

In [None]:
all_CVD = {
    "SurvivEHR-cr-small-v1": (
        "SurvivEHR-CR",
        [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000]],
        [0.6270919442176819],
        [0.03389651543792928],
        [0.14657540914939907]
        ),
    
}

