# Evaluation of a pre-trained SurvivEHR model


Environment setup for BlueBear (Birmingham HPC)

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")


Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import wandb
from hydra import compose, initialize
import polars as pl
pl.Config.set_tbl_rows(10000)
# import pandas as pd
# pd.options.display.max_rows = 10000
import logging
logging.basicConfig(level=logging.INFO)
import torch
torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

from FastEHR.dataloader.foundational_loader import FoundationalDataModule
from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

INFO:numexpr.utils:Note: detected 72 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 72 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.


Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


## Choosing configurations
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `zero-shot` experiment type, which will lead to running a ```CausalExperiment```. 

We tell this experiment that no further training is needed. Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

# Run small (11M) Competing-Risk model experiment

```

```

In [4]:
pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr', 'SurvivEHR-cr-v1', 'SurvivEHR-cr-v1-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337',
                        'SurvivEHR-cr-small-192', "SurvivEHR-cr-small-192-v1"]

pre_trained_model = "SurvivEHR-cr-small-debug3_2_exp1000-v1-v1" # "SurvivEHR-cr-small-debug3_2_leadsmall"  # 
print(pre_trained_model)

SurvivEHR-cr-small-debug3_2_exp1000-v1-v1


In [5]:
wandb.finish()

# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             "experiment.project_name='Evaluating pre-trained models'",
                             f"experiment.run_id='{pre_trained_model}'",
                             "experiment.train=False",
                             "experiment.test=True",
                             # Dataloader
                             "data.batch_size=128",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=3",
                             # "data.global_diagnoses=True",
                             # Optimiser
                             "optim.limit_test_batches=0.035",
                             # Model
                             # "transformer.n_embd=192",  #384
                             # "transformer.block_size=512", 
                            ]
                 )     

model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

wandb.finish()

INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/. This will be loaded in causal form.
INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 toke

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.trainer.connectors.signal_connector:SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

Loaded model with 11.211302 M parameters


0,1
Test:Cinter,▁
Test:Cinter+0,▁
Test:Cinter+1,▁
Test:Cinter+10,▁
Test:Cinter+13,▁
Test:Cinter+16,▁
Test:Cinter+19,▁
Test:Cinter+2,▁
Test:Cinter+3,▁
Test:Cinter+4,▁

0,1
Test:Cinter,0.99329
Test:Cinter+0,0.9838
Test:Cinter+1,0.92531
Test:Cinter+10,0.81695
Test:Cinter+13,0.75829
Test:Cinter+16,0.74239
Test:Cinter+19,0.72972
Test:Cinter+2,0.90371
Test:Cinter+3,0.88407
Test:Cinter+4,0.85536


In [5]:
wandb.finish()
print(pre_trained_model)

SurvivEHR-cr-small-debug3_2_exp1000


In [23]:
display(dm.encode(['IHDINCLUDINGMI_OPTIMALV2', 'ISCHAEMICSTROKE_V2', 'MINFARCTION', 'STROKEUNSPECIFIED_V2', 'STROKE_HAEMRGIC']))
display(dm.encode(['HYPERTENSION']))
# display(dm.decode([95, 175, 263,249]).split(" "))

[95, 41, 67, 65, 28]

[129]

In [24]:
dm.tokenizer._event_counts["EVENT"][-5:].to_list()

['NSAIDS_oral_OPTIMAL_final',
 'Diastolic_blood_pressure_5',
 'Systolic_blood_pressure_4',
 'Statins',
 'Lipid_lowering_drugs_Optimal']

In [None]:
raise NotImplementedError

In [None]:
# all_nondiagnosis = [cn for cn in dm.tokenizer._event_counts["EVENT"] if cn.upper() != cn]
# display([cn for cn in all_nondiagnosis if "Warf" in cn])

# Get the data from the callback

In [12]:
os.makedirs(f"figs/metrics/{pre_trained_model}/", exist_ok=True) 

In [8]:
api = wandb.Api()

In [None]:
# run = api.run("SurvivEHR/test_causal_eval")

