# Demo Notebook:
## Survival Transformer For Causal Sequence Modelling 

Including time, and excluding tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT


In [2]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from pycox.evaluation import EvalSurv

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


## Build configurations

In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 128        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    SurvLayer = "Single-Risk"                                  # "Competing-Risk"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 30
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [4]:
from CPRD.data.database import queries

# Get a list of patients which fit a reduced set of criterion
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()
# identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
# all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set
all_identifiers = identifiers2

if True:
    # Lets take only the first N for faster run-time
    N = np.min((len(all_identifiers), 20000))
    print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
    identifiers = random.choices(all_identifiers, k=N)
else:
    print(f"Using all available {len(all_identifiers)} samples")
    identifiers = all_identifiers

# Build 
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            include_measurements=True,
                            include_diagnoses=True,
                            preprocess_measurements=True
                           )


vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{vocab_size} vocab elements")
# print(dm.train_set.tokenizer._itos)

INFO:root:Building polars dataset


Using N=20000 random samples, from the available 129717


INFO:root:Using measurements
INFO:root:Using test/measurement standardisation method: normalise
INFO:root:Removing measurement and test outliers. Using three deviations from mean as cutoff
INFO:root:Using diagnoses
INFO:root:Dropping samples with no dynamic events
INFO:root:Using tabular tokenizer


16440 training samples
914 validation samples
913 test samples
90 vocab elements


## Standardisation

This was performed automatically across measurements and tests in the dataloader. The standardisation statistics (bias and scale respectively) are given in the dictionary object. 

We define two mappings to simplify notation later

In [5]:
display(dm.standardisation_dict)

standardise = lambda key, v: (v - dm.standardisation_dict[key][0]) / dm.standardisation_dict[key][1]
unstandardise = lambda key, v: (v * dm.standardisation_dict[key][1]) + dm.standardisation_dict[key][0]

print(standardise("bmi", 30))
print(unstandardise("bmi", standardise("bmi", 20)))

