# Demo Notebook:
## Time to Event Transformer For Causal Time Series Modelling 

Including time and tabular values

In [1]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal_tabular import TTETransformerForCausalTimeSeriesModelling

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling


## Build configurations

In [2]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 128        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"                                  # "Geometric"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 50
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [3]:
from CPRD.data.database import queries

# Get a list of patients which fit a reduced set of criterion
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()
# identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
# all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set
all_identifiers = identifiers2

if False:
    # Lets take only the first N for faster run-time
    N = np.min((len(all_identifiers), 10000))
    print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
    identifiers = random.choices(all_identifiers, k=N)
else:
    print(f"Using all available {len(all_identifiers)} samples")
    identifiers = all_identifiers

# Build 
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            include_measurements=True,
                            include_diagnoses=True,
                            preprocess_measurements=True
                           )


vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{vocab_size} vocab elements")
# print(dm.train_set.tokenizer._itos)

INFO:root:Building polars dataset


Using all available 129717 samples


INFO:root:Using measurements
INFO:root:Using test/measurement standardisation method: normalise
INFO:root:Removing measurement and test outliers. Using three deviations from mean as cutoff
INFO:root:Using diagnoses
INFO:root:Dropping samples with no dynamic events
INFO:root:Using tabular tokenizer


107683 training samples
5983 validation samples
5982 test samples
90 vocab elements


## Standardisation

This was performed automatically across measurements and tests in the dataloader. The standardisation statistics (bias and scale respectively) are given in the dictionary object. 

We define two mappings to simplify notation later

In [4]:
display(dm.standardisation_dict)

standardise = lambda key, v: (v - dm.standardisation_dict[key][0]) / dm.standardisation_dict[key][1]
unstandardise = lambda key, v: (v * dm.standardisation_dict[key][1]) + dm.standardisation_dict[key][0]

print(standardise("bmi", 30))
print(unstandardise("bmi", standardise("bmi", 20)))

