# Demo Notebook:
## Time to Event Transformer For Causal Time Series Modelling 

Including time (and TODO: values)

In [1]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal_tabular import TTETransformerForCausalTimeSeriesModelling

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling


## Build configurations

In [2]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 8   #256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"                                  # "Geometric"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 50
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [3]:
from CPRD.data.database import queries

# Get a list of patients which fit a reduced set of criterion
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()
# identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
# all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set
all_identifiers = identifiers2

if False:
    # Lets take only the first N for faster run-time
    N = np.min((len(all_identifiers), 10000))
    print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
    identifiers = random.choices(all_identifiers, k=N)
else:
    print(f"Using all available {len(all_identifiers)} samples")
    identifiers = all_identifiers

# Build 
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer="tabular",    # non-
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            include_measurements=True,
                            include_diagnoses=True
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{vocab_size} vocab elements")
# print(dm.train_set.tokenizer._itos)

INFO:root:Building polars dataset


Using all available 129717 samples


INFO:root:Dropping samples with no dynamic events
INFO:root:Using tabular tokenizer


107683 training samples
5983 validation samples
5982 test samples
90 vocab elements


## View the frequency of tokens in the extracted data

In [4]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

shape: (89, 3)
┌───────────────────────────────────┬─────────┬──────────┐
│ EVENT                             ┆ counts  ┆ freq     │
│ ---                               ┆ ---     ┆ ---      │
│ str                               ┆ u32     ┆ f64      │
╞═══════════════════════════════════╪═════════╪══════════╡
│ UNK                               ┆ 0       ┆ 0.0      │
│ diastolic_blood_pressure          ┆ 2223992 ┆ 0.42246  │
│ bmi                               ┆ 865264  ┆ 0.164362 │
│ eosinophil_count                  ┆ 810927  ┆ 0.15404  │
│ basophil_count                    ┆ 531915  ┆ 0.10104  │
│ corrected_serum_calcium_level     ┆ 149201  ┆ 0.028342 │
│ DEPRESSION                        ┆ 86568   ┆ 0.016444 │
│ serum_level                       ┆ 76301   ┆ 0.014494 │
│ calculated_LDL_cholesterol_level  ┆ 68992   ┆ 0.013105 │
│ ANXIETY                           ┆ 48206   ┆ 0.009157 │
│ HYPERTENSION                      ┆ 32843   ┆ 0.006239 │
│ TYPE2DIABETES                     ┆ 272

In [5]:
# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]

print(measurements_for_univariate_regression)
print(dm.encode(measurements_for_univariate_regression))
print(dm.decode([7,4,3,2]))

['diastolic_blood_pressure', 'bmi', 'eosinophil_count', 'basophil_count', 'corrected_serum_calcium_level', 'serum_level', 'calculated_LDL_cholesterol_level', 'aspartate_transam', 'blood_urea', 'calcium_adjusted_level', 'combined_total_vitamin_D2_and_D3_level', 'hydroxyvitamin3', 'hydroxyvitamin2', 'creatinine_ratio', 'brain_natriuretic_peptide_level', 'blood_calcium']
[2, 3, 4, 5, 6, 8, 9, 18, 24, 26, 32, 39, 40, 57, 72, 83]
DEPRESSION eosinophil_count bmi diastolic_blood_pressure


## Create models and train

In [6]:
models, m_names = [], []

# My development model
for tte_layer in ["Exponential"]: #, "Geometric"]:
    
    ## Create configuration
    config = DemoConfig()
    # Specify which TTE layer to use
    config.TTELayer = tte_layer    
    # list of univariate measurements to model with Normal distribution
    config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
    
    models.append(TTETransformerForCausalTimeSeriesModelling(config, vocab_size).to(device))
    m_names.append(f"TTETransformerForCausalTimeSeriesModelling: {tte_layer} TTE")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using ExponentialTTELayer. This module predicts the time until next event as an exponential distribution
