# Demo Notebook:
## Single Risk Survival Transformer For Causal Sequence Modelling 

Including time, tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT/notebooks/SingleRisk


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Demo Version of SurvStreamGPT

## Build configurations

In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_SingleRisk11M", 
                  # overrides=[
                  #     ]
                 )

# Just load in pretrained model
cfg.experiment.train = False
cfg.experiment.test = False
cfg.experiment.log = False
cfg.experiment.run_id = "SR_11M" 
print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 20
  global_diagnoses: false
  path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/
  path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information.pickle
experiment:
  project_name: SurvStreamGPT_${head.SurvLayer}
  run_id: SR_11M
  train: false
  test: false
  verbose: true
  seed: 1337
  log: false
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
optim:
  num_epochs: 1
  learning_rate: 0.0001
  scheduler: CAWarmRestarts
  scheduler_periods: 5000
  scheduler_warmup: true
  lr_cosine_decay_period: 10000000.0
  val_check_interval: 1000
  early_stop: false
  ear

In [4]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

model, dm = run(cfg)     
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")


INFO:root:Running sr experiment on 72 CPUs and 1 GPUs


env: SLURM_NTASKS_PER_NODE=28


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 184 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,343,104 samples
INFO:root:Loaded /rds/projects/g/gokhal

Loaded model with 11.167512 M parameters


In [6]:
dm.train_set.view_sample(100, report_time=True)

# for batch in dm.train_dataloader():
#     break
# print(batch)

Time to retrieve sample index 100 was 0.055548906326293945 seconds

SEX                 | F
IMD                 | 2.0
ETHNICITY           | MISSING
birth_year          | 1954.0

Token                                                                      | Age               | Standardised value
Serum_sodium_24                                                            | 16911             | 0.28              
Serum_triglycerides_105                                                    | 16911             | -0.09             
TSH___thyroid_stim_hormone_72                                              | 16911             | -0.17             
Total_white_cell_count_18                                                  | 16911             | -0.14             
Diastolic_blood_pressure_5                                                 | 17252             | 0.22              
Systolic_blood_pressure_4                                                  | 17252             | 0.26              
Diastolic_

In [None]:
import polars as pl
pl.Config.set_tbl_rows(200)
pl.Config.set_fmt_str_lengths(100)
display(dm.tokenizer._event_counts)

### Real data

In [7]:
display(dm.meta_information["measurement_tables"])

Unnamed: 0,event,count,count_obs,digest,min,max,mean,approx_lqr,approx_uqr
0,25_Hydroxyvitamin_D2_level_92,782791,693470,"({'m': 0.0, 'c': 9.0}, {'m': 0.1, 'c': 112.0},...",0.000,6.860000e+02,3.908721,-4.563002,10.788817
1,25_Hydroxyvitamin_D3_level_90,809104,781118,"({'m': 0.1, 'c': 3.0}, {'m': 1.0, 'c': 314.0},...",0.000,9.518000e+02,47.148892,-36.406444,121.163127
2,AST___aspartate_transam_SGOT__46,1738489,1680613,"({'m': 0.0, 'c': 3901.0}, {'m': 0.770571428571...",0.000,1.533000e+04,26.619633,3.417134,41.771075
3,AST_serum_level_47,10837982,10485351,"({'m': 0.0, 'c': 53.0}, {'m': 1.8, 'c': 1.0}, ...",-5.000,2.070000e+04,27.251680,4.558863,41.966985
4,Albumin___creatinine_ratio_37,180911,78420,"({'m': -1.0, 'c': 1.0}, {'m': 0.0, 'c': 4213.0...",-1.000,1.282100e+04,10.672548,-4.329018,8.832383
...,...,...,...,...,...,...,...,...,...
103,Total_cholesterol_HDL_ratio_95,15760772,15489099,"({'m': 0.0, 'c': 21.0}, {'m': 0.001, 'c': 2.0}...",-3.100,3.720369e+09,964.704260,0.445266,6.918591
104,Total_white_cell_count_18,94827537,94179010,"({'m': -0.3, 'c': 1.0}, {'m': 0.0, 'c': 3.0}, ...",-14.700,3.720369e+09,250.189546,1.618844,12.034846
105,Urine_albumin_creatinine_ratio_35,15107807,10249206,"({'m': 0.0, 'c': 213.0}, {'m': 0.01, 'c': 3.0}...",-14.000,1.000000e+10,996.069857,-3.788562,7.911867
106,Urine_microalbumin_creatinine_ratio_36,201318,94009,"({'m': 0.0, 'c': 5691.0}, {'m': 0.001, 'c': 53...",-5.803,3.176800e+04,10.775029,-4.128243,8.251042


## Generation

### Sampling from the model

In [8]:
# Default context start
baseline_covariates = {"sex": "F", "deprivation": 1.0, "ethnicity": "WHITE", "year_of_birth": 1997-40}
prompt = ["O_E___height_1", "O_E___weight_2"]
values = [163, 80]
ages_in_years = [18.2, 18.2]

# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

# Convert for model
covariates = dm.train_set._encode_covariates(**baseline_covariates).reshape(1,-1).to(device)
tokens = encode_prompt(prompt)
values_scaled = encode_value(prompt, values)
ages_in_days = encode_age(ages_in_years)