In [6]:
raw_data_from_wandb = {
  "Test:Cinter": 0.9932898541314654,
  "Test:Cinter+0": 0.9837973135640176,
  "Test:Cinter+1": 0.9253109654737564,
  "Test:Cinter+10": 0.8169473112438893,
  "Test:Cinter+13": 0.7582933818261975,
  "Test:Cinter+16": 0.7423882901169279,
  "Test:Cinter+19": 0.7297150491066842,
  "Test:Cinter+2": 0.9037113784980492,
  "Test:Cinter+3": 0.884070654032631,
  "Test:Cinter+4": 0.8553641989115041,
  "Test:Cinter+7": 0.8054318305268877,
  "Test:Cintra100": 0.8998732572877058,
  "Test:Cintra101": 0.920152091254753,
  "Test:Cintra104": 0.841571609632446,
  "Test:Cintra106": 0.9158428390367556,
  "Test:Cintra107": 0.8403041825095057,
  "Test:Cintra11": 0.4410646387832699,
  "Test:Cintra110": 0.9192015209125476,
  "Test:Cintra112": 0.9411472970738968,
  "Test:Cintra113": 0.9976235741444868,
  "Test:Cintra114": 0.8846641318124209,
  "Test:Cintra115": 0.993789607097592,
  "Test:Cintra116": 0.996045627376426,
  "Test:Cintra118": 0.9824510090669784,
  "Test:Cintra119": 0.9804182509505708,
  "Test:Cintra120": 0.88212927756654,
  "Test:Cintra121": 0.9930291508238276,
  "Test:Cintra123": 0.9704995411039729,
  "Test:Cintra124": 0.958555133079848,
  "Test:Cintra126": 0.917979359043998,
  "Test:Cintra127": 0.9874049429657796,
  "Test:Cintra129": 0.945445528186477,
  "Test:Cintra130": 0.980513307984791,
  "Test:Cintra131": 0.9904065516232816,
  "Test:Cintra133": 0.9471301828716276,
  "Test:Cintra134": 0.946595229865192,
  "Test:Cintra135": 0.966624419095902,
  "Test:Cintra136": 0.9572056959665992,
  "Test:Cintra141": 0.6932826362484157,
  "Test:Cintra142": 0.9893708952644312,
  "Test:Cintra143": 0.9744210162461114,
  "Test:Cintra144": 0.9933460076045628,
  "Test:Cintra145": 0.9898605830164764,
  "Test:Cintra146": 0.9399335804014056,
  "Test:Cintra148": 0.9873475809623704,
  "Test:Cintra149": 0.9917950770462278,
  "Test:Cintra150": 0.9872243346007606,
  "Test:Cintra151": 0.9993387336749876,
  "Test:Cintra152": 0.9640245685873066,
  "Test:Cintra153": 0.9982088426609688,
  "Test:Cintra154": 0.94851711026616,
  "Test:Cintra155": 0.9809885931558934,
  "Test:Cintra156": 0.9961604413628572,
  "Test:Cintra157": 0.9767818137691126,
  "Test:Cintra158": 0.9987698501453814,
  "Test:Cintra159": 0.9863117870722432,
  "Test:Cintra160": 0.980445410103205,
  "Test:Cintra161": 0.9970865636264876,
  "Test:Cintra162": 0.9949302915082382,
  "Test:Cintra163": 0.9891261859919692,
  "Test:Cintra164": 0.998198919351611,
  "Test:Cintra165": 0.9938903513048004,
  "Test:Cintra166": 0.9972840847365564,
  "Test:Cintra167": 0.9837325075457648,
  "Test:Cintra168": 0.9834220532319392,
  "Test:Cintra169": 0.9906527249683142,
  "Test:Cintra170": 0.99527716629978,
  "Test:Cintra171": 0.9754286534184669,
  "Test:Cintra172": 0.9713144342301588,
  "Test:Cintra173": 0.9974651457541196,
  "Test:Cintra174": 0.9994568169473113,
  "Test:Cintra175": 0.9811998310097167,
  "Test:Cintra176": 0.9776557292400142,
  "Test:Cintra177": 0.9923003802281368,
  "Test:Cintra178": 0.9969420805922712,
  "Test:Cintra179": 0.9892357949981262,
  "Test:Cintra180": 0.9968993300742356,
  "Test:Cintra181": 0.993178259897115,
  "Test:Cintra182": 0.9868504435994933,
  "Test:Cintra183": 0.9978876214617663,
  "Test:Cintra184": 0.97596958174905,
  "Test:Cintra185": 0.8846641318124207,
  "Test:Cintra186": 0.998562917277926,
  "Test:Cintra187": 0.9991032355262216,
  "Test:Cintra188": 0.9935119802039956,
  "Test:Cintra189": 0.9761398123277638,
  "Test:Cintra190": 0.9955893536121676,
  "Test:Cintra191": 0.9745247148288974,
  "Test:Cintra192": 0.9769779676024796,
  "Test:Cintra193": 0.988022813688214,
  "Test:Cintra194": 0.9956191105967928,
  "Test:Cintra195": 0.9825855513307988,
  "Test:Cintra196": 0.9908808355552756,
  "Test:Cintra197": 0.9634980988593156,
  "Test:Cintra198": 0.953175609069145,
  "Test:Cintra199": 0.9898009393871618,
  "Test:Cintra200": 0.9969676020488326,
  "Test:Cintra201": 0.9948582786035264,
  "Test:Cintra202": 0.988178059674233,
  "Test:Cintra203": 0.9957048303055908,
  "Test:Cintra204": 0.9996510564566587,
  "Test:Cintra205": 0.9593244318684238,
  "Test:Cintra206": 0.9942366995000156,
  "Test:Cintra207": 0.984658850866073,
  "Test:Cintra208": 0.999550574622447,
  "Test:Cintra209": 0.9901680653362424,
  "Test:Cintra210": 0.9918282461189424,
  "Test:Cintra211": 0.9959497437592988,
  "Test:Cintra212": 0.9974936730527062,
  "Test:Cintra213": 0.9912036774303388,
  "Test:Cintra214": 0.9916349809885914,
  "Test:Cintra215": 0.9965503940045176,
  "Test:Cintra216": 0.9889100126742716,
  "Test:Cintra217": 0.9915124430108254,
  "Test:Cintra218": 0.993033776470262,
  "Test:Cintra219": 0.9948033092066932,
  "Test:Cintra220": 0.9998577886889608,
  "Test:Cintra221": 0.9936549429657796,
  "Test:Cintra222": 0.9930658395037016,
  "Test:Cintra223": 0.98715910580831,
  "Test:Cintra224": 0.9963147119040648,
  "Test:Cintra225": 0.9997998799279568,
  "Test:Cintra226": 0.9987823272475868,
  "Test:Cintra227": 0.9990383765142804,
  "Test:Cintra228": 0.98696360673547,
  "Test:Cintra229": 0.993861471522086,
  "Test:Cintra230": 0.995157293381,
  "Test:Cintra231": 0.9900491882431004,
  "Test:Cintra232": 0.99967132822066,
  "Test:Cintra233": 0.995883617126798,
  "Test:Cintra234": 0.9917667732113328,
  "Test:Cintra235": 0.9922622575296488,
  "Test:Cintra236": 0.9987498461598252,
  "Test:Cintra237": 0.9999740458609636,
  "Test:Cintra238": 0.9997224140936172,
  "Test:Cintra239": 0.9998972356386804,
  "Test:Cintra240": 0.9950464550416012,
  "Test:Cintra241": 0.9949798003802248,
  "Test:Cintra242": 0.9925780795988484,
  "Test:Cintra243": 0.999634980988593,
  "Test:Cintra244": 0.9931680608365002,
  "Test:Cintra245": 0.9982176806083638,
  "Test:Cintra246": 0.999464660745906,
  "Test:Cintra247": 0.9893553951831736,
  "Test:Cintra248": 0.9981339357119792,
  "Test:Cintra249": 0.996240682262465,
  "Test:Cintra25": 0.9239543726235742,
  "Test:Cintra250": 0.9952356956342476,
  "Test:Cintra251": 0.9939766824349132,
  "Test:Cintra252": 0.9972767444250328,
  "Test:Cintra253": 0.9988259622440128,
  "Test:Cintra254": 0.9957255950022336,
  "Test:Cintra255": 0.9856599674090168,
  "Test:Cintra256": 0.9928895663389206,
  "Test:Cintra257": 0.9968665238230888,
  "Test:Cintra258": 0.9999752294373364,
  "Test:Cintra259": 0.9969216730037996,
  "Test:Cintra260": 0.992981907396361,
  "Test:Cintra261": 0.9959181006885892,
  "Test:Cintra262": 0.9989362817421452,
  "Test:Cintra263": 0.9997711404033542,
  "Test:Cintra264": 0.9976003895019916,
  "Test:Cintra27": 0.8174904942965779,
  "Test:Cintra31": 0.8973384030418251,
  "Test:Cintra34": 0.4550063371356147,
  "Test:Cintra38": 0.4619771863117871,
  "Test:Cintra39": 0.9163498098859316,
  "Test:Cintra40": 0.7414448669201521,
  "Test:Cintra41": 0.5855513307984791,
  "Test:Cintra44": 0.8326996197718631,
  "Test:Cintra47": 0.870722433460076,
  "Test:Cintra48": 0.688212927756654,
  "Test:Cintra51": 0.8182509505703421,
  "Test:Cintra55": 1,
  "Test:Cintra57": 0.7243346007604563,
  "Test:Cintra58": 0.8935361216730038,
  "Test:Cintra60": 0.7110266159695817,
  "Test:Cintra63": 0.973384030418251,
  "Test:Cintra65": 0.7718631178707225,
  "Test:Cintra67": 0.688212927756654,
  "Test:Cintra70": 0.8225602027883396,
  "Test:Cintra71": 0.1520912547528517,
  "Test:Cintra75": 0.8153781157583437,
  "Test:Cintra78": 0.870722433460076,
  "Test:Cintra79": 0.8726235741444868,
  "Test:Cintra80": 0.9967409016838674,
  "Test:Cintra81": 0.7604562737642586,
  "Test:Cintra82": 0.8384030418250951,
  "Test:Cintra83": 0.8060836501901141,
  "Test:Cintra84": 0.8863117870722436,
  "Test:Cintra85": 0.7266159695817491,
  "Test:Cintra89": 0.7775665399239543,
  "Test:Cintra91": 0.9043092522179976,
  "Test:Cintra93": 0.832699619771863,
  "Test:Cintra94": 0.8647188312987794,
  "Test:Cintra95": 1,
  "Test:Cintra97": 0.8457360130363931,
  "Test:Cintra98": 0.960234474017744,
  "Test:Cintra99": 0.8479087452471483,
  "Test:base_Cinter": 0.8626828301681702,
  "Test:base_Cinter+0": 0.8549209927966827,
  "Test:base_Cinter+1": 0.8549209927966827,
  "Test:base_Cinter+10": 0.8498032617479898,
  "Test:base_Cinter+13": 0.848335813082052,
  "Test:base_Cinter+16": 0.8504674233111297,
  "Test:base_Cinter+19": 0.8518284716003348,
  "Test:base_Cinter+2": 0.8554057884602225,
  "Test:base_Cinter+3": 0.8562878467821434,
  "Test:base_Cinter+4": 0.8553855002076877,
  "Test:base_Cinter+7": 0.8501297603959204,
  "Test:base_Cintra100": 0.376425855513308,
  "Test:base_Cintra101": 0.38022813688212925,
  "Test:base_Cintra104": 0.3916349809885931,
  "Test:base_Cintra106": 0.3992395437262359,
  "Test:base_Cintra107": 0.4030418250950571,
  "Test:base_Cintra11": 0.03802281368821293,
  "Test:base_Cintra110": 0.41444866920152096,
  "Test:base_Cintra112": 0.4220532319391635,
  "Test:base_Cintra113": 0.4258555133079848,
  "Test:base_Cintra114": 0.4296577946768061,
  "Test:base_Cintra115": 0.4334600760456273,
  "Test:base_Cintra116": 0.4372623574144487,
  "Test:base_Cintra118": 0.4448669201520912,
  "Test:base_Cintra119": 0.44866920152091183,
  "Test:base_Cintra120": 0.45247148288973377,
  "Test:base_Cintra121": 0.4562737642585551,
  "Test:base_Cintra123": 0.4638783269961979,
  "Test:base_Cintra124": 0.467680608365019,
  "Test:base_Cintra126": 0.4752851711026616,
  "Test:base_Cintra127": 0.479087452471483,
  "Test:base_Cintra129": 0.48669201520912553,
  "Test:base_Cintra130": 0.49049429657794685,
  "Test:base_Cintra131": 0.49429657794676896,
  "Test:base_Cintra133": 0.5019011406844106,
  "Test:base_Cintra134": 0.5057034220532319,
  "Test:base_Cintra135": 0.5095057034220533,
  "Test:base_Cintra136": 0.5133079847908741,
  "Test:base_Cintra141": 0.532319391634981,
  "Test:base_Cintra142": 0.536121673003802,
  "Test:base_Cintra143": 0.539923954372624,
  "Test:base_Cintra144": 0.5437262357414447,
  "Test:base_Cintra145": 0.5475285171102662,
  "Test:base_Cintra146": 0.5513307984790876,
  "Test:base_Cintra148": 0.5589353612167296,
  "Test:base_Cintra149": 0.5627376425855521,
  "Test:base_Cintra150": 0.5665399239543727,
  "Test:base_Cintra151": 0.5703422053231935,
  "Test:base_Cintra152": 0.5741444866920152,
  "Test:base_Cintra153": 0.5779467680608357,
  "Test:base_Cintra154": 0.5817490494296573,
  "Test:base_Cintra155": 0.5855513307984791,
  "Test:base_Cintra156": 0.5893536121673008,
  "Test:base_Cintra157": 0.5931558935361212,
  "Test:base_Cintra158": 0.5969581749049425,
  "Test:base_Cintra159": 0.6007604562737643,
  "Test:base_Cintra160": 0.6045627376425854,
  "Test:base_Cintra161": 0.6083650190114078,
  "Test:base_Cintra162": 0.6121673003802284,
  "Test:base_Cintra163": 0.615969581749049,
  "Test:base_Cintra164": 0.6197718631178707,
  "Test:base_Cintra165": 0.6235741444866912,
  "Test:base_Cintra166": 0.6273764258555133,
  "Test:base_Cintra167": 0.6311787072243358,
  "Test:base_Cintra168": 0.6349809885931562,
  "Test:base_Cintra169": 0.6387832699619767,
  "Test:base_Cintra170": 0.642585551330798,
  "Test:base_Cintra171": 0.6463878326996174,
  "Test:base_Cintra172": 0.6501901140684417,
  "Test:base_Cintra173": 0.653992395437262,
  "Test:base_Cintra174": 0.6577946768060839,
  "Test:base_Cintra175": 0.6615969581749045,
  "Test:base_Cintra176": 0.6653992395437258,
  "Test:base_Cintra177": 0.6692015209125464,
  "Test:base_Cintra178": 0.6730038022813708,
  "Test:base_Cintra179": 0.6768060836501868,
  "Test:base_Cintra180": 0.6806083650190117,
  "Test:base_Cintra181": 0.6844106463878322,
  "Test:base_Cintra182": 0.6882129277566531,
  "Test:base_Cintra183": 0.6920152091254733,
  "Test:base_Cintra184": 0.6958174904942973,
  "Test:base_Cintra185": 0.6996197718631177,
  "Test:base_Cintra186": 0.7034220532319374,
  "Test:base_Cintra187": 0.7072243346007637,
  "Test:base_Cintra188": 0.7110266159695826,
  "Test:base_Cintra189": 0.7148288973384025,
  "Test:base_Cintra190": 0.718631178707225,
  "Test:base_Cintra191": 0.7224334600760467,
  "Test:base_Cintra192": 0.7262357414448671,
  "Test:base_Cintra193": 0.7300380228136875,
  "Test:base_Cintra194": 0.7338403041825081,
  "Test:base_Cintra195": 0.7376425855513306,
  "Test:base_Cintra196": 0.7414448669201497,
  "Test:base_Cintra197": 0.7452471482889734,
  "Test:base_Cintra198": 0.7490494296577949,
  "Test:base_Cintra199": 0.7528517110266153,
  "Test:base_Cintra200": 0.7566539923954361,
  "Test:base_Cintra201": 0.7604562737642592,
  "Test:base_Cintra202": 0.764258555133078,
  "Test:base_Cintra203": 0.7680608365019019,
  "Test:base_Cintra204": 0.7718631178707158,
  "Test:base_Cintra205": 0.7756653992395429,
  "Test:base_Cintra206": 0.7794676806083684,
  "Test:base_Cintra207": 0.7832699619771852,
  "Test:base_Cintra208": 0.7870722433460138,
  "Test:base_Cintra209": 0.790874524714831,
  "Test:base_Cintra210": 0.794676806083648,
  "Test:base_Cintra211": 0.7984790874524759,
  "Test:base_Cintra212": 0.8022813688212963,
  "Test:base_Cintra213": 0.8060836501901127,
  "Test:base_Cintra214": 0.8098859315589376,
  "Test:base_Cintra215": 0.8136882129277554,
  "Test:base_Cintra216": 0.817490494296578,
  "Test:base_Cintra217": 0.8212927756654119,
  "Test:base_Cintra218": 0.8250950570342191,
  "Test:base_Cintra219": 0.8288973384030506,
  "Test:base_Cintra220": 0.8326996197718595,
  "Test:base_Cintra221": 0.8365019011406819,
  "Test:base_Cintra222": 0.8403041825095002,
  "Test:base_Cintra223": 0.8441064638783322,
  "Test:base_Cintra224": 0.8479087452471428,
  "Test:base_Cintra225": 0.8517110266159725,
  "Test:base_Cintra226": 0.8555133079847834,
  "Test:base_Cintra227": 0.8593155893536092,
  "Test:base_Cintra228": 0.8631178707224335,
  "Test:base_Cintra229": 0.8669201520912636,
  "Test:base_Cintra230": 0.8707224334600675,
  "Test:base_Cintra231": 0.8745247148289041,
  "Test:base_Cintra232": 0.8783269961977118,
  "Test:base_Cintra233": 0.8821292775665415,
  "Test:base_Cintra234": 0.8859315589353532,
  "Test:base_Cintra235": 0.8897338403041877,
  "Test:base_Cintra236": 0.8935361216729947,
  "Test:base_Cintra237": 0.8973384030418267,
  "Test:base_Cintra238": 0.901140684410634,
  "Test:base_Cintra239": 0.9049429657794688,
  "Test:base_Cintra240": 0.908745247148278,
  "Test:base_Cintra241": 0.9125475285171164,
  "Test:base_Cintra242": 0.9163498098859406,
  "Test:base_Cintra243": 0.9201520912547496,
  "Test:base_Cintra244": 0.923954372623586,
  "Test:base_Cintra245": 0.9277566539923904,
  "Test:base_Cintra246": 0.9315589353612244,
  "Test:base_Cintra247": 0.9353612167300316,
  "Test:base_Cintra248": 0.9391634980988668,
  "Test:base_Cintra249": 0.9429657794676788,
  "Test:base_Cintra25": 0.09125475285171104,
  "Test:base_Cintra250": 0.9467680608365008,
  "Test:base_Cintra251": 0.950570342205312,
  "Test:base_Cintra252": 0.9543726235741544,
  "Test:base_Cintra253": 0.9581749049429704,
  "Test:base_Cintra254": 0.961977186311795,
  "Test:base_Cintra255": 0.9657794676806072,
  "Test:base_Cintra256": 0.9695817490494358,
  "Test:base_Cintra257": 0.9733840304182504,
  "Test:base_Cintra258": 0.9771863117870776,
  "Test:base_Cintra259": 0.9809885931558822,
  "Test:base_Cintra260": 0.984790874524719,
  "Test:base_Cintra261": 0.9885931558935528,
  "Test:base_Cintra262": 0.9923954372623228,
  "Test:base_Cintra263": 0.9961977186311592,
  "Test:base_Cintra264": 1,
  "Test:base_Cintra27": 0.0988593155893536,
  "Test:base_Cintra31": 0.1140684410646388,
  "Test:base_Cintra34": 0.12547528517110265,
  "Test:base_Cintra38": 0.14068441064638784,
  "Test:base_Cintra39": 0.1444866920152091,
  "Test:base_Cintra40": 0.1482889733840304,
  "Test:base_Cintra41": 0.1520912547528517,
  "Test:base_Cintra44": 0.1634980988593156,
  "Test:base_Cintra47": 0.17490494296577946,
  "Test:base_Cintra48": 0.17870722433460076,
  "Test:base_Cintra51": 0.19011406844106463,
  "Test:base_Cintra55": 0.20532319391634984,
  "Test:base_Cintra57": 0.2129277566539924,
  "Test:base_Cintra58": 0.21673003802281368,
  "Test:base_Cintra60": 0.22433460076045628,
  "Test:base_Cintra63": 0.23574144486692017,
  "Test:base_Cintra65": 0.2433460076045627,
  "Test:base_Cintra67": 0.2509505703422053,
  "Test:base_Cintra70": 0.2623574144486692,
  "Test:base_Cintra71": 0.2661596958174905,
  "Test:base_Cintra75": 0.2813688212927757,
  "Test:base_Cintra78": 0.29277566539923955,
  "Test:base_Cintra79": 0.2965779467680608,
  "Test:base_Cintra80": 0.3003802281368821,
  "Test:base_Cintra81": 0.3041825095057034,
  "Test:base_Cintra82": 0.30798479087452474,
  "Test:base_Cintra83": 0.311787072243346,
  "Test:base_Cintra84": 0.3155893536121673,
  "Test:base_Cintra85": 0.3193916349809886,
  "Test:base_Cintra89": 0.33460076045627374,
  "Test:base_Cintra91": 0.34220532319391633,
  "Test:base_Cintra93": 0.349809885931559,
  "Test:base_Cintra94": 0.3536121673003801,
  "Test:base_Cintra95": 0.3574144486692015,
  "Test:base_Cintra97": 0.3650190114068441,
  "Test:base_Cintra98": 0.3688212927756655,
  "Test:base_Cintra99": 0.3726235741444867,
  "_runtime": 1576.268884897232,
  "_step": 0,
  "_timestamp": 1742820962.8752978,
  "_wandb.runtime": 1580,
  "epoch": 0,
  "test_loss": -7.535277366638184,
  "test_loss_desurv": -4.07185697555542,
  "test_loss_values": -42.169490814208984,
  "trainer/global_step": 0
}