{'corrected_serum_calcium_level': (2.319034923472716, 0.12451622974555095),
 'blood_calcium': (2.334465408805031, 0.1487655865623607),
 'serum_level': (27.19654609350041, 20.584990688686364),
 'aspartate_transam': (26.791031390134528, 18.877776290025214),
 'bmi': (29.629965163503474, 7.013281083178253),
 'hydroxyvitamin2': (3.268025477707002, 2.8054309314024435),
 'creatinine_ratio': (4.435476477683948, 8.185329269651573),
 'diastolic_blood_pressure': (78.86937661562213, 11.727257179342669),
 'hydroxyvitamin3': (52.36982317356912, 30.382475290251843),
 'basophil_count': (0.06984214216499786, 0.10630741873638763),
 'combined_total_vitamin_D2_and_D3_level': (56.94353322028673,
  29.241759267841857),
 'calculated_LDL_cholesterol_level': (2.5891371173802233, 1.0358897687206567),
 'blood_urea': (6.702299445123703, 4.281261339020057),
 'brain_natriuretic_peptide_level': (156.9857534246576, 291.8915253494199),
 'calcium_adjusted_level': (2.315393830170679, 0.10855075164585075),
 'eosinophil_c

0.052762014256647304
20.0


## View the frequency of tokens in the extracted data

In [5]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

shape: (89, 3)
┌───────────────────────────────────┬─────────┬──────────┐
│ EVENT                             ┆ counts  ┆ freq     │
│ ---                               ┆ ---     ┆ ---      │
│ str                               ┆ u32     ┆ f64      │
╞═══════════════════════════════════╪═════════╪══════════╡
│ UNK                               ┆ 0       ┆ 0.0      │
│ diastolic_blood_pressure          ┆ 2225753 ┆ 0.422385 │
│ bmi                               ┆ 864405  ┆ 0.16404  │
│ eosinophil_count                  ┆ 812964  ┆ 0.154278 │
│ basophil_count                    ┆ 533086  ┆ 0.101165 │
│ corrected_serum_calcium_level     ┆ 149024  ┆ 0.028281 │
│ DEPRESSION                        ┆ 86594   ┆ 0.016433 │
│ serum_level                       ┆ 76803   ┆ 0.014575 │
│ calculated_LDL_cholesterol_level  ┆ 69057   ┆ 0.013105 │
│ ANXIETY                           ┆ 48189   ┆ 0.009145 │
│ HYPERTENSION                      ┆ 32889   ┆ 0.006241 │
│ TYPE2DIABETES                     ┆ 272

In [6]:
# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]

print(measurements_for_univariate_regression)
print(dm.encode(measurements_for_univariate_regression))
print(dm.decode([7,4,3,2]))

['diastolic_blood_pressure', 'bmi', 'eosinophil_count', 'basophil_count', 'corrected_serum_calcium_level', 'serum_level', 'calculated_LDL_cholesterol_level', 'aspartate_transam', 'blood_urea', 'calcium_adjusted_level', 'combined_total_vitamin_D2_and_D3_level', 'hydroxyvitamin3', 'hydroxyvitamin2', 'creatinine_ratio', 'brain_natriuretic_peptide_level', 'blood_calcium']
[2, 3, 4, 5, 6, 8, 9, 17, 24, 26, 30, 39, 40, 57, 71, 83]
DEPRESSION eosinophil_count bmi diastolic_blood_pressure


## Create models and train

In [7]:
models, m_names = [], []

# My development model
for tte_layer in ["Exponential"]: #, "Geometric"]:
    
    ## Create configuration
    config = DemoConfig()
    # Specify which TTE layer to use
    config.TTELayer = tte_layer    
    # list of univariate measurements to model with Normal distribution
    config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
    
    models.append(TTETransformerForCausalTimeSeriesModelling(config, vocab_size).to(device))
    m_names.append(f"TTETransformerForCausalTimeSeriesModelling: {tte_layer} TTE")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using ExponentialTTELayer. This module predicts the time until next event as an exponential distribution
INFO:root:ModuleDict(
  (Token 2): Linear(in_features=384, out_features=2, bias=True)
  (Token 3): Linear(in_features=384, out_features=2, bias=True)
  (Token 4): Linear(in_features=384, out_features=2, bias=True)
  (Token 5): Linear(in_features=384, out_features=2, bias=True)
  (Token 6): Linear(in_features=384, out_features=2, bias=True)
  (Token 8): Linear(in_features=384, out_features=2, bias=True)
  (Token 9): Linear(in_features=384, out_features=2, bias=True)
  (Token 17): Linear(in_features=384, out_features=2, bias=True)
  (Token 24): Linear(in_features=384, out_features=2, bias=True)
  (Token 26): Linear(in_features=384, out_features=2, bias=True)
  (Token 30): Linear(in_features=384, out_features=2, bias=True)
  (Token 39): Linear(in_features=384

In [8]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_tte = [[] for _ in models]
loss_curves_train_values = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_tte = [[] for _ in models]
loss_curves_val_values = [[] for _ in models]

In [9]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss, epoch_clf_loss, epoch_tte_loss, epoch_values_loss = 0, 0, 0, 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                               ages=batch['ages'].to(device), 
                                                               values=batch['values'].to(device),
                                                               attention_mask=batch['attention_mask'].to(device)   
                                                               )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
            # record
            epoch_clf_loss += loss_clf.item()
            epoch_tte_loss += loss_tte.item()
            epoch_values_loss += loss_values.item()
        epoch_loss /= i
        epoch_clf_loss /= i
        epoch_tte_loss /= i
        epoch_values_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_clf[m_idx].append(epoch_clf_loss)
        loss_curves_train_tte[m_idx].append(epoch_tte_loss)
        loss_curves_train_values[m_idx].append(epoch_values_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_clf_loss, val_tte_loss, val_values_loss = 0, 0, 0, 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                                       ages=batch['ages'].to(device),
                                                                       values=batch['values'].to(device),
                                                                       attention_mask=batch['attention_mask'].to(device)   
                                                                       )
                    val_loss += loss.item()
                    
                    # record
                    val_clf_loss += loss_clf.item()
                    val_tte_loss += loss_tte.item()
                    val_values_loss += loss_values.item()
                val_loss /= j
                val_clf_loss /= j
                val_tte_loss /= j
                val_values_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_clf[m_idx].append(val_clf_loss)
                loss_curves_val_tte[m_idx].append(val_tte_loss)
                loss_curves_val_values[m_idx].append(val_values_loss)

                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f}, {val_values_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    values = torch.tensor([[torch.nan]], device=device)
    
    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
        try:
            _value = unstandardise(_cat, _value)
        except:
            pass
        print(f"\t {_cat}:{_value:.02f}, at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}

Training model `TTETransformerForCausalTimeSeriesModelling: Exponential TTE`, with 10.757217 M parameters
Epoch 0:	Train loss 1.87  (1.71, -0.74, 4.65). Val loss 1.83 (1.70, -0.80, 4.59)
Epoch 1:	Train loss 1.68  (1.64, -0.85, 4.24). Val loss 1.63 (1.64, -0.91, 4.16)
Epoch 2:	Train loss 1.66  (1.63, -0.88, 4.24). Val loss 1.60 (1.64, -0.93, 4.09)
Epoch 3:	Train loss 1.63  (1.61, -0.92, 4.18). Val loss 1.63 (1.63, -0.94, 4.20)
Epoch 4:	Train loss 1.58  (1.60, -0.93, 4.07). Val loss 1.57 (1.62, -0.96, 4.06)
Epoch 5:	Train loss 1.56  (1.59, -0.94, 4.04). Val loss 1.54 (1.61, -0.97, 3.99)
Epoch 6:	Train loss 1.55  (1.58, -0.95, 4.01). Val loss 1.52 (1.60, -0.98, 3.95)
Epoch 7:	Train loss 1.55  (1.58, -0.95, 4.01). Val loss 1.60 (1.61, -0.98, 4.17)
Epoch 8:	Train loss 1.52  (1.58, -0.97, 3.95). Val loss 1.54 (1.60, -0.99, 4.01)
Epoch 9:	Train loss 1.52  (1.58, -0.97, 3.94). Val loss 1.52 (1.59, -1.00, 3.96)
Epoch 10:	Train loss 1.50  (1.57, -0.98, 3.89). Val loss 1.51 (1.59, -1.00, 3.94)
Ep

## Comparing output to real data

In [10]:
for batch in dm.train_dataloader():
    break
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, _age, _value) in enumerate(zip(conditions[0], batch["ages"][0,:],  batch["values"][0,:])):
    if token == 0 or idx >= 10:
        break
    _cat = dm.decode([token])
    try:
        _value = unstandardise(_cat, _value)
    except:
        pass
        
    print(f"{_cat}:{_value:.02f}, at age {_age/365:.0f} ({_age:.1f} days)")

bmi:22.90, at age 24 (8791.0 days)
diastolic_blood_pressure:80.00, at age 24 (8791.0 days)
DEPRESSION:nan, at age 29 (10601.0 days)
ANXIETY:nan, at age 29 (10692.0 days)
eosinophil_count:0.10, at age 34 (12259.0 days)
diastolic_blood_pressure:70.00, at age 36 (13122.0 days)


In [11]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss.png")

# Plot classifier loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_clf[m_idx]), len(loss_curves_train_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_clf[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_clf[m_idx]), len(loss_curves_val_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_clf[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_clf.png")

# Plot tte loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_tte[m_idx]), len(loss_curves_train_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_tte[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_tte[m_idx]), len(loss_curves_val_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_tte[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_tte.png")

# Plot values loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_values[m_idx]), len(loss_curves_train_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_values[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_values[m_idx]), len(loss_curves_val_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_values[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_val.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [12]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
            print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                     values=torch.tensor(value).to(device),
                                                     ages=to_days(age),
                                                     is_generation=True)
            probs = torch.nn.functional.softmax(lgts, dim=2)
            print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
            print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------

Control: 	 (DEPRESSION): 
	probability of type I diabetes: 0.0461%
	probability of type II diabetes: 0.0275%

Type 1: 	 (DEPRESSION,TYPE1DM): 
	probability of type I diabetes: 0.2617%
	probability of type II diabetes: 0.1534%

Type 2: 	 (DEPRESSION,TYPE2DIABETES): 
	probability of type I diabetes: 0.4178%
	probability of type II diabetes: 0.1901%


## Values: How increasing BMI affects likelihood of diagnoses

In [18]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, v) for _cat in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Value {value}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100
        
        topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            if i in events_of_interest:
                print(f"\t{i}: {j:.2f}%")




TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------
Value tensor([-2.5138], device='cuda:0')
	diastolic_blood_pressure: 63.54%
	bmi: 13.19%
	TYPE1DM: 0.17%
	OSTEOARTHRITIS: 0.16%
	HYPERTENSION: 0.07%
	TYPE2DIABETES: 0.04%
	HF: 0.01%
	CKDSTAGE3TO5: 0.01%
	ISCHAEMICSTROKE: 0.00%
Value tensor([-2.0860], device='cuda:0')
	diastolic_blood_pressure: 70.36%
	bmi: 10.19%
	OSTEOARTHRITIS: 0.12%
	TYPE1DM: 0.11%
	HYPERTENSION: 0.05%
	TYPE2DIABETES: 0.03%
	HF: 0.00%
	CKDSTAGE3TO5: 0.00%
	ISCHAEMICSTROKE: 0.00%
Value tensor([-1.6583], device='cuda:0')
	diastolic_blood_pressure: 75.80%
	bmi: 7.91%
	OSTEOARTHRITIS: 0.10%
	TYPE1DM: 0.07%
	HYPERTENSION: 0.04%
	TYPE2DIABETES: 0.03%
	HF: 0.00%
	CKDSTAGE3TO5: 0.00%
	ISCHAEMICSTROKE: 0.00%
