# CPRD Notebook:
## Evaluation of the pre-trained SurvivEHR-CR model on a supervised cohort study, with no additional training (zero-shot).

Cohort study: predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population.

This notebook quantifies the performance obtained when applying the pre-trained model to a sub-population, without fine-tuning

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
import wandb
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf

from FastEHR.dataloader.foundational_loader import FoundationalDataModule

from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

import time
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

ModuleNotFoundError: No module named 'torch'

## Choosing configurations
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `zero-shot` experiment type, which will lead to running the ```SupervisedExperiment```. 

We tell this experiment that no further training is needed. Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

# Run small (11M) Competing-Risk model experiment

```
notebook-test3
  removing the value contribution to the loss function
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5369176864624023          │
│  Test:OutcomePerformanceMetricsibs  │         0.10644617670905024         │
│ Test:OutcomePerformanceMetricsinbll │         0.8475475761826847          │
│              test_loss              │         17.466724395751953          │

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5394920706748962          │
│  Test:OutcomePerformanceMetricsibs  │        0.035522244827896336         │
│ Test:OutcomePerformanceMetricsinbll │         0.24506395237725365         │
│              test_loss              │          17.61273765563965          │
└─────────────────────────────────────┴─────────────────────────────────────┘

notebook-test4-3232
   adding value contribution back in, but making DeSurv head deeper
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5443774461746216          │
│  Test:OutcomePerformanceMetricsibs  │         0.10633526848584007         │
│ Test:OutcomePerformanceMetricsinbll │         0.7019155487106575          │
│              test_loss              │         15.524889945983887          │
└─────────────────────────────────────┴─────────────────────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5000342130661011          │
│  Test:OutcomePerformanceMetricsibs  │         0.03566100741780196         │
│ Test:OutcomePerformanceMetricsinbll │         0.5747876250223403          │
│              test_loss              │          17.3300838470459           │
└─────────────────────────────────────┴─────────────────────────────────────┘

SurvivEHR-cr-small
    column 1: a bit over one epoch,
    column 2: a bit over one epoch further

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5091821551322937          │         0.5178812742233276          │
│  Test:OutcomePerformanceMetricsibs  │         0.10525757690619604         │         0.10537813909147978         │
│ Test:OutcomePerformanceMetricsinbll │         0.6082171127343547          │         0.7108097155724724          │
│              test_loss              │          1.011661171913147          │        -0.11626045405864716         │
└─────────────────────────────────────┴─────────────────────────────────────┴─────────────────────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5508394241333008          │         0.5341746807098389          │
│  Test:OutcomePerformanceMetricsibs  │          0.035395744412963          │         0.03552938539484051         │
│ Test:OutcomePerformanceMetricsinbll │         0.22044220403102463         │          0.225344592765462          │
│              test_loss              │         1.6153547763824463          │         0.3721931278705597          │
└─────────────────────────────────────┴──────────────────────────────────────┴─────────────────────────────────────┘

SurvivEHR-cr
    column 1: barely trained, 
    column 2: one full epoch further
    column 3: (queued) one full epoch futher (but higher CA learning rate)

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳
┃             Test metric             ┃            DataLoader 0             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇
│  Test:OutcomePerformanceMetricsctd  │         0.4814554750919342          │         0.5357928276062012          │
│  Test:OutcomePerformanceMetricsibs  │         0.10645012149839331         │         0.10630172990591351         │
│ Test:OutcomePerformanceMetricsinbll │         0.7736776537668243          │         0.7560898019817786          │
│              test_loss              │         11.424819946289062          │         3.4465575218200684          │
└─────────────────────────────────────┴─────────────────────────────────────┘─────────────────────────────────────┘

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5488200783729553          │
│  Test:OutcomePerformanceMetricsibs  │         0.10655508144812552         │
│ Test:OutcomePerformanceMetricsinbll │         0.8669728193358219          │
│              test_loss              │         0.8350259065628052          │
└─────────────────────────────────────┴─────────────────────────────────────┘

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5324709415435791          │         0.5513700842857361          │
│  Test:OutcomePerformanceMetricsibs  │         0.0356440989232476          │         0.03561174257369593         │
│ Test:OutcomePerformanceMetricsinbll │         0.3058541084156331          │         0.27518337504141976         │
│              test_loss              │         12.514366149902344          │          4.268644332885742          │
└─────────────────────────────────────┴─────────────────────────────────────┘─────────────────────────────────────┘

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5337772965431213          │
│  Test:OutcomePerformanceMetricsibs  │        0.035621223180408765         │
│ Test:OutcomePerformanceMetricsinbll │          0.281117343849122          │
│              test_loss              │         0.9384847283363342          │
└─────────────────────────────────────┴─────────────────────────────────────┘




SurvivEHR-cr-384
    column 1: one full epoch
    column 2: (running) one full epoch further (but higher CA learning rate)

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5163285732269287          │
│  Test:OutcomePerformanceMetricsibs  │         0.10512818377777854         │
│ Test:OutcomePerformanceMetricsinbll │         0.5792421591597969          │
│              test_loss              │          3.765381097793579          │
└─────────────────────────────────────┴─────────────────────────────────────┘─────────────────────────────────────┘

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5281378030776978          │
│  Test:OutcomePerformanceMetricsibs  │         0.10532110992913866         │
│ Test:OutcomePerformanceMetricsinbll │         0.5780003907973813          │
│              test_loss              │         -1.1865565776824951         │
└─────────────────────────────────────┴─────────────────────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.4940091669559479          │
│  Test:OutcomePerformanceMetricsibs  │         0.10620601581280624         │
│ Test:OutcomePerformanceMetricsinbll │         0.7204389549016048          │
│              test_loss              │         -2.023712158203125          │
└─────────────────────────────────────┴─────────────────────────────────────┘




┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5163038372993469          │
│  Test:OutcomePerformanceMetricsibs  │         0.03557852038258427         │
│ Test:OutcomePerformanceMetricsinbll │         0.2669052804846556          │
│              test_loss              │         5.5669636726379395          │
└─────────────────────────────────────┴─────────────────────────────────────┘─────────────────────────────────────┘

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5406502485275269          │
│  Test:OutcomePerformanceMetricsibs  │         0.03548340070928004         │
│ Test:OutcomePerformanceMetricsinbll │         0.23103872763338784         │
│              test_loss              │         0.7352104187011719          │
└─────────────────────────────────────┴─────────────────────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│  Test:OutcomePerformanceMetricsctd  │         0.5240659713745117          │
│  Test:OutcomePerformanceMetricsibs  │        0.035569368165653606         │
│ Test:OutcomePerformanceMetricsinbll │         0.26702761352148036         │
│              test_loss              │         -0.2968777120113373         │
└─────────────────────────────────────┴─────────────────────────────────────┘



```