# Next event concordance

In [7]:
Cinter_keys = [_key for _key in raw_data_from_wandb.keys() if "Test:Cintra" in _key ]

decoded_cintra_diagnoses = {}
decoded_cintra_other = {}

for _key in Cinter_keys:
    _event = int(_key[len("Test:Cintra"):])                 # token
    _event_name = dm.decode([_event]).split(" ")[0]         # string
    _event_cintra = raw_data_from_wandb[_key]               # concordance

    if _event_name.upper() == _event_name:
        decoded_cintra_diagnoses = {**decoded_cintra_diagnoses, _event_name: _event_cintra}
    else:
        decoded_cintra_other = {**decoded_cintra_other, _event_name: _event_cintra}


In [8]:
# display(decoded_cintra_diagnoses)
# display(decoded_cintra_other)

In [9]:
BaseCinter_keys = [_key for _key in raw_data_from_wandb.keys() if "Test:base_Cintra" in _key ]

base_decoded_cintra_diagnoses = {}
base_decoded_cintra_other = {}
base_prevalence_diagnoses = {}
base_prevalence_other = {}

for _key in BaseCinter_keys:
    _event = int(_key[len("Test:base_Cintra"):])                 # token
    _event_name = dm.decode([_event]).split(" ")[0]         # string
    _event_cintra = raw_data_from_wandb[_key]               # concordance

    prevalence = dm.tokenizer._event_counts
    prevalence = prevalence.filter(pl.col("EVENT") ==_event_name)["COUNT"][0]

    if _event_name.upper() == _event_name:
        base_decoded_cintra_diagnoses = {**base_decoded_cintra_diagnoses, _event_name: _event_cintra}
        base_prevalence_diagnoses = {**base_prevalence_diagnoses, _event_name: prevalence}
    else:
        base_decoded_cintra_other = {**base_decoded_cintra_other, _event_name: _event_cintra}
        base_prevalence_other = {**base_prevalence_other, _event_name: prevalence}



