# Demo Notebook:
## Survival Transformer For Causal Sequence Modelling 

Including time, and excluding tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT


In [12]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from pycox.evaluation import EvalSurv
from tqdm import tqdm

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


## Build configurations

In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 128        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    SurvLayer = "Competing-Risk"                                  # "Competing-Risk"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 30
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [84]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=20,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,842,428 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,417,644 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,169,583 samples


184 vocab elements


In [85]:
import pandas as pd
pd.set_option('display.max_rows', 1000) #replace n with the number of columns you want to see completely
display(dm.train_set.meta_measurement)
display(dm.train_set.meta_information["diagnosis_table"])

Unnamed: 0,event,count,count_obs,digest,min,max,mean,approx_lqr,approx_uqr
0,25_Hydroxyvitamin_D2_level_92,782791,693470,"({'m': 0.0, 'c': 9.0}, {'m': 0.1, 'c': 112.0},...",0.0,686.0,3.908721,-4.699694,10.870832
1,25_Hydroxyvitamin_D3_level_90,809104,781118,"({'m': 0.1, 'c': 3.0}, {'m': 1.0, 'c': 314.0},...",0.0,951.8,47.14889,-36.308194,121.286799
2,AST___aspartate_transam_SGOT__46,1738489,1680613,"({'m': 0.0, 'c': 3901.0}, {'m': 0.770571428571...",0.0,15330.0,26.61963,3.417134,41.771075
3,AST_serum_level_47,10837982,10485351,"({'m': 0.0, 'c': 53.0}, {'m': 1.8, 'c': 1.0}, ...",-5.0,20700.0,27.25168,4.558863,41.966985
4,Albumin___creatinine_ratio_37,180911,78420,"({'m': -1.0, 'c': 1.0}, {'m': 0.0, 'c': 4213.0...",-1.0,12821.0,10.67255,-4.329046,8.827713
5,Basophil_count_22,86869779,85642540,"({'m': 0.0, 'c': 37098.0}, {'m': 0.01, 'c': 28...",-0.1,111111.0,0.05008992,-0.093801,0.160919
6,Blood_calcium_level_38,415717,385464,"({'m': 0.0, 'c': 33.0}, {'m': 1.0, 'c': 1.0}, ...",0.0,440.0,2.35298,2.025402,2.62252
7,Blood_urea_28,785766,671861,"({'m': 0.0, 'c': 2746.0}, {'m': 0.09, 'c': 1.0...",0.0,1265.0,6.513018,0.270987,10.954279
8,Body_mass_index_3,99868822,97759312,"({'m': 0.0, 'c': 14.0}, {'m': 0.05, 'c': 1.0},...",-32680.0,2100000000.0,293.305,10.476686,43.320395
9,Brain_natriuretic_peptide_level_66,229202,159318,"({'m': 0.0, 'c': 120.0}, {'m': 0.1, 'c': 1.0},...",0.0,500142.0,416.8786,-245.175243,483.219601


Unnamed: 0,event,count
0,ADDISONS_DISEASE,6691
1,ADDISON_DISEASE,11794
2,AF,731332
3,ALCOHOLMISUSE_V2,1125212
4,ALLCANCER_NOHAEM_NOBCC,1496973
5,ALLERGICRHINITISCONJ,3291165
6,ALL_DEMENTIA,528602
7,ANXIETY,3560978
8,ANY_DEAFNESS_HEARING_LOSS_V2,2282766
9,AORTICANEURYSM_V2,101134


0.3508530089269363


In [7]:
# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]

display(measurements_for_univariate_regression)
# print(dm.encode(measurements_for_univariate_regression))
# print(dm.decode([7,4,3,2]))

['Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70',
 'N_terminal_pro_brain_natriuretic_peptide_level_67',
 'Plasma_B_natriuretic_peptide_level_69',
 'Plasma_pro_brain_natriuretic_peptide_level_64',
 'Albumin___creatinine_ratio_37',
 'Urine_microalbumin_creatinine_ratio_36',
 'Plasma_ferritin_level_62',
 'Brain_natriuretic_peptide_level_66',
 'Serum_pro_brain_natriuretic_peptide_level_65',
 'Serum_vitamin_D2_level_89',
 'Total_25_hydroxyvitamin_D_level_91',
 'Serum_N_terminal_pro_B_type_natriuretic_peptide_conc_68',
 'Blood_calcium_level_38',
 'INR___international_normalised_ratio_81',
 'Combined_total_vitamin_D2_and_D3_level_93',
 'TSH_level_74',
 'Serum_T4_level_78',
 'Plasma_cholesterol_HDL_ratio_96',
 'Plasma_free_T4_level_77',
 '25_Hydroxyvitamin_D2_level_92',
 'Blood_urea_28',
 '25_Hydroxyvitamin_D3_level_90',
 'Plasma_corrected_calcium_level_43',
 'Serum_25_Hydroxy_vitamin_D3_level_88',
 'Plasma_calcium_level_40',
 'Free_T4_level_76',
 'Plasma_LDL_cholesterol_level_104',

## Create models and train

In [99]:
models, m_names = [], []

# My development model
for surv_layer in ["Competing-Risk"]: #, "Single-Risk"]:
    
    ## Create configuration
    config = DemoConfig()
    # Specify which survival head layer to use
    config.SurvLayer = surv_layer   
    # list of univariate measurements to model with Normal distribution
    config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
    
    models.append(SurvStreamGPTForCausalModelling(config, vocab_size).to(device))
    m_names.append(f"SurvStreamGPTForCausalModelling: {surv_layer}")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Single-Risk DeSurvival head. This module predicts a separate survival curve for each possible future event
INFO:root:Internally scaling time in survival head by 1825 days
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1825.0], with delta=1/300
INFO:root:ModuleDict(
  (Token 15): Linear(in_features=384, out_features=2, bias=True)
  (Token 17): Linear(in_features=384, out_features=2, bias=True)
  (Token 24): Linear(in_features=384, out_features=2, bias=True)
  (Token 26): Linear(in_features=384, out_features=2, bias=True)
  (Token 41): Linear(in_features=384, out_features=2, bias=True)
  (Token 46): Linear(in_features=384, out_features=2, bias=True)
  (Token 49): Linear(in_features=384, out_features=2, bias=True)
  (Token 50): Linear(in_features=384, out_features=2, bias=True)
  (Token 52): Linear(in_features=384, out_features=2, bias

In [100]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_surv = [[] for _ in models]
loss_curves_train_values = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_surv = [[] for _ in models]
loss_curves_val_values = [[] for _ in models]

In [101]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        
        epoch_loss, epoch_surv_loss, epoch_values_loss = 0, 0, 0
        model.train()
        for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
            if i > 1000:
                break
                
            # evaluate the loss
            _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device), 
                                                          ages=batch['ages'].to(device), 
                                                          values=batch['values'].to(device),
                                                          attention_mask=batch['attention_mask'].to(device)   
                                                          )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            # record
            epoch_loss += loss.item()            
            epoch_surv_loss += torch.sum(losses_desurv).item()
            epoch_values_loss += loss_values.item()
        
        epoch_loss /= i
        epoch_surv_loss /= i
        epoch_values_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_surv[m_idx].append(epoch_surv_loss)
        loss_curves_train_values[m_idx].append(epoch_values_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_surv_loss, val_values_loss = 0, 0, 0
                for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                    if j > 100:
                        break
                    _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device), 
                                                                  ages=batch['ages'].to(device),
                                                                  values=batch['values'].to(device),
                                                                  attention_mask=batch['attention_mask'].to(device)   
                                                                  )
                    # record
                    val_loss += loss.item()                    
                    val_surv_loss += torch.sum(losses_desurv).item()
                    val_values_loss += loss_values.item()
                    
                val_loss /= j
                val_surv_loss /= j
                val_values_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_surv[m_idx].append(val_surv_loss)
                loss_curves_val_values[m_idx].append(val_values_loss)

                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_surv_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f}: ({val_surv_loss:.2f}, {val_values_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning

            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

                # Save best seen model
                torch.save(model.state_dict(), path_to_db + "polars/CR.pt")
            


Training model `SurvStreamGPTForCausalModelling: Competing-Risk`, with 11.006427 M parameters


Training epoch 0:   0%|          | 1001/356913 [06:00<35:37:56,  2.77it/s]
Validation epoch 0:   1%|          | 101/18275 [00:26<1:18:56,  3.84it/s]


Epoch 0:	Train loss -3.65: (0.15, -7.45). Val loss -6.51: (-1.29, -11.72)


Training epoch 1:   0%|          | 1001/356913 [05:53<34:52:26,  2.83it/s]
Validation epoch 1:   1%|          | 101/18275 [00:26<1:18:41,  3.85it/s]


Epoch 1:	Train loss -6.64: (-1.79, -11.48). Val loss -8.20: (-1.96, -14.44)


Training epoch 2:   0%|          | 1001/356913 [05:54<34:58:56,  2.83it/s]
Validation epoch 2:   1%|          | 101/18275 [00:26<1:18:48,  3.84it/s]


Epoch 2:	Train loss -7.53: (-2.25, -12.82). Val loss -8.85: (-2.30, -15.39)


Training epoch 3:   0%|          | 1001/356913 [05:51<34:43:41,  2.85it/s]
Validation epoch 3:   1%|          | 101/18275 [00:26<1:18:34,  3.86it/s]


Epoch 3:	Train loss -8.20: (-2.55, -13.84). Val loss -9.07: (-2.55, -15.59)


Training epoch 4:   0%|          | 1001/356913 [05:51<34:43:58,  2.85it/s]
Validation epoch 4:   1%|          | 101/18275 [00:26<1:18:40,  3.85it/s]


Epoch 4:	Train loss -8.37: (-2.77, -13.98). Val loss -9.75: (-2.77, -16.74)


Training epoch 5:   0%|          | 1001/356913 [05:55<35:06:11,  2.82it/s]
Validation epoch 5:   1%|          | 101/18275 [00:26<1:18:45,  3.85it/s]


Epoch 5:	Train loss -8.76: (-2.96, -14.57). Val loss -8.90: (-2.90, -14.90)


Training epoch 6:   0%|          | 1001/356913 [05:52<34:51:41,  2.84it/s]
Validation epoch 6:   1%|          | 101/18275 [00:26<1:18:56,  3.84it/s]


Epoch 6:	Train loss -9.06: (-3.13, -14.99). Val loss -10.03: (-3.10, -16.96)


Training epoch 7:   0%|          | 1001/356913 [05:55<35:05:24,  2.82it/s]
Validation epoch 7:   1%|          | 101/18275 [00:26<1:18:56,  3.84it/s]


Epoch 7:	Train loss -9.40: (-3.29, -15.51). Val loss -8.93: (-3.24, -14.61)


Training epoch 8:   0%|          | 1001/356913 [05:52<34:47:46,  2.84it/s]
Validation epoch 8:   1%|          | 101/18275 [00:26<1:18:38,  3.85it/s]


Epoch 8:	Train loss -9.74: (-3.42, -16.07). Val loss -10.61: (-3.36, -17.86)


Training epoch 9:   0%|          | 1001/356913 [05:54<35:02:38,  2.82it/s]
Validation epoch 9:   1%|          | 101/18275 [00:26<1:18:44,  3.85it/s]


Epoch 9:	Train loss -9.88: (-3.54, -16.21). Val loss -10.84: (-3.46, -18.22)


Training epoch 10:   0%|          | 1001/356913 [05:53<34:56:47,  2.83it/s]
Validation epoch 10:   1%|          | 101/18275 [00:26<1:18:37,  3.85it/s]


Epoch 10:	Train loss -9.90: (-3.63, -16.17). Val loss -10.56: (-3.53, -17.58)


Training epoch 11:   0%|          | 42/356913 [00:17<42:06:52,  2.35it/s]
Exception ignored in: <function WeakValueDictionary.__init__.<locals>.remove at 0x7f8471982e60>
Traceback (most recent call last):
  File "/rds/bear-apps/2022a/EL8-ice/software/Python/3.10.4-GCCcore-11.3.0/lib/python3.10/weakref.py", line 106, in remove
    def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
KeyboardInterrupt: 


In [102]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)
    
    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    values = torch.tensor([[torch.nan]], device=device)
    
    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
        try:
            _value = unstandardise(_cat, _value)
        except:
            pass
        print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}