print(values_scaled)

tensor([[-0.0446,  0.0652]], device='cuda:0')


In [9]:
# generate: sample the next 10 tokens
new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, covariates, max_new_tokens=40)

# report:
print(f"Baseline covariates: \n{baseline_covariates}\n" + "="*90)
print(f"PROMPT:")
for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                new_ages[0, :], 
                                                new_values[0, :]
                                               )
                                           ):
    # _value = dm.unstandardise(_cat, _value)
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}
    if _idx == tokens.shape[-1] - 1:
        print("="*90)
        print(f"GENERATION")



Baseline covariates: 
{'sex': 'F', 'deprivation': 1.0, 'ethnicity': 'WHITE', 'year_of_birth': 1957}
PROMPT:
O_E___height_1                                    -0.04          at age 18 (6643 days)
O_E___weight_2                                    0.07           at age 18 (6643 days)
GENERATION
UNK                                               nan            at age 21 (7585 days)
FIBROMYALGIA                                      nan            at age 22 (7981 days)
Brain_natriuretic_peptide_level_66                -0.33          at age 28 (10262 days)
Lymphocyte_count_20                               -0.07          at age 38 (13760 days)
OSTEOARTHRITIS                                    nan            at age 39 (14095 days)
Calcium_adjusted_level_41                         0.15           at age 39 (14338 days)
TSH_level_74                                      0.02           at age 43 (15767 days)
Combined_total_vitamin_D2_and_D3_level_93         0.41           at age 46 (16802 days)
Brain

# Prompt testing

## Diagnoses: How related conditions are impacted by each other


In [10]:
exp_prompts = [["DEPRESSION"], ["TYPE1DM"], ["TYPE2DIABETES"], ["Never_smoked_tobacco_85"], ["Ex_smoker_84"]]
exp_ages = [[20] for _ in range(len(exp_prompts))]
exp_values = [[np.nan] for _ in range(len(exp_prompts))]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, (_exp_prompt, _exp_age, _exp_value) in enumerate(zip(exp_prompts, 
                                                                    exp_ages, 
                                                                    exp_values)):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        for p_idx in range(len(exp_prompts)):
            plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{'->'.join(exp_prompts[p_idx]).lower()}")
        plt.xlabel("Time (years)")
        plt.ylabel(f"$P(T>t)$ ({event_name})")
        plt.legend()
        plt.savefig(save_path + f"diabetes/{event_name}.png")


## Values: How increasing BMI affects diagnosis risk

In [11]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]

_exp_prompt = ["Body_mass_index_3"]
_exp_age = [40]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel(f"$P(T>t)$ ({event_name})")
            plt.legend()
            plt.savefig(save_path + f"bmi/{event_name}.png")


## Values: How increasing DBP affects diagnosis risk

In [12]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]


_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [40]
_exp_values = [[60.], [70.], [80.], [90.], [100.], [110.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"diastolic_blood_pressure/{event_name}.png")


## Values: How varying diagnosis affects value of DBP

In [13]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompts = [["DEPRESSION"], ["TYPE2DIABETES"], ["HF_V3"], ["HYPERTENSION"]]
_exp_age = [20]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_prompt in enumerate(_exp_prompts):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


DEPRESSION                    leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.2)
TYPE2DIABETES                 leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.2)
HF_V3                         leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.2)
HYPERTENSION                  leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.2)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [14]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompt = ["Body_mass_index_3"]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)} of {_exp_value[0]}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


Body_mass_index_3 of 18.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)
Body_mass_index_3 of 21.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)
Body_mass_index_3 of 24.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)
Body_mass_index_3 of 30.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.2)
Body_mass_index_3 of 40.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.0, 0.1)


## Baseline, impact of gender

In [15]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "POLYCYSTIC_OVARIAN_SYNDROME_PCOS_V2",
                      "DEATH"
                     ]

_genders = ["M", "F", "I"]
_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [20]
_exp_value = [90.]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _gender in enumerate(_genders):

        _baseline_covariate = {"sex": _gender, "deprivation": 4.0, "ethnicity": "WHITE", "year_of_birth": 1997}
        _covariates = dm.train_set._encode_covariates(**_baseline_covariate).reshape(1,-1).to(device)
        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=_covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_genders)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_genders[p_idx]}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"gender/{event_name}.png")


# Appendix: model architectures

In [16]:
display(model)

SurvStreamGPTForCausalModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (static_proj): Linear(in_features=16, out_features=384, bias=True)
      (dynamic_embedding_layer): SplitDynamicEmbeddingLayer(
        (cat_event_embed_layer): Embedding(184, 384, padding_idx=0)
        (cat_event_proj): Linear(in_features=384, out_features=384, bias=True)
        (num_value_embed_layer): EmbeddingBag(184, 384, mode='sum', padding_idx=0)
        (num_value_proj): Linear(in_features=384, out_features=384, bias=True)
      )
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_pr

In [17]:
!jupyter nbconvert --to html --no-input generation.ipynb

[NbConvertApp] Converting notebook generation.ipynb to html
[NbConvertApp] Writing 588210 bytes to generation.html
