# Demo Notebook:
## Survival Transformer For Causal Sequence Modelling 

Including time, and excluding tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT


In [2]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from pycox.evaluation import EvalSurv
from tqdm import tqdm 

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


## Build configurations

In [14]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 128        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    SurvLayer = "Single-Risk"                                  # "Competing-Risk"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 16
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 30
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [15]:
# from CPRD.data.database import queries

# # Get a list of patients which fit a reduced set of criterion
# PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
# conn = sqlite3.connect(PATH_TO_DB)
# cursor = conn.cursor()
# # identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
# identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
# # all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set
# all_identifiers = identifiers2

# if True:
#     # Lets take only the first N for faster run-time
#     N = np.min((len(all_identifiers), 20000))
#     print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
#     identifiers = random.choices(all_identifiers, k=N)
# else:
#     print(f"Using all available {len(all_identifiers)} samples")
#     identifiers = all_identifiers

# # Build 
# dm = FoundationalDataModule(identifiers=identifiers,
#                             tokenizer="tabular",
#                             batch_size=opt.batch_size,
#                             max_seq_length=config.block_size,
#                             unk_freq_threshold=config.unk_freq_threshold,
#                             include_measurements=True,
#                             include_diagnoses=True,
#                             preprocess_measurements=True
#                            )


# vocab_size = dm.train_set.tokenizer.vocab_size

# print(f"{len(dm.train_set)} training samples")
# print(f"{len(dm.val_set)} validation samples")
# print(f"{len(dm.test_set)} test samples")
# print(f"{vocab_size} vocab elements")
# # print(dm.train_set.tokenizer._itos)

In [16]:
# Get a list of patients which fit a reduced set of criterion
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            # include_measurements=True,
                            # drop_missing_data=True,
                            # include_diagnoses=True,
                            # drop_empty_dynamic=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=1
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 74.82M tokens
INFO:root:Creating dataset
INFO:root:Creating hash map
Calculating chunk index splits : 27it [00:05,  4.56it/s]
INFO:root:Creating dataset
INFO:root:Creating hash map
Calculating chunk index splits : 2it [00:00,  8.99it/s]
INFO:root:Creating dataset
INFO:root:Creating hash map
Calculating chunk index splits : 2it [00:00, 13.05it/s]

466364 training patients
17841 validation patients
22034 test patients
184 vocab elements





In [17]:
import time
start = time.time()   # starting time
for batch in dm.train_dataloader():
    break
print(f"batch loaded in {time.time()-start} seconds")    
    
for key in batch.keys():
    print(f"{key}".ljust(20) + f"{batch[key].shape}")


batch loaded in 5.673753976821899 seconds
tokens              torch.Size([16, 128])
ages                torch.Size([16, 128])
values              torch.Size([16, 128])
attention_mask      torch.Size([16, 128])


## Standardisation

This was performed automatically across measurements and tests in the dataloader. The standardisation statistics (bias and scale respectively) are given in the dictionary object. 

We define two mappings to simplify notation later

In [18]:
# display(dm.standardisation_dict)

# standardise = lambda key, v: (v - dm.standardisation_dict[key][0]) / dm.standardisation_dict[key][1]
# unstandardise = lambda key, v: (v * dm.standardisation_dict[key][1]) + dm.standardisation_dict[key][0]

# print(standardise("bmi", 30))
# print(unstandardise("bmi", standardise("bmi", 20)))

## View the frequency of tokens in the extracted data

In [19]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
display(dm.tokenizer._event_counts.head())

EVENT,COUNT,FREQUENCY
str,u32,f64
"""UNK""",0,0.0
"""Plasma_N-termi…",23,3.0742e-07
"""SICKLE_CELL_DI…",123,2e-06
"""CYSTICFIBROSIS…",127,2e-06
"""SYSTEMIC_SCLER…",199,3e-06


In [21]:
# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]

# print(measurements_for_univariate_regression)
# print(dm.encode(measurements_for_univariate_regression))
# print(dm.decode([7,4,3,2]))

## Create models and train

In [22]:
models, m_names = [], []

# My development model
for surv_layer in ["Single-Risk"]: #, "Competing-Risk"]:
    
    ## Create configuration
    config = DemoConfig()
    # Specify which survival head layer to use
    config.SurvLayer = surv_layer   
    # list of univariate measurements to model with Normal distribution
    config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
    
    models.append(SurvStreamGPTForCausalModelling(config, vocab_size).to(device))
    m_names.append(f"SurvStreamGPTForCausalModelling: {surv_layer}")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Single-Risk DeSurvival head. This module predicts a separate survival curve for each possible future event