Training model `SurvStreamGPTForCausalModelling: Competing-Risk`, with 11.006427 M parameters
DEPRESSION                                        nan            at age 20 (7300.0 days)
ANXIETY                                           nan            at age 25 (9125.0 days)
Basophil_count_22                                 0.02           at age 30 (10950.0 days)
Eosinophil_count_21                               0.07           at age 35 (12775.0 days)
Erythrocyte_sedimentation_rate_61                 4.09           at age 40 (14600.0 days)
GFR_calculated_abbreviated_MDRD_34                92.08          at age 45 (16425.0 days)
Haematocrit_15                                    0.46           at age 50 (18250.0 days)
Haemoglobin_estimation_9                          154.31         at age 55 (20075.0 days)
Lymphocyte_count_20                               1.71           at age 60 (21900.0 days)
Mean_corpusc_haemoglobin_MCH__13                  29.97          at age 65 (23725.0 days)
Mean_cor

## Comparing output to real data

In [103]:
for batch in dm.train_dataloader():
    break
    
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, _age, _value) in enumerate(zip(conditions[0], batch["ages"][0,:],  batch["values"][0,:])):
    if token == 0 or idx >= 10:
        break
    _cat = dm.decode([token])
    try:
        _value = dm.unstandardise(_cat, _value)
    except:
        pass
        
    print(f"{_cat}".ljust(50) + f"{_value:.05f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")