In [10]:
# display(base_decoded_cintra_diagnoses)
# display(base_decoded_cintra_other)

In [11]:
keys_included_diagnoses = list(set(base_decoded_cintra_diagnoses.keys()) & set(decoded_cintra_diagnoses.keys()))
keys_included_other = list(set(base_decoded_cintra_other.keys()) & set(decoded_cintra_other.keys()))

In [13]:
for dict_name, result_dict, result_dict_base, result_dict_prev, keys_to_include in zip(["diagnoses", "other"],
                                                                     [decoded_cintra_diagnoses, decoded_cintra_other], 
                                                                     [base_decoded_cintra_diagnoses, base_decoded_cintra_other], 
                                                                     [base_prevalence_diagnoses, base_prevalence_other],
                                                                     [keys_included_diagnoses, keys_included_other]
                                                                     ):
    plt.close()
    # plt.figure(figsize=(len(keys_to_include)/5,5))
    fig, ax1 = plt.subplots(figsize=(len(keys_to_include)/4,8))
    ax2 = ax1.twinx()  

    X_axis = np.arange(len(keys_to_include)) 

    Y_base = [result_dict_base[_key] for _key in keys_to_include]
    Y_survivEHR = [result_dict[_key] for _key in keys_to_include]
    Y_log_prevalence = [np.log(result_dict_prev[_key]) for _key in keys_to_include]

    # Sort by prevalence
    arg_sort = np.argsort(Y_log_prevalence)
    Y_base = [Y_base[_i] for _i in arg_sort]
    Y_survivEHR = [Y_survivEHR[_i] for _i in arg_sort]
    Y_log_prevalence = [Y_log_prevalence[_i] for _i in arg_sort]
    keys_to_include = [keys_to_include[_i] for _i in arg_sort]

    width = 0.25
    ax1.bar(X_axis - width, Y_base, width, label = f'Concordance by prevalence (Average over events: {raw_data_from_wandb["Test:base_Cinter"]:.3f})', color="mediumblue") 
    ax1.bar(X_axis, Y_survivEHR, width, label = f'Concordance by SurvivEHR (Average over events: {raw_data_from_wandb["Test:Cinter"]:.3f})', color="firebrick") 
    ax2.plot(X_axis, Y_log_prevalence, width, label='Log-prevalence', color="darkseagreen", marker=".")  #  + width

    ax1.set_xticks(X_axis, keys_to_include, rotation=90) 
    # ax1.xticks(X_axis, keys_to_include) 
    ax1.set_xlabel("Events") 
    ax1.set_ylabel("Self-supervised Concordance") 
    ax2.set_ylabel("Log Prevalence") 
    ax1.legend(loc="upper left")
    ax2.legend(loc="upper right")
    ax1.set_ylim(0, 1.2)
    ax2.set_ylim(np.min(Y_log_prevalence)*0.95, np.max(Y_log_prevalence)*1.1)
    

    # plt.bar(result_dict.keys(), result_dict.values(), 0.5, color='g')
    # ax1.xticks()

    # ybar = raw_data_from_wandb["Test:Cinter"]
    # ax1.plot([0, len(result_dict)-1], 
    #          [ybar, ybar],
    #          label=f"SurvivEHR marginalised over events",
    #          color="firebrick")

    # ybar = raw_data_from_wandb["Test:base_Cinter"]
    # ax1.plot([0, len(result_dict)-1], 
    #          [ybar, ybar],
    #          label=f"Prevalence marginalised over events",
    #          color="mediumblue")
    
    plt.tight_layout()
    plt.savefig(f"figs/metrics/{pre_trained_model}/inter_causal_eval_{dict_name}.png", bbox_inches="tight")
    plt.close()