In [4]:
pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr', 'SurvivEHR-cr-v1', 'SurvivEHR-cr-v1-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337']
# CR_11M_24_11_01_big_posencscale_ - adjusted the positional encoding scalar, with single layer heads
# SurvivEHR-cr-small               -  with adjusted scalar, and deeper heads
# SurvivEHR-cr-384                 -  with bigger Transformer
# SurvivEHR-cr                     -  with bigger Transformer, and 1024 latent dimensions
experiments = ["cvd"]

for pre_trained_model_id in pre_trained_model_ids[-1:]:
    for experiment in experiments:
        
        print(f"Running {experiment} on pre-trained model {pre_trained_model_id}")
        
        # load the configuration file, override any settings 
        with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
            cfg = compose(config_name="config_CompetingRisk_tiny",
                          overrides=[# Experiment setup
                                     f"experiment.project_name='SurvivEHR-Study1:CVD'",
                                     "experiment.type='zeroshot'",
                                     f"experiment.run_id='{pre_trained_model_id}'",
                                     f"experiment.fine_tune_id='{experiment}-0shot'",
                                     "experiment.train=False",
                                     "experiment.test=True",
                                     # Dataloader
                                     "data.batch_size=128",
                                     "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                     "data.min_workers=3",
                                     "data.global_diagnoses=True",
                                     # Optimiser
                                     "optim.limit_test_batches=null",
                                     # Model 
                                     "transformer.use_fine_tune_adapter=False",
                                     # "transformer.n_embd=1024",   # 384
                                     "transformer.block_size=512", 
                                    ]
                         )   
        
        match experiment.lower():
            case "cvd":
                cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"
                cfg.experiment.fine_tune_outcomes=["IHDINCLUDINGMI_OPTIMALV2", "ISCHAEMICSTROKE_V2", "MINFARCTION", "STROKEUNSPECIFIED_V2", "STROKE_HAEMRGIC"]
            case "hypertension":
                cfg.data.path_to_ds="/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/"
                cfg.experiment.fine_tune_outcomes=["HYPERTENSION"]
        
        
        # print(OmegaConf.to_yaml(cfg))
        model, dm = run(cfg)
        print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
        
        wandb.finish()


Running cvd on pre-trained model crPreTrain_small_1337


INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/. This will be loaded in supervised form.
INFO:root:Creating supervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

Loaded model with 1.55879 M parameters


0,1
Test:OutcomePerformanceMetricsctd,▁
Test:OutcomePerformanceMetricsibs,▁
Test:OutcomePerformanceMetricsinbll,▁
epoch,▁
test_loss,▁
trainer/global_step,▁

0,1
Test:OutcomePerformanceMetricsctd,0.57766
Test:OutcomePerformanceMetricsibs,0.03555
Test:OutcomePerformanceMetricsinbll,0.24701
epoch,0.0
test_loss,6.45279
trainer/global_step,0.0


In [4]:
dm.train_set[0]["ages"][1:] - dm.train_set[0]["ages"][:-1]
dm.train_set[0].keys()

dict_keys(['static_covariates', 'tokens', 'ages', 'values'])

In [5]:
# import wandb
print(model)
# wandb.finish()

FewShotExperiment(
  (model): SurvStreamGPTForCausalModelling(
    (transformer): TTETransformer(
      (wpe): TemporalPositionalEncoding()
      (wte): DataEmbeddingLayer(
        (static_proj): Linear(in_features=16, out_features=192, bias=True)
        (dynamic_embedding_layer): SplitDynamicEmbeddingLayer(
          (cat_event_embed_layer): Embedding(265, 192, padding_idx=0)
          (cat_event_proj): Linear(in_features=192, out_features=192, bias=True)
          (num_value_embed_layer): EmbeddingBag(265, 192, mode='sum', padding_idx=0)
          (num_value_proj): Linear(in_features=192, out_features=192, bias=True)
        )
      )
      (drop): Dropout(p=0.0, inplace=False)
      (blocks): ModuleList(
        (0-2): 3 x Block(
          (ln_1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): MultiHeadedSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): L

In [6]:
dm.encode(['IHDINCLUDINGMI_OPTIMALV2', 'ISCHAEMICSTROKE_V2', 'MINFARCTION', 'STROKEUNSPECIFIED_V2', 'STROKE_HAEMRGIC'])
# display(dm.decode([95, 175, 263,249]).split(" "))

[95, 41, 67, 65, 28]

In [7]:
dm.tokenizer._event_counts["EVENT"][-5:].to_list()

['NSAIDS_oral_OPTIMAL_final',
 'Diastolic_blood_pressure_5',
 'Systolic_blood_pressure_4',
 'Statins',
 'Lipid_lowering_drugs_Optimal']

In [8]:
save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

# Load Pre-Trained model

In [9]:
ckpt_path = cfg.experiment.log_dir + f'checkpoints/{cfg.experiment.run_id}.ckpt'
model = SurvivalExperiment.load_from_checkpoint(ckpt_path)

NameError: name 'SurvivalExperiment' is not defined

# Initialise fine-tuning data module

In [None]:
# Update dataset path to point to the new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"

# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)

In [None]:
for batch in dm.train_dataloader():
    break
display(batch)

In [None]:
num_diagnoses = len(dm.train_set.meta_information["diagnosis_table"]["count"])
num_diagnosis_events = sum(dm.train_set.meta_information["diagnosis_table"]["count"])

is_medication = dm.train_set.meta_information["measurement_tables"]["count_obs"] == 0

num_medications = sum(is_medication)
num_medication_events = sum(dm.train_set.meta_information["measurement_tables"][is_medication]["count"])
num_measurement_test = sum(~is_medication)
num_measurement_test_events = sum(dm.train_set.meta_information["measurement_tables"][~is_medication]["count"])

num_measurement_test_events = sum(dm.train_set.meta_information["measurement_tables"][~is_medication]["count_obs"])

print(f'{num_diagnosis_events:,} diagnoses of {num_diagnoses} types')
print(f'{num_medication_events:,} medications of {num_medications} types')
print(f'{num_measurement_test_events:,} measurements and tests of {num_measurement_test} types')
print(f'{num_diagnoses+num_medication_events+num_measurement_test_events:,}')

print(f'{num_measurement_test_events:,}')
dm.train_set.meta_information.keys()

print(dm.train_set.tokenizer._event_counts)

In [None]:
# # import pickle as pkl
# # import pathlib

# pkl_file_to_amend = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle"

# with open(pkl_file_to_amend, 'rb') as pickle_file:
#     content = pickle.load(pickle_file)
# display(content)

# # new_dictionary = {}
# # for key in content.keys():
# #     str_to_remove = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/split=val/"
# #     new_key = str(key)[len(str_to_remove):]
# #     new_dictionary[new_key] = content[key]
# # display(new_dictionary)


# # with open(pkl_file_to_amend, 'wb') as handle:
# #     pickle.dump(new_dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
# new_dictionary

In [None]:
import copy
start = time.time()   # starting time
for batch in dm.train_dataloader():
    # print(batch["tokens"][1,:])
    
    c_batch = convert_batch_to_none_causal(batch)
    # print(c_batch["tokens"][1,:])
    # print(c_batch["target_token"][1])

    # print(batch["tokens"][1,:])
    
    break
    
print(f"batch loaded in {time.time()-start} seconds")    
    
# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

In [None]:
display(batch.keys())
display(c_batch.keys())

print(batch["static_covariates"].shape)

# print(dm.train_set.static_1hot)
# print(dm.train_set.static_1hot["SEX"].categories_)
# print(dm.train_set.static_1hot["IMD"].categories_)
# print(dm.train_set.static_1hot["ETHNICITY"].categories_)

print(batch["tokens"][1,:])
print(c_batch["tokens"][1,:])
print(c_batch["target_token"][1])

## View an example sample

In [None]:
dm.test_set.view_sample(11003, max_dynamic_events=None, report_time=True)

# Custom wrapper prediction last token

To begin with, I will just loop over samples individually to test the zero-shot capacity of SurvivEHR. 

In [None]:


# Verifying on datamodule 
for _idx, batch in enumerate(dm.test_dataloader()):
    if _idx > 10:
        break
    print(_idx)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))
    batch = replace_last_non_pad_with_pad(batch)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))