Red_blood_cell__RBC__count_10                     5.27000        at age 53 (19414.0 days)
Red_blood_cell_distribution_width_17              13.20000       at age 53 (19414.0 days)
Serum_HDL_cholesterol_level_100                   0.90000        at age 53 (19414.0 days)
Serum_LDL_cholesterol_level_102                   2.20000        at age 53 (19414.0 days)
Serum_TSH_level_71                                1.10000        at age 53 (19414.0 days)
Serum_alanine_aminotransferase_level_45           22.00000       at age 53 (19414.0 days)
Serum_albumin_51                                  39.00000       at age 53 (19414.0 days)
Serum_alkaline_phosphatase_50                     116.00001      at age 53 (19414.0 days)
Serum_cholesterol_HDL_ratio_94                    3.80000        at age 53 (19414.0 days)
Serum_creatinine_31                               78.00000       at age 53 (19414.0 days)


In [105]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/competing_risk/loss.png")

# Plot DeSurv loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_surv[m_idx]), len(loss_curves_train_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_surv[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_surv[m_idx]), len(loss_curves_val_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_surv[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/competing_risk/loss_desurv.png")

plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_values[m_idx]), len(loss_curves_train_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_values[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_values[m_idx]), len(loss_curves_val_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_values[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/competing_risk/loss_val.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [106]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Depression")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Depression -> Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Depression - > Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
            print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=torch.tensor(value).to(device),
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            for p_idx in range(len(prompts)):
                plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"{desc[p_idx]}")
            plt.legend()
            plt.savefig(f"figs/competing_risk/diabetes/{event_name}.png")




SurvStreamGPTForCausalModelling: Competing-Risk
--------------------------------------

Depression: 	 (DEPRESSION): 

Depression -> Type 1: 	 (DEPRESSION,TYPE1DM): 

Depression - > Type 2: 	 (DEPRESSION,TYPE2DIABETES): 


## Values: How increasing BMI affects diagnosis risk

In [107]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2"
                     ]
prompt = ["Body_mass_index_3"]
values = [torch.tensor([dm.standardise(_cat, v) for _cat in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
# values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(0,1,5)]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    bmi_value = dm.unstandardise("Body_mass_index_3", values[p_idx])
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"BMI {bmi_value:.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/competing_risk/bmi/{event_name}.png")




SurvStreamGPTForCausalModelling: Competing-Risk
--------------------------------------
Value tensor([0.0464], device='cuda:0')
Value tensor([0.1377], device='cuda:0')
Value tensor([0.2291], device='cuda:0')
Value tensor([0.3204], device='cuda:0')
Value tensor([0.4117], device='cuda:0')
Value tensor([0.5944], device='cuda:0')
Value tensor([0.8989], device='cuda:0')


## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [108]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2"
                     ]

prompt = ["Diastolic_blood_pressure_5"]
values = [torch.tensor([dm.standardise(_cat, _value) for _cat in prompt], device=device) for _value in [60.,70.,80.,90.,100.,120.]]
# values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(0,1,5)]
age = [40]


for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    dbp_value = unstandardise("Diastolic_blood_pressure_5", values[p_idx])
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"DBP {dbp_value.item():.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/competing_risk/diastolic_blood_pressure/{event_name}.png")




SurvStreamGPTForCausalModelling: Competing-Risk
--------------------------------------
Value tensor([0.1955], device='cuda:0')
Value tensor([0.3697], device='cuda:0')
Value tensor([0.5439], device='cuda:0')
Value tensor([0.7181], device='cuda:0')
Value tensor([0.8923], device='cuda:0')
Value tensor([1.2407], device='cuda:0')


## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [109]:
measurements_of_interest = ["Diastolic_blood_pressure_5"]
t1_token = dm.tokenizer._stoi["Diastolic_blood_pressure_5"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF_V3"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

    for p_idx, diagnosis in enumerate(diagnoses):
        print(f"\nDiagnosis {diagnosis}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=values,
                                       ages=to_days(age),
                                       is_generation=True)
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")





SurvStreamGPTForCausalModelling: Competing-Risk
--------------------------------------

Diagnosis ['DEPRESSION']
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Diagnosis ['TYPE2DIABETES']
standardised diastolic_blood_pressure ~ N(0.6, 0.2)

Diagnosis ['HF_V3']
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Diagnosis ['HYPERTENSION']
standardised diastolic_blood_pressure ~ N(0.6, 0.2)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [110]:
t1_token = dm.tokenizer._stoi["Diastolic_blood_pressure_5"]

prompt = ["Body_mass_index_3"]
values = [torch.tensor([dm.standardise(_cat, _value) for _cat in prompt], device=device) for _value in [12.,15.,18.,21.,24.,30.,40.,50.]]
# values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(-0.2,1.2,5)]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    for p_idx, value in enumerate(values):
        print(f"\nValues {value.tolist()}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=value,
                                       ages=to_days(age),
                                       is_generation=True)
        
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")



SurvStreamGPTForCausalModelling: Competing-Risk
--------------------------------------

Values [0.04638069495558739]
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Values [0.13772238790988922]
standardised diastolic_blood_pressure ~ N(0.4, 0.2)

Values [0.22906407713890076]
standardised diastolic_blood_pressure ~ N(0.4, 0.2)

Values [0.3204057812690735]
standardised diastolic_blood_pressure ~ N(0.4, 0.2)

Values [0.4117474853992462]
standardised diastolic_blood_pressure ~ N(0.4, 0.2)

Values [0.5944308638572693]
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Values [0.8989031910896301]
standardised diastolic_blood_pressure ~ N(0.6, 0.2)

Values [1.2033754587173462]
standardised diastolic_blood_pressure ~ N(0.6, 0.2)


# Appendix: model architectures

In [111]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")



SurvStreamGPTForCausalModelling: Competing-Risk


SurvStreamGPTForCausalModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (token_embed_layer): Embedding(184, 384, padding_idx=0)
      (value_embed_layer): EmbeddingBag(184, 384, mode=sum, padding_idx=0)
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=False)
          (q_proj): Linear(in_features=384, out_features=384, bias=False)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (ln_2): LayerNorm((384,), eps=1e-05, eleme

In [112]:
!jupyter nbconvert --to html --no-input competing_risk.ipynb

[NbConvertApp] Converting notebook competing_risk.ipynb to html
[NbConvertApp] Writing 688398 bytes to competing_risk.html