INFO:root:Internally scaling time in survival head by 1825 days
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1825.0], with delta=1/300
INFO:root:ModuleDict(
  (Token 2): Linear(in_features=384, out_features=2, bias=True)
  (Token 12): Linear(in_features=384, out_features=2, bias=True)
  (Token 13): Linear(in_features=384, out_features=2, bias=True)
  (Token 15): Linear(in_features=384, out_features=2, bias=True)
  (Token 19): Linear(in_features=384, out_features=2, bias=True)
  (Token 33): Linear(in_features=384, out_features=2, bias=True)
  (Token 36): Linear(in_features=384, out_features=2, bias=True)
  (Token 44): Linear(in_features=384, out_features=2, bias=True)
  (Token 45): Linear(in_features=384, out_features=2, bias=

In [23]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_surv = [[] for _ in models]
loss_curves_train_values = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_surv = [[] for _ in models]
loss_curves_val_values = [[] for _ in models]

In [24]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        
        epoch_loss, epoch_surv_loss, epoch_values_loss = 0, 0, 0
        model.train()
        for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
            if i > 50:
                break

            # evaluate the loss
            _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device), 
                                                        ages=batch['ages'].to(device), 
                                                        values=batch['values'].to(device),
                                                        attention_mask=batch['attention_mask'].to(device)   
                                                        )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            # record
            epoch_loss += loss.item()            
            epoch_surv_loss += torch.sum(losses_desurv).item()
            epoch_values_loss += loss_values.item()
        
        epoch_loss /= i
        epoch_surv_loss /= i
        epoch_values_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_surv[m_idx].append(epoch_surv_loss)
        loss_curves_train_values[m_idx].append(epoch_values_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_surv_loss, val_values_loss = 0, 0, 0
                for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.train_dataloader())):
                    if j > 20:
                        break
                    _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device), 
                                                                  ages=batch['ages'].to(device),
                                                                  values=batch['values'].to(device),
                                                                  attention_mask=batch['attention_mask'].to(device)   
                                                                  )
                    # record
                    val_loss += loss.item()                    
                    val_surv_loss += torch.sum(losses_desurv).item()
                    val_values_loss += loss_values.item()
                    
                val_loss /= j
                val_surv_loss /= j
                val_values_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_surv[m_idx].append(val_surv_loss)
                loss_curves_val_values[m_idx].append(val_values_loss)

                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_surv_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f}: ({val_surv_loss:.2f}, {val_values_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    values = torch.tensor([[torch.nan]], device=device)
    
    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
        try:
            _value = unstandardise(_cat, _value)
        except:
            pass
        print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}

Training model `SurvStreamGPTForCausalModelling: Single-Risk`, with 10.865304 M parameters


Training epoch 0:   0%|          | 51/29148 [02:18<21:56:50,  2.72s/it]
Validation epoch 0:   0%|          | 21/29148 [00:24<9:28:13,  1.17s/it]


Epoch 0:	Train loss 237.65: (4.42, 470.88). Val loss 1091136543880417.75: (3.02, 2182273087760832.75)


Training epoch 1:   0%|          | 51/29148 [02:19<22:10:05,  2.74s/it]
Validation epoch 1:   0%|          | 21/29148 [00:24<9:28:15,  1.17s/it]


Epoch 1:	Train loss 60.28: (4.29, 116.28). Val loss 869623810089096.38: (2.99, 1739247620178189.75)


Training epoch 2:   0%|          | 51/29148 [02:21<22:22:53,  2.77s/it]
Validation epoch 2:   0%|          | 21/29148 [00:24<9:27:52,  1.17s/it]


Epoch 2:	Train loss 681.18: (4.27, 1358.09). Val loss 702718261969965.38: (3.04, 1405436523939927.50)


Training epoch 3:   0%|          | 41/29148 [01:49<21:34:49,  2.67s/it]

KeyboardInterrupt



## Comparing output to real data

In [None]:
for batch in dm.train_dataloader():
    break
    
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, _age, _value) in enumerate(zip(conditions[0], batch["ages"][0,:],  batch["values"][0,:])):
    if token == 0 or idx >= 10:
        break
    _cat = dm.decode([token])
    try:
        _value = unstandardise(_cat, _value)
    except:
        pass
        
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")

In [None]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss.png")

# Plot DeSurv loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_surv[m_idx]), len(loss_curves_train_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_surv[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_surv[m_idx]), len(loss_curves_val_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_surv[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss_desurv.png")

plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_values[m_idx]), len(loss_curves_train_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_values[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_values[m_idx]), len(loss_curves_val_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_values[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss_val.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [None]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Depression")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Depression -> Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Depression - > Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
            print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=torch.tensor(value).to(device),
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            for p_idx in range(len(prompts)):
                plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"{desc[p_idx]}")
            plt.legend()
            plt.savefig(f"figs/single_risk/diabetes/{event_name}.png")


## Values: How increasing BMI affects diagnosis risk

In [None]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]
prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, v) for _cat in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    bmi_value = unstandardise("bmi", values[p_idx])
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"BMI {bmi_value.item():.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/single_risk/bmi/{event_name}.png")


## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [None]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["diastolic_blood_pressure"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [60.,70.,80.,90.,100.,120.]]
age = [40]


for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    dbp_value = unstandardise("diastolic_blood_pressure", values[p_idx])
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"DBP {dbp_value.item():.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/single_risk/diastolic_blood_pressure/{event_name}.png")


## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [None]:
measurements_of_interest = ["diastolic_blood_pressure"]
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

    for p_idx, diagnosis in enumerate(diagnoses):
        print(f"\nDiagnosis {diagnosis}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=values,
                                       ages=to_days(age),
                                       is_generation=True)
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")



## Values: How increasing bmi affects value of diastolic_blood_pressure

In [None]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [12.,15.,18.,21.,24.,30.,40.,50.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    for p_idx, value in enumerate(values):
        print(f"\nValues {value.tolist()}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=value,
                                       ages=to_days(age),
                                       is_generation=True)
        
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Appendix: model architectures

In [None]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")

In [None]:
!jupyter nbconvert --to html --no-input single_risk.ipynb