INFO:root:ModuleDict(
  (Token 2): Linear(in_features=384, out_features=2, bias=True)
  (Token 3): Linear(in_features=384, out_features=2, bias=True)
  (Token 4): Linear(in_features=384, out_features=2, bias=True)
  (Token 5): Linear(in_features=384, out_features=2, bias=True)
  (Token 6): Linear(in_features=384, out_features=2, bias=True)
  (Token 8): Linear(in_features=384, out_features=2, bias=True)
  (Token 9): Linear(in_features=384, out_features=2, bias=True)
  (Token 18): Linear(in_features=384, out_features=2, bias=True)
  (Token 24): Linear(in_features=384, out_features=2, bias=True)
  (Token 26): Linear(in_features=384, out_features=2, bias=True)
  (Token 32): Linear(in_features=384, out_features=2, bias=True)
  (Token 39): Linear(in_features=384

In [7]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_tte = [[] for _ in models]
loss_curves_train_values = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_tte = [[] for _ in models]
loss_curves_val_values = [[] for _ in models]

In [8]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss, epoch_clf_loss, epoch_tte_loss, epoch_values_loss = 0, 0, 0, 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                               ages=batch['ages'].to(device), 
                                                               values=batch['values'].to(device),
                                                               attention_mask=batch['attention_mask'].to(device)   
                                                               )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
            # record
            epoch_clf_loss += loss_clf.item()
            epoch_tte_loss += loss_tte.item()
            epoch_values_loss += loss_values.item()
        epoch_loss /= i
        epoch_clf_loss /= i
        epoch_tte_loss /= i
        epoch_values_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_clf[m_idx].append(epoch_clf_loss)
        loss_curves_train_tte[m_idx].append(epoch_tte_loss)
        loss_curves_train_values[m_idx].append(epoch_values_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_clf_loss, val_tte_loss, val_values_loss = 0, 0, 0, 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                                       ages=batch['ages'].to(device),
                                                                       values=batch['values'].to(device),
                                                                       attention_mask=batch['attention_mask'].to(device)   
                                                                       )
                    val_loss += loss.item()
                    
                    # record
                    val_clf_loss += loss_clf.item()
                    val_tte_loss += loss_tte.item()
                    val_values_loss += loss_values.item()
                val_loss /= j
                val_clf_loss /= j
                val_tte_loss /= j
                val_values_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_clf[m_idx].append(val_clf_loss)
                loss_curves_val_tte[m_idx].append(val_tte_loss)
                loss_curves_val_values[m_idx].append(val_values_loss)

                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f}, {val_values_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    values = torch.tensor([[torch.nan]], device=device)
    
    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    #    note, Not considering value yet.
    for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
        print(f"\t {_cat}:{_value:.02f}, at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}


Training model `TTETransformerForCausalTimeSeriesModelling: Exponential TTE`, with 10.757217 M parameters
Epoch 0:	Train loss 11.76  (2.00, -0.27, 33.55). Val loss 3.05 (1.97, -0.33, 7.50)
Epoch 1:	Train loss 5.12  (1.94, -0.36, 13.78). Val loss 2.60 (1.92, -0.44, 6.32)
Epoch 2:	Train loss 2.33  (1.85, -0.62, 5.74). Val loss 2.26 (1.84, -0.75, 5.70)
Epoch 3:	Train loss 2.29  (1.82, -0.77, 5.82). Val loss 2.21 (1.83, -0.80, 5.59)
Epoch 4:	Train loss 2.16  (1.81, -0.80, 5.46). Val loss 2.18 (1.81, -0.82, 5.56)
Epoch 5:	Train loss 2.21  (1.81, -0.82, 5.66). Val loss 2.11 (1.81, -0.85, 5.37)
Epoch 6:	Train loss 2.17  (1.80, -0.84, 5.54). Val loss 2.43 (1.80, -0.86, 6.36)
Epoch 7:	Train loss 2.10  (1.79, -0.86, 5.36). Val loss 2.06 (1.79, -0.87, 5.25)
Epoch 8:	Train loss 2.05  (1.78, -0.88, 5.25). Val loss 2.03 (1.78, -0.90, 5.21)
Epoch 9:	Train loss 2.06  (1.77, -0.89, 5.29). Val loss 2.45 (1.78, -0.93, 6.49)
Epoch 10:	Train loss 2.05  (1.77, -0.90, 5.26). Val loss 2.01 (1.78, -0.91, 5.15)

## Comparing output to real data

In [9]:
for batch in dm.train_dataloader():
    break
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, age, value) in enumerate(zip(conditions[0], batch["ages"][0,:],  batch["values"][0,:])):
    if token == 0 or idx >= 10:
        break
    print(f"{dm.decode([token])}:{value:.02f}, at age {age/365:.0f} ({age:.1f} days)")

basophil_count:0.05, at age 43 (15737.0 days)
eosinophil_count:0.13, at age 43 (15737.0 days)
diastolic_blood_pressure:73.00, at age 46 (16709.0 days)
diastolic_blood_pressure:71.00, at age 46 (16769.0 days)
basophil_count:0.06, at age 46 (16770.0 days)
eosinophil_count:0.11, at age 46 (16770.0 days)
basophil_count:0.04, at age 46 (16825.0 days)
eosinophil_count:0.13, at age 46 (16825.0 days)


In [10]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss.png")

# Plot classifier loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_clf[m_idx]), len(loss_curves_train_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_clf[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_clf[m_idx]), len(loss_curves_val_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_clf[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_clf.png")

# Plot tte loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_tte[m_idx]), len(loss_curves_train_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_tte[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_tte[m_idx]), len(loss_curves_val_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_tte[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_tte.png")

# Plot values loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_values[m_idx]), len(loss_curves_train_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_values[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_values[m_idx]), len(loss_curves_val_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_values[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_val.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [11]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
            model.eval()
    
            for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
                print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
                encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
                (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                         values=torch.tensor(value).to(device),
                                                         ages=to_days(age),
                                                         is_generation=True)
                probs = torch.nn.functional.softmax(lgts, dim=2)
                print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
                print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------

Control: 	 (DEPRESSION): 
	probability of type I diabetes: 0.0207%
	probability of type II diabetes: 0.0457%

Type 1: 	 (DEPRESSION,TYPE1DM): 
	probability of type I diabetes: 0.0529%
	probability of type II diabetes: 0.1771%

Type 2: 	 (DEPRESSION,TYPE2DIABETES): 
	probability of type I diabetes: 0.0467%
	probability of type II diabetes: 0.2782%


## Values: How increasing BMI affects likelihood of diagnoses

In [12]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["bmi", "bmi", "bmi"]
values = [torch.tensor([v for _ in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
age = [39, 40, 41]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Values {value}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100
        
        topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            if i in events_of_interest:
                print(f"\t{i}: {j:.2f}%")




TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------
Values tensor([12., 12., 12.], device='cuda:0')
	diastolic_blood_pressure: 34.98%
	bmi: 21.18%
	TYPE1DM: 0.83%
	HYPERTENSION: 0.05%
	OSTEOARTHRITIS: 0.05%
	TYPE2DIABETES: 0.03%
	CKDSTAGE3TO5: 0.01%
	HF: 0.01%
	ISCHAEMICSTROKE: 0.00%
Values tensor([15., 15., 15.], device='cuda:0')
	diastolic_blood_pressure: 45.70%
	bmi: 18.38%
	TYPE1DM: 0.50%
	HYPERTENSION: 0.05%
	OSTEOARTHRITIS: 0.05%
	TYPE2DIABETES: 0.04%
	CKDSTAGE3TO5: 0.01%
	HF: 0.01%
	ISCHAEMICSTROKE: 0.00%
Values tensor([18., 18., 18.], device='cuda:0')
	diastolic_blood_pressure: 57.82%
	bmi: 14.90%
	TYPE1DM: 0.25%
	HYPERTENSION: 0.06%
	OSTEOARTHRITIS: 0.05%
	TYPE2DIABETES: 0.05%
	CKDSTAGE3TO5: 0.01%
	HF: 0.01%
	ISCHAEMICSTROKE: 0.00%
Values tensor([21., 21., 21.], device='cuda:0')
	diastolic_blood_pressure: 67.61%
	bmi: 11.83%
	TYPE1DM: 0.11%
	TYPE2DIABETES: 0.07%
	HYPERTENSION: 0.07%
	OSTEOARTHRITIS: 0.05%
	CKDSTAGE3TO5: 0.01%
	H

## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [13]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["diastolic_blood_pressure", "diastolic_blood_pressure", "diastolic_blood_pressure"]
values = [torch.tensor([v for _ in prompt], device=device) for v in [60.,70.,80.,90.,100.,120.]]
age = [39, 40, 41]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Values {value}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100
        
        topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            if i in events_of_interest:
                print(f"\t{i}: {j:.2f}%")




TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------
Values tensor([60., 60., 60.], device='cuda:0')
	diastolic_blood_pressure: 43.26%
	bmi: 20.45%
	OSTEOARTHRITIS: 0.15%
	TYPE2DIABETES: 0.11%
	CKDSTAGE3TO5: 0.10%
	TYPE1DM: 0.07%
	HYPERTENSION: 0.06%
	HF: 0.05%
	ISCHAEMICSTROKE: 0.01%
Values tensor([70., 70., 70.], device='cuda:0')
	diastolic_blood_pressure: 44.14%
	bmi: 20.32%
	OSTEOARTHRITIS: 0.22%
	TYPE2DIABETES: 0.19%
	HYPERTENSION: 0.15%
	CKDSTAGE3TO5: 0.11%
	HF: 0.06%
	TYPE1DM: 0.06%
	ISCHAEMICSTROKE: 0.01%
Values tensor([80., 80., 80.], device='cuda:0')
	diastolic_blood_pressure: 45.20%
	bmi: 19.83%
	HYPERTENSION: 0.45%
	TYPE2DIABETES: 0.35%
	OSTEOARTHRITIS: 0.33%
	CKDSTAGE3TO5: 0.11%
	HF: 0.07%
	TYPE1DM: 0.05%
	ISCHAEMICSTROKE: 0.02%
Values tensor([90., 90., 90.], device='cuda:0')
	diastolic_blood_pressure: 50.88%
	bmi: 17.82%
	HYPERTENSION: 1.36%
	TYPE2DIABETES: 0.56%
	OSTEOARTHRITIS: 0.39%
	CKDSTAGE3TO5: 0.11%
	HF: 0.07%
	TYPE1D

## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [16]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [39]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, diagnosis in enumerate(diagnoses):
        print(f"\nDiagnosis {diagnosis}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=values,
                                                 ages=to_days(age),
                                                 is_generation=True)
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")





TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------

Diagnosis ['DEPRESSION']
diastolic_blood_pressure ~ N(80.7, 10.6)

Diagnosis ['TYPE2DIABETES']
diastolic_blood_pressure ~ N(84.4, 10.6)

Diagnosis ['HF']
diastolic_blood_pressure ~ N(85.1, 10.9)

Diagnosis ['HYPERTENSION']
diastolic_blood_pressure ~ N(86.7, 11.0)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [27]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

prompt = ["bmi", "bmi", "bmi"]
values = [torch.tensor([v for _ in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
age = [39, 40, 41]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Values {value.tolist()}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")



TTETransformerForCausalTimeSeriesModelling: Exponential TTE
--------------------------------------
Values [12.0, 12.0, 12.0]
diastolic_blood_pressure ~ N(65.4, 9.9)
Values [15.0, 15.0, 15.0]
diastolic_blood_pressure ~ N(66.8, 9.8)
Values [18.0, 18.0, 18.0]
diastolic_blood_pressure ~ N(68.9, 9.7)
Values [21.0, 21.0, 21.0]
diastolic_blood_pressure ~ N(71.5, 9.7)
Values [24.0, 24.0, 24.0]
diastolic_blood_pressure ~ N(74.6, 9.7)
Values [30.0, 30.0, 30.0]
diastolic_blood_pressure ~ N(79.2, 9.7)
Values [40.0, 40.0, 40.0]
diastolic_blood_pressure ~ N(82.6, 10.0)


# Appendix: model architectures

In [15]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")



TTETransformerForCausalTimeSeriesModelling: Exponential TTE


TTETransformerForCausalTimeSeriesModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (token_embed_layer): Embedding(90, 384, padding_idx=0)
      (value_embed_layer): EmbeddingBag(90, 384, mode=sum, padding_idx=0)
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=False)
          (q_proj): Linear(in_features=384, out_features=384, bias=False)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (ln_2): LayerNorm((38