Value tensor([-1.2305], device='cuda:0')
	diastolic_blood_pressure: 79.70%
	bmi: 6.39%
	OSTEOARTHRITIS: 0.09%
	TYPE2DIABETES: 0.05%
	HYPERTENSION: 0.04%
	TYPE1DM: 0.04%
	HF: 0.00%
	CKDSTAGE3TO5: 0.00%
	ISCHAEMICSTROKE: 0.0

## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [14]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["diastolic_blood_pressure"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [60.,70.,80.,90.,100.,120.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Value {value}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100
        
        topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            if i in events_of_interest:
                print(f"\t{i}: {j:.2f}%")




TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------
Value tensor([-1.6090], device='cuda:0')
	diastolic_blood_pressure: 49.06%
	bmi: 13.46%
	OSTEOARTHRITIS: 0.94%
	TYPE2DIABETES: 0.51%
	HYPERTENSION: 0.40%
	TYPE1DM: 0.15%
	HF: 0.04%
	CKDSTAGE3TO5: 0.04%
	ISCHAEMICSTROKE: 0.02%
Value tensor([-0.7563], device='cuda:0')
	diastolic_blood_pressure: 49.79%
	bmi: 13.67%
	OSTEOARTHRITIS: 0.89%
	HYPERTENSION: 0.55%
	TYPE2DIABETES: 0.50%
	TYPE1DM: 0.14%
	HF: 0.03%
	CKDSTAGE3TO5: 0.02%
	ISCHAEMICSTROKE: 0.02%
Value tensor([0.0964], device='cuda:0')
	diastolic_blood_pressure: 51.23%
	bmi: 13.49%
	HYPERTENSION: 1.20%
	OSTEOARTHRITIS: 0.97%
	TYPE2DIABETES: 0.69%
	TYPE1DM: 0.13%
	HF: 0.03%
	CKDSTAGE3TO5: 0.02%
	ISCHAEMICSTROKE: 0.02%
Value tensor([0.9491], device='cuda:0')
	diastolic_blood_pressure: 54.53%
	bmi: 11.45%
	HYPERTENSION: 4.11%
	OSTEOARTHRITIS: 1.10%
	TYPE2DIABETES: 1.01%
	TYPE1DM: 0.11%
	HF: 0.04%
	CKDSTAGE3TO5: 0.03%
	ISCHAEMICSTROKE: 0.0

## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [15]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [39]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, diagnosis in enumerate(diagnoses):
        print(f"\nDiagnosis {diagnosis}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=values,
                                                 ages=to_days(age),
                                                 is_generation=True)
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")





TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------

Diagnosis ['DEPRESSION']
standardised diastolic_blood_pressure ~ N(0.1, 0.9)

Diagnosis ['TYPE2DIABETES']
standardised diastolic_blood_pressure ~ N(0.2, 0.9)

Diagnosis ['HF']
standardised diastolic_blood_pressure ~ N(0.3, 1.0)

Diagnosis ['HYPERTENSION']
standardised diastolic_blood_pressure ~ N(0.5, 1.1)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [19]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [12.,15.,18.,21.,24.,30.,40.,60.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Values {value.tolist()}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")



TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------
Values [-2.5137970447540283]
standardised diastolic_blood_pressure ~ N(-0.8, 0.9)
Values [-2.0860371589660645]
standardised diastolic_blood_pressure ~ N(-0.7, 0.9)
Values [-1.6582773923873901]
standardised diastolic_blood_pressure ~ N(-0.6, 0.8)
Values [-1.2305175065994263]
standardised diastolic_blood_pressure ~ N(-0.5, 0.8)
Values [-0.8027576804161072]
standardised diastolic_blood_pressure ~ N(-0.3, 0.8)
Values [0.05276201292872429]
standardised diastolic_blood_pressure ~ N(-0.0, 0.9)
Values [1.478628158569336]
standardised diastolic_blood_pressure ~ N(0.4, 1.0)
Values [4.330360412597656]
standardised diastolic_blood_pressure ~ N(0.1, 1.1)


# Appendix: model architectures

In [17]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")



TTETransformerForCausalTimeSeriesModelling: Exponential TTE


TTETransformerForCausalTimeSeriesModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (token_embed_layer): Embedding(90, 384, padding_idx=0)
      (value_embed_layer): EmbeddingBag(90, 384, mode=sum, padding_idx=0)
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=False)
          (q_proj): Linear(in_features=384, out_features=384, bias=False)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (ln_2): LayerNorm((38