# Future events

## SurvivEHR

In [12]:
Cinter_keys = [_key for _key in raw_data_from_wandb.keys() if "+" in _key and "base" not in _key ]
print(Cinter_keys)

x_survivEHR, y_survivEHR = [], []
for _key in Cinter_keys:
    x_survivEHR.append(int(_key[len("Test:Cinter+"):]) + 1 )                # steps ahead
    y_survivEHR.append(raw_data_from_wandb[_key] )                   # concordance


arg_sort = np.argsort(x_survivEHR)
x_survivEHR = [x_survivEHR[_i] for _i in arg_sort]
y_survivEHR = [y_survivEHR[_i] for _i in arg_sort]

print(x_survivEHR)
print(y_survivEHR)

['Test:Cinter+0', 'Test:Cinter+1', 'Test:Cinter+10', 'Test:Cinter+13', 'Test:Cinter+16', 'Test:Cinter+19', 'Test:Cinter+2', 'Test:Cinter+3', 'Test:Cinter+4', 'Test:Cinter+7']
[1, 2, 3, 4, 5, 8, 11, 14, 17, 20]
[0.9817719777595192, 0.9291992280795602, 0.911726470329498, 0.8939083026535495, 0.8706372282753405, 0.8150159937232188, 0.8317987308064285, 0.7816167961646554, 0.7681751908287836, 0.7502532407475381]