{'creatinine_ratio': (4.613647798742136, 8.013549757225261),
 'blood_urea': (6.661150210084032, 3.5151115243185727),
 'bmi': (29.96208708895408, 7.104642555249171),
 'eosinophil_count': (0.22112432491405676, 0.18412155821672552),
 'hydroxyvitamin2': (3.3176282051282047, 2.624011268273257),
 'combined_total_vitamin_D2_and_D3_level': (54.18057692307692,
  28.489483861873047),
 'diastolic_blood_pressure': (78.89241975790014, 11.729769809594963),
 'corrected_serum_calcium_level': (2.3180350373559864, 0.12127745858022278),
 'calculated_LDL_cholesterol_level': (2.5496743006577454, 1.0433663664568302),
 'basophil_count': (0.06863472505097562, 0.10071943861842063),
 'blood_calcium': (2.3682051282051284, 0.20732471093066332),
 'calcium_adjusted_level': (2.3160170045781627, 0.10605442751470515),
 'aspartate_transam': (26.941471048513304, 19.44413875598056),
 'brain_natriuretic_peptide_level': (173.09296875, 302.17777324463435),
 'hydroxyvitamin3': (51.70338733431519, 29.183941674076287),
 'serum

0.0053363572834371955
20.0


## View the frequency of tokens in the extracted data

In [6]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

shape: (89, 3)
┌───────────────────────────────────┬────────┬──────────┐
│ EVENT                             ┆ counts ┆ freq     │
│ ---                               ┆ ---    ┆ ---      │
│ str                               ┆ u32    ┆ f64      │
╞═══════════════════════════════════╪════════╪══════════╡
│ UNK                               ┆ 0      ┆ 0.0      │
│ diastolic_blood_pressure          ┆ 354853 ┆ 0.417141 │
│ bmi                               ┆ 141779 ┆ 0.166666 │
│ eosinophil_count                  ┆ 131696 ┆ 0.154813 │
│ basophil_count                    ┆ 88056  ┆ 0.103513 │
│ corrected_serum_calcium_level     ┆ 24429  ┆ 0.028717 │
│ DEPRESSION                        ┆ 13328  ┆ 0.015668 │
│ serum_level                       ┆ 12956  ┆ 0.01523  │
│ calculated_LDL_cholesterol_level  ┆ 11295  ┆ 0.013278 │
│ ANXIETY                           ┆ 7270   ┆ 0.008546 │
│ HYPERTENSION                      ┆ 5272   ┆ 0.006197 │
│ TYPE2DIABETES                     ┆ 4930   ┆ 0.005795 │

In [7]:
# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]

print(measurements_for_univariate_regression)
print(dm.encode(measurements_for_univariate_regression))
print(dm.decode([7,4,3,2]))

['diastolic_blood_pressure', 'bmi', 'eosinophil_count', 'basophil_count', 'corrected_serum_calcium_level', 'serum_level', 'calculated_LDL_cholesterol_level', 'aspartate_transam', 'blood_urea', 'calcium_adjusted_level', 'combined_total_vitamin_D2_and_D3_level', 'hydroxyvitamin3', 'hydroxyvitamin2', 'creatinine_ratio', 'brain_natriuretic_peptide_level', 'blood_calcium']
[2, 3, 4, 5, 6, 8, 9, 16, 23, 26, 30, 40, 41, 54, 67, 81]
DEPRESSION eosinophil_count bmi diastolic_blood_pressure


## Create models and train

In [8]:
models, m_names = [], []

# My development model
for surv_layer in ["Single-Risk"]: #, "Competing-Risk"]:
    
    ## Create configuration
    config = DemoConfig()
    # Specify which survival head layer to use
    config.SurvLayer = surv_layer   
    # list of univariate measurements to model with Normal distribution
    config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
    
    models.append(SurvStreamGPTForCausalModelling(config, vocab_size).to(device))
    m_names.append(f"SurvStreamGPTForCausalModelling: {surv_layer}")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Single-Risk DeSurvival head. This module predicts a separate survival curve for each possible future event
INFO:root:Internally scaling time in survival head by 1825 days
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1825.0], with delta=1/300
INFO:root:ModuleDict(
  (Token 2): Linear(in_features=384, out_features=2, bias=True)
  (Token 3): Linear(in_features=384, out_features=2, bias=True)
  (Token 4): Linear(in_features=384, out_features=2, bias=True)
  (Token 5): Linear(in_features=384, out_features=2, bias=True)
  (Token 6): Linear(in_features=384, out_features=2, bias=True)
  (Token 8): Linear(in_features=384, out_features=2, bias=True)
  (Token 9): Linear(in_features=384, out_features=2, bias=True)
  (Token 16): Linear(in_features=384, out_features=2, bias=True)
  (Token 23): Linear(in_features=384, out_features=2, bias=True)


In [9]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_surv = [[] for _ in models]
loss_curves_train_values = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_surv = [[] for _ in models]
loss_curves_val_values = [[] for _ in models]

In [10]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        
        epoch_loss, epoch_surv_loss, epoch_values_loss = 0, 0, 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # print(i)
            # evaluate the loss
            _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device), 
                                                        ages=batch['ages'].to(device), 
                                                        values=batch['values'].to(device),
                                                        attention_mask=batch['attention_mask'].to(device)   
                                                        )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            # record
            epoch_loss += loss.item()            
            epoch_surv_loss += torch.sum(losses_desurv).item()
            epoch_values_loss += loss_values.item()
        
        epoch_loss /= i
        epoch_surv_loss /= i
        epoch_values_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_surv[m_idx].append(epoch_surv_loss)
        loss_curves_train_values[m_idx].append(epoch_values_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_surv_loss, val_values_loss = 0, 0, 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device), 
                                                                  ages=batch['ages'].to(device),
                                                                  values=batch['values'].to(device),
                                                                  attention_mask=batch['attention_mask'].to(device)   
                                                                  )
                    # record
                    val_loss += loss.item()                    
                    val_surv_loss += torch.sum(losses_desurv).item()
                    val_values_loss += loss_values.item()
                    
                val_loss /= j
                val_surv_loss /= j
                val_values_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_surv[m_idx].append(val_surv_loss)
                loss_curves_val_values[m_idx].append(val_values_loss)

                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_surv_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f}: ({val_surv_loss:.2f}, {val_values_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    values = torch.tensor([[torch.nan]], device=device)
    
    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
        try:
            _value = unstandardise(_cat, _value)
        except:
            pass
        print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}

Training model `SurvStreamGPTForCausalModelling: Single-Risk`, with 10.722272 M parameters
Epoch 0:	Train loss 1.02: (1.71, 0.34). Val loss 0.98: (1.65, 0.30)
Epoch 1:	Train loss 0.81: (1.34, 0.28). Val loss 0.84: (1.39, 0.28)
Epoch 2:	Train loss 0.69: (1.11, 0.27). Val loss 0.72: (1.16, 0.28)
Epoch 3:	Train loss 0.59: (0.92, 0.26). Val loss 0.62: (0.97, 0.28)
Epoch 4:	Train loss 0.51: (0.75, 0.26). Val loss 0.54: (0.80, 0.27)
Epoch 5:	Train loss 0.43: (0.61, 0.25). Val loss 0.46: (0.65, 0.27)
Epoch 6:	Train loss 0.37: (0.48, 0.25). Val loss 0.40: (0.53, 0.28)
Epoch 7:	Train loss 0.31: (0.37, 0.24). Val loss 0.34: (0.41, 0.26)
Epoch 8:	Train loss 0.26: (0.27, 0.24). Val loss 0.29: (0.32, 0.27)
Epoch 9:	Train loss 0.21: (0.19, 0.24). Val loss 0.24: (0.23, 0.26)
Epoch 10:	Train loss 0.18: (0.11, 0.24). Val loss 0.22: (0.16, 0.27)
Epoch 11:	Train loss 0.15: (0.05, 0.24). Val loss 0.17: (0.09, 0.26)
Epoch 12:	Train loss 0.11: (-0.02, 0.23). Val loss 0.15: (0.02, 0.28)
Epoch 13:	Train loss 

## Comparing output to real data

In [11]:
for batch in dm.train_dataloader():
    break
    
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, _age, _value) in enumerate(zip(conditions[0], batch["ages"][0,:],  batch["values"][0,:])):
    if token == 0 or idx >= 10:
        break
    _cat = dm.decode([token])
    try:
        _value = unstandardise(_cat, _value)
    except:
        pass
        
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")