In [None]:
outcome_of_interest = ["COPD", "SUBSTANCEMISUSE"]
outcome_token = dm.encode(outcome_of_interest)[0]
print(outcome_token)
# print(model(batch))

In [None]:
Hs, labels = [], []
mins,maxes=[],[]
for _idx, batch in enumerate(dm.test_dataloader()):

    batch = replace_last_non_pad_with_pad(batch)
    print(batch["tokens".shape)
    outputs, _, hidden_states = model(batch, is_generation=True)
    print(outputs)
    
    hidden_states = hidden_states.cpu().detach().numpy()                           # (64, 128, 384) 
    Hs.append( hidden_states.reshape(hidden_states.shape[0], -1) )
    labels.append((batch["target_token"] == outcome_token).long().numpy())

    if _idx == 9:
        break



# Visualise hidden dimension labelled by target

In [None]:
import umap
from sklearn.preprocessing import StandardScaler

H = np.concatenate(Hs, 0)
lbl = np.concatenate(labels, 0)

H = StandardScaler().fit_transform(H)
reducer = umap.UMAP()
H_proj = reducer.fit_transform(H)

plt.close()
plt.scatter(H_proj[:,0], H_proj[:,1], c=lbl)
plt.savefig(save_path + f"zero_shot/hidden_umap.png")

In [None]:
print(outputs["surv"]["surv_CDF"][outcome_token].shape)

# The first two tokens in the vocab correspond to the PAD and UNK tokens. There is no CDF corresponding to the PAD token, so the indexing for surv_CDF begins as ["UNK", "ADDISONS_DISEASE", ...]
# print(dm.decode([0,1,2]))

outcomes = ["COPD", "SUBSTANCEMISUSE"]
outcome_tokens = dm.encode(outcomes)

# for outcome in outcomes:
    # observed_outcome_token = dm.encode([outcome])[0]
cdf = np.zeros_like(outputs["surv"]["surv_CDF"][0])
lbls = np.zeros_like(batch["target_token"])

for _outcome_token in outcome_tokens:
    cdf += outputs["surv"]["surv_CDF"][_outcome_token - 1] 
    lbls += (batch["target_token"] == _outcome_token).long().numpy()

plt.close()
cdf_true = cdf[lbls==1,:]
cdf_false = cdf[lbls==0,:]
for i in range(cdf_true.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_true[i,:], c="r", label="outcome occurred next" if i == 0 else None, alpha=1)
for i in range(cdf_false.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_false[i,:], c="k", label="outcome did not occur next" if i == 0 else None, alpha=0.3)

plt.legend(loc=2)
plt.xlabel("days")
plt.ylabel(f"P(t>T) - outcomes={','.join(outcomes)}")
plt.savefig(save_path + f"zero_shot/cdf_outcomes.png")

In [None]:
print(batch["target_token"].unique())
print(len(outputs["surv"]["surv_CDF"]))

In [None]:
dm.decode([2])

In [None]:
outputs["surv"]["surv_CDF"][observed_outcome_token - 1]