In [13]:
Cinter_keys = [_key for _key in raw_data_from_wandb.keys() if "+" in _key and "base" in _key ]
print(Cinter_keys)

x_base, y_base = [], []
for _key in Cinter_keys:
    x_base.append(int(_key[len("Test:base_Cinter+"):]) + 1 )                # steps ahead
    y_base.append(raw_data_from_wandb[_key] )                   # concordance

arg_sort = np.argsort(x_base)
x_base = [x_base[_i] for _i in arg_sort]
y_base = [y_base[_i] for _i in arg_sort]

['Test:base_Cinter+0', 'Test:base_Cinter+1', 'Test:base_Cinter+10', 'Test:base_Cinter+13', 'Test:base_Cinter+16', 'Test:base_Cinter+19', 'Test:base_Cinter+2', 'Test:base_Cinter+3', 'Test:base_Cinter+4', 'Test:base_Cinter+7']


In [14]:
plt.close()

plt.plot(x_survivEHR, y_survivEHR,
         label=f"SurvivEHR decay",
         color="firebrick")

plt.plot(x_base, y_base,
         label=f"Prevalence prognosis decay",
         color="mediumblue")

plt.legend()
plt.tight_layout()
plt.savefig(f"figs/{pre_trained_model}_inter_decay.png", bbox_inches="tight")
    