bmi                                               37.70          at age 37 (13413.0 days)
ASTHMA_PUSHASTHMA                                 nan            at age 37 (13684.0 days)
diastolic_blood_pressure                          80.00          at age 40 (14480.0 days)
bmi                                               38.30          at age 41 (14882.0 days)
diastolic_blood_pressure                          70.00          at age 41 (14882.0 days)
diastolic_blood_pressure                          88.00          at age 42 (15470.0 days)
OSTEOARTHRITIS                                    nan            at age 44 (16024.0 days)
aspartate_transam                                 18.00          at age 47 (17231.0 days)
basophil_count                                    0.09           at age 47 (17231.0 days)
eosinophil_count                                  0.53           at age 47 (17231.0 days)


In [12]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss.png")

# Plot DeSurv loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_surv[m_idx]), len(loss_curves_train_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_surv[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_surv[m_idx]), len(loss_curves_val_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_surv[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss_desurv.png")

plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_values[m_idx]), len(loss_curves_train_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_values[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_values[m_idx]), len(loss_curves_val_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_values[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss_val.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [13]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Depression")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Depression -> Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Depression - > Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
            print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=torch.tensor(value).to(device),
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            for p_idx in range(len(prompts)):
                plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"{desc[p_idx]}")
            plt.legend()
            plt.savefig(f"figs/single_risk/diabetes/{event_name}.png")




SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------

Depression: 	 (DEPRESSION): 

Depression -> Type 1: 	 (DEPRESSION,TYPE1DM): 

Depression - > Type 2: 	 (DEPRESSION,TYPE2DIABETES): 


## Values: How increasing BMI affects diagnosis risk

In [14]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]
prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, v) for _cat in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    bmi_value = unstandardise("bmi", values[p_idx])
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"BMI {bmi_value.item():.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/single_risk/bmi/{event_name}.png")




SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------
Value tensor([-2.5282], device='cuda:0')
Value tensor([-2.1060], device='cuda:0')
Value tensor([-1.6837], device='cuda:0')
Value tensor([-1.2614], device='cuda:0')
Value tensor([-0.8392], device='cuda:0')
Value tensor([0.0053], device='cuda:0')
Value tensor([1.4129], device='cuda:0')


## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [15]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["diastolic_blood_pressure"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [60.,70.,80.,90.,100.,120.]]
age = [40]


for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    dbp_value = unstandardise("diastolic_blood_pressure", values[p_idx])
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"DBP {dbp_value.item():.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/single_risk/diastolic_blood_pressure/{event_name}.png")




SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------
Value tensor([-1.6106], device='cuda:0')
Value tensor([-0.7581], device='cuda:0')
Value tensor([0.0944], device='cuda:0')
Value tensor([0.9470], device='cuda:0')
Value tensor([1.7995], device='cuda:0')
Value tensor([3.5046], device='cuda:0')


## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [16]:
measurements_of_interest = ["diastolic_blood_pressure"]
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

    for p_idx, diagnosis in enumerate(diagnoses):
        print(f"\nDiagnosis {diagnosis}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=values,
                                       ages=to_days(age),
                                       is_generation=True)
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")





SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------

Diagnosis ['DEPRESSION']
standardised diastolic_blood_pressure ~ N(-0.1, 0.9)

Diagnosis ['TYPE2DIABETES']
standardised diastolic_blood_pressure ~ N(-0.2, 0.9)

Diagnosis ['HF']
standardised diastolic_blood_pressure ~ N(-0.2, 1.0)

Diagnosis ['HYPERTENSION']
standardised diastolic_blood_pressure ~ N(0.5, 1.1)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [17]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [12.,15.,18.,21.,24.,30.,40.,50.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    for p_idx, value in enumerate(values):
        print(f"\nValues {value.tolist()}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=value,
                                       ages=to_days(age),
                                       is_generation=True)
        
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")



SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------

Values [-2.5282182693481445]
standardised diastolic_blood_pressure ~ N(-0.8, 0.9)

Values [-2.105959177017212]
standardised diastolic_blood_pressure ~ N(-0.7, 1.0)

Values [-1.6837000846862793]
standardised diastolic_blood_pressure ~ N(-0.6, 1.0)

Values [-1.2614409923553467]
standardised diastolic_blood_pressure ~ N(-0.4, 0.9)

Values [-0.8391818404197693]
standardised diastolic_blood_pressure ~ N(-0.2, 0.9)

Values [0.005336357280611992]
standardised diastolic_blood_pressure ~ N(0.1, 0.9)

Values [1.4128667116165161]
standardised diastolic_blood_pressure ~ N(0.5, 1.1)

Values [2.820397138595581]
standardised diastolic_blood_pressure ~ N(0.5, 1.2)


# Appendix: model architectures

In [18]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")



SurvStreamGPTForCausalModelling: Single-Risk


SurvStreamGPTForCausalModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (token_embed_layer): Embedding(90, 384, padding_idx=0)
      (value_embed_layer): EmbeddingBag(90, 384, mode=sum, padding_idx=0)
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=False)
          (q_proj): Linear(in_features=384, out_features=384, bias=False)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (ln_2): LayerNorm((384,), eps=1e-05, elementwis

In [19]:
!jupyter nbconvert --to html --no-input single_risk.ipynb

[NbConvertApp] Converting notebook single_risk.ipynb to html
[NbConvertApp] Writing 597697 bytes to single_risk.html