## Comparison of decay across different pre-trained runs

Note: These are not fair comparisons as each run has been trained with different learning schedulers and for different lengths of time

In [21]:
all_survivEHR = {
    "SurvivEHR-cr-small-v1": (
        "baseline: 11M, 384 latent",
        [1, 2, 3, 4, 5, 8, 11, 14, 17, 20],
        [0.9845138238722124, 0.9366127214016844, 0.9158513102720236, 0.899843080451446, 0.868315386991299, 0.8239121250528096, 0.8252010439713312, 0.7866451755111037, 0.7641013179336177, 0.7604122318951223]
        ),
    "SurvivEHR-cr-small-192": (
        "11M, 192 latent",
        [1, 2, 3, 4, 5, 8, 11, 14, 17, 20], 
        [0.9530637980778408, 0.8965454649674218, 0.8732735785430125, 0.8556641921660932, 0.8248607427761983, 0.761096022692981, 0.7637418688146687, 0.7261393067724697, 0.7002773092426885, 0.7196441416973737]
        ),
    "SurvivEHR-cr-384-v1": (
        "129M, 384 latent",
        [1, 2, 3, 4, 5, 8, 11, 14, 17, 20],
        [0.9707854863671966, 0.9347402411296004, 0.9142678409102016, 0.8998330215060252, 0.8802228115580834, 0.8381073088297424, 0.856944131635776, 0.8165261475726017, 0.8077417879299013, 0.8017675470146958]
        ),
    "crPreTrain_small_1337": (
        "1M, 192 latent",
        [1, 2, 3, 4, 5, 8, 11, 14, 17, 20],
        [0.9785524581080296, 0.9322563387278594, 0.9193310330671408, 0.8942603657432556, 0.8602421957376115, 0.808292594604382, 0.8126680886580727, 0.7806386730589077, 0.7502644443809143, 0.7436763216230898]
        ),
                }


In [None]:
plt.close()

for run_id in all_survivEHR.keys():

    plt.plot(all_survivEHR[run_id][1], all_survivEHR[run_id][2],
             label=f"{all_survivEHR[run_id][0]}")

plt.plot(x_base, y_base,
         label=f"Prevalence prognosis decay",
         color="k")

plt.legend()
plt.title("Performance across pre-trained runs (note, different LR schedulers and run times)")
plt.tight_layout()
plt.savefig(f"figs/all_inter_decay.png", bbox_inches="tight")
    

# Deprecated

In [None]:
dm.tokenizer._event_counts

In [None]:
save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

# Load Pre-Trained model

In [None]:
ckpt_path = cfg.experiment.log_dir + f'checkpoints/{cfg.experiment.run_id}.ckpt'
model = SurvivalExperiment.load_from_checkpoint(ckpt_path)

# Initialise fine-tuning data module

In [None]:
# Update dataset path to point to the new dataset 
cfg.data.path_to_ds = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/"

# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)

In [None]:
num_diagnoses = len(dm.train_set.meta_information["diagnosis_table"]["count"])
num_diagnosis_events = sum(dm.train_set.meta_information["diagnosis_table"]["count"])

is_medication = dm.train_set.meta_information["measurement_tables"]["count_obs"] == 0

num_medications = sum(is_medication)
num_medication_events = sum(dm.train_set.meta_information["measurement_tables"][is_medication]["count"])
num_measurement_test = sum(~is_medication)
num_measurement_test_events = sum(dm.train_set.meta_information["measurement_tables"][~is_medication]["count"])

num_measurement_test_events = sum(dm.train_set.meta_information["measurement_tables"][~is_medication]["count_obs"])

print(f'{num_diagnosis_events:,} diagnoses of {num_diagnoses} types')
print(f'{num_medication_events:,} medications of {num_medications} types')
print(f'{num_measurement_test_events:,} measurements and tests of {num_measurement_test} types')
print(f'{num_diagnoses+num_medication_events+num_measurement_test_events:,}')

print(f'{num_measurement_test_events:,}')
dm.train_set.meta_information.keys()

print(dm.train_set.tokenizer._event_counts)

In [None]:
# # import pickle as pkl
# # import pathlib

# pkl_file_to_amend = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle"

# with open(pkl_file_to_amend, 'rb') as pickle_file:
#     content = pickle.load(pickle_file)
# display(content)

# # new_dictionary = {}
# # for key in content.keys():
# #     str_to_remove = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/split=val/"
# #     new_key = str(key)[len(str_to_remove):]
# #     new_dictionary[new_key] = content[key]
# # display(new_dictionary)


# # with open(pkl_file_to_amend, 'wb') as handle:
# #     pickle.dump(new_dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
# new_dictionary

In [None]:
import copy
start = time.time()   # starting time
for batch in dm.train_dataloader():
    # print(batch["tokens"][1,:])
    
    c_batch = convert_batch_to_none_causal(batch)
    # print(c_batch["tokens"][1,:])
    # print(c_batch["target_token"][1])

    # print(batch["tokens"][1,:])
    
    break
    
print(f"batch loaded in {time.time()-start} seconds")    
    
# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

In [None]:
display(batch.keys())
display(c_batch.keys())

print(batch["static_covariates"].shape)

# print(dm.train_set.static_1hot)
# print(dm.train_set.static_1hot["SEX"].categories_)
# print(dm.train_set.static_1hot["IMD"].categories_)
# print(dm.train_set.static_1hot["ETHNICITY"].categories_)

print(batch["tokens"][1,:])
print(c_batch["tokens"][1,:])
print(c_batch["target_token"][1])

## View an example sample

In [None]:
dm.test_set.view_sample(11003, max_dynamic_events=None, report_time=True)

# Custom wrapper prediction last token

To begin with, I will just loop over samples individually to test the zero-shot capacity of SurvivEHR. 

In [None]:


# Verifying on datamodule 
for _idx, batch in enumerate(dm.test_dataloader()):
    if _idx > 10:
        break
    print(_idx)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))
    batch = replace_last_non_pad_with_pad(batch)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))

In [None]:
outcome_of_interest = ["COPD", "SUBSTANCEMISUSE"]
outcome_token = dm.encode(outcome_of_interest)[0]
print(outcome_token)
# print(model(batch))

In [None]:
Hs, labels = [], []
mins,maxes=[],[]
for _idx, batch in enumerate(dm.test_dataloader()):

    batch = replace_last_non_pad_with_pad(batch)
    print(batch["tokens".shape)
    outputs, _, hidden_states = model(batch, is_generation=True)
    print(outputs)
    
    hidden_states = hidden_states.cpu().detach().numpy()                           # (64, 128, 384) 
    Hs.append( hidden_states.reshape(hidden_states.shape[0], -1) )
    labels.append((batch["target_token"] == outcome_token).long().numpy())

    if _idx == 9:
        break



# Visualise hidden dimension labelled by target

In [None]:
import umap
from sklearn.preprocessing import StandardScaler

H = np.concatenate(Hs, 0)
lbl = np.concatenate(labels, 0)

H = StandardScaler().fit_transform(H)
reducer = umap.UMAP()
H_proj = reducer.fit_transform(H)

plt.close()
plt.scatter(H_proj[:,0], H_proj[:,1], c=lbl)
plt.savefig(save_path + f"zero_shot/hidden_umap.png")

In [None]:
print(outputs["surv"]["surv_CDF"][outcome_token].shape)

# The first two tokens in the vocab correspond to the PAD and UNK tokens. There is no CDF corresponding to the PAD token, so the indexing for surv_CDF begins as ["UNK", "ADDISONS_DISEASE", ...]
# print(dm.decode([0,1,2]))

outcomes = ["COPD", "SUBSTANCEMISUSE"]
outcome_tokens = dm.encode(outcomes)

# for outcome in outcomes:
    # observed_outcome_token = dm.encode([outcome])[0]
cdf = np.zeros_like(outputs["surv"]["surv_CDF"][0])
lbls = np.zeros_like(batch["target_token"])

for _outcome_token in outcome_tokens:
    cdf += outputs["surv"]["surv_CDF"][_outcome_token - 1] 
    lbls += (batch["target_token"] == _outcome_token).long().numpy()

plt.close()
cdf_true = cdf[lbls==1,:]
cdf_false = cdf[lbls==0,:]
for i in range(cdf_true.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_true[i,:], c="r", label="outcome occurred next" if i == 0 else None, alpha=1)
for i in range(cdf_false.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_false[i,:], c="k", label="outcome did not occur next" if i == 0 else None, alpha=0.3)

plt.legend(loc=2)
plt.xlabel("days")
plt.ylabel(f"P(t>T) - outcomes={','.join(outcomes)}")
plt.savefig(save_path + f"zero_shot/cdf_outcomes.png")

In [None]:
print(batch["target_token"].unique())
print(len(outputs["surv"]["surv_CDF"]))

In [None]:
dm.decode([2])

In [None]:
outputs["surv"]["surv_CDF"][observed_outcome_token - 1]