# Demo Notebook:
## Time to Event Transformer For Causal Time Series Modelling 

Including time and tabular values

In [2]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

# Perform sqlite operations on disk
%env SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
%env TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
!echo $SQLITE_TMPDIR
!echo $TMPDIR
!echo $USERPROFILE

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
env: SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
env: TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles



In [4]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from tqdm import tqdm
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal_tabular import TTETransformerForCausalTimeSeriesModelling
from tqdm import tqdm

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(device)

!pwd
%load_ext autoreload
%autoreload 2

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/tteGPT


## Build configurations

In [5]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    # Training input sequences
    block_size: int = 128  
    # Multi-head attention configurations
    n_layer: int = 12   # 6, 12
    n_head: int = 8    # 6, 12
    n_embd: int = 1024  # 384 , 768
    layer_norm_bias = False
    attention_type = "global"
    window_size = 256            # the window size for local attention
    max_positions = 512          # the maximum sequence length that this model might ever be used with 
    # SA dropouts
    attention_dropout = 0.0
    resid_dropout = 0.0
    dropout = 0.0
    
    #
    
    bias: bool = True
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"                                  # "Geometric"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 50
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [6]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=20,
                            inclusion_conditions=["COUNTRY = 'E'"],
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
# display(measurements_for_univariate_regression)
# list of univariate measurements to model with Normal distribution
config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,912,046 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,207,449 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,226,576 samples


184 vocab elements


## View a single patient

In [7]:
dm.train_set.view_sample(1, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 1 was 0.5601367950439453 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1997.0

Token                                                                      | Age               | Standardised value
Mean_corpusc_Hb_conc__MCHC__14                                             | 1771              | nan               
Mean_corpusc_haemoglobin_MCH__13                                           | 1771              | -0.40             
Mean_corpuscular_volume__MCV__11                                           | 1771              | -0.10             
Monocyte_count_23                                                          | 1771              | 0.06              
Neutrophil_count_19                                                        | 1771              | 0.04              
Platelet_count_12                                                          | 1771              | 0.28              
Red_blood_cell__

## Create models and train

In [10]:
model = TTETransformerForCausalTimeSeriesModelling(config, vocab_size).to(device)

loss_curves_train = []
loss_curves_train_clf = []
loss_curves_train_tte = []
loss_curves_train_values = []

loss_curves_val = []
loss_curves_val_clf = []
loss_curves_val_tte = []
loss_curves_val_values = []

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using ExponentialTTELayer. This module predicts the time until next event as an exponential distribution


In [9]:
print(f"Training model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

best_val, epochs_since_best = np.inf, 0
for epoch in range(opt.epochs):
    epoch_loss, epoch_clf_loss, epoch_tte_loss, epoch_values_loss = 0, 0, 0, 0
    model.train()
    for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
        if i > 1000:
            break

        # evaluate the loss
        _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                           ages=batch['ages'].to(device), 
                                                           values=batch['values'].to(device),
                                                           attention_mask=batch['attention_mask'].to(device)   
                                                           )
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
        # record
        epoch_clf_loss += loss_clf.item()
        epoch_tte_loss += loss_tte.item()
        epoch_values_loss += loss_values.item()
    epoch_loss /= i
    epoch_clf_loss /= i
    epoch_tte_loss /= i
    epoch_values_loss /= i
    loss_curves_train.append(epoch_loss)
    loss_curves_train_clf.append(epoch_clf_loss)
    loss_curves_train_tte.append(epoch_tte_loss)
    loss_curves_train_values.append(epoch_values_loss)

    # evaluate the loss on val set
    with torch.no_grad(): 
        model.eval()
        if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
            val_loss, val_clf_loss, val_tte_loss, val_values_loss = 0, 0, 0, 0
            for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                if j > 100:
                    break
                _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                                   ages=batch['ages'].to(device),
                                                                   values=batch['values'].to(device),
                                                                   attention_mask=batch['attention_mask'].to(device)   
                                                                   )
            
                val_loss += loss.item()
                
                # record
                val_clf_loss += loss_clf.item()
                val_tte_loss += loss_tte.item()
                val_values_loss += loss_values.item()
            val_loss /= j
            val_clf_loss /= j
            val_tte_loss /= j
            val_values_loss /= j
            loss_curves_val.append(val_loss)
            loss_curves_val_clf.append(val_clf_loss)
            loss_curves_val_tte.append(val_tte_loss)
            loss_curves_val_values.append(val_values_loss)

            print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f}, {val_values_loss:.2f})")          
            # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
    
        if val_loss >= best_val:
            epochs_since_best += 1
            if epochs_since_best >= 5:
                break
        else:
            best_val = val_loss
            epochs_since_best = 0



Training model with 154.024665 M parameters


Training epoch 0:   0%|          | 1001/358001 [10:23<61:43:09,  1.61it/s]
Validation epoch 0:   1%|          | 101/19166 [00:27<1:25:56,  3.70it/s]


Epoch 0:	Train loss -0.25  (3.73, -1.11, -3.37). Val loss -0.73 (3.16, -1.98, -3.36)


Training epoch 1:   0%|          | 55/358001 [00:39<70:31:02,  1.41it/s]
Exception ignored in: <function WeakValueDictionary.__init__.<locals>.remove at 0x7f072f6ec280>
Traceback (most recent call last):
  File "/rds/bear-apps/2022a/EL8-ice/software/Python/3.10.4-GCCcore-11.3.0/lib/python3.10/weakref.py", line 106, in remove
    def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
KeyboardInterrupt: 


KeyboardInterrupt: 

Training model with 86.823001 M parameters
Training epoch 0:   0%|          | 1001/358001 [07:15<43:07:25,  2.30it/s]
Validation epoch 0:   1%|          | 101/19166 [00:22<1:09:29,  4.57it/s]

Epoch 0:	Train loss -1.25  (3.63, -1.28, -6.11). Val loss -3.57 (2.76, -2.22, -11.23)




Training model with 154.024665 M parameters

Training epoch 0:   0%|          | 1001/358001 [10:27<62:12:13,  1.59it/s]
Validation epoch 0:   1%|          | 101/19166 [00:27<1:26:54,  3.66it/s]

Epoch 0:	Train loss -0.18  (3.60, -1.14, -3.00). Val loss -2.03 (2.93, -2.16, -6.86)

Training epoch 1:   0%|          | 1001/358001 [10:28<62:13:17,  1.59it/s]
Validation epoch 1:   1%|          | 101/19166 [00:27<1:27:06,  3.65it/s]

Epoch 1:	Train loss -2.77  (2.27, -2.36, -8.22). Val loss -4.16 (1.81, -2.56, -11.74)

Training epoch 2:   0%|          | 1001/358001 [10:28<62:18:36,  1.59it/s]
Validation epoch 2:   1%|          | 101/19166 [00:27<1:27:07,  3.65it/s]

Epoch 2:	Train loss -3.83  (1.61, -2.57, -10.53). Val loss -4.41 (1.61, -2.57, -12.26)


## Generation

In [14]:
# Default context start
prompt = ["O_E___height_1", "O_E___weight_2"]
values = [163, 90]
ages_in_years = [18.2, 18.2]

# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

# Convert for model
tokens = encode_prompt(prompt)
values_scaled = encode_value(prompt, values)
ages_in_days = encode_age(ages_in_years)

In [15]:
# generate: sample the next 10 tokens
new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, max_new_tokens=10)

# report:
print(f"PROMPT:")
for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                new_ages[0, :], 
                                                new_values[0, :]
                                               )
                                           ):
    # _value = dm.unstandardise(_cat, _value)
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}
    if _idx == tokens.shape[-1] - 1:
        print("="*90)
        print(f"GENERATION")

PROMPT:
O_E___height_1                                    -0.04          at age 18 (6643.0 days)
O_E___weight_2                                    0.16           at age 18 (6643.0 days)
GENERATION
O_E___weight_2                                    0.09           at age 18 (6699.0 days)
Systolic_blood_pressure_4                         -0.23          at age 19 (6942.3 days)
Body_mass_index_3                                 -0.13          at age 21 (7743.8 days)
Diastolic_blood_pressure_5                        -0.18          at age 21 (7745.7 days)
O_E___height_1                                    0.01           at age 21 (7762.7 days)
O_E___weight_2                                    -0.04          at age 21 (7781.6 days)
Systolic_blood_pressure_4                         -0.17          at age 21 (7808.5 days)
Basophil_count_22                                 0.22           at age 22 (8143.9 days)
Eosinophil_count_21                               0.14           at age 22 (8147.8 days)
Eo

## Comparing generation to real data

In [16]:
dm.train_set.view_sample(1, max_dynamic_events=10, report_time=True)

Time to retrieve sample index 1 was 0.12309074401855469 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1997.0

Token                                                                      | Age               | Standardised value
Mean_corpuscular_volume__MCV__11                                           | 8000              | -0.20             
Monocyte_count_23                                                          | 8000              | 0.20              
Neutrophil_count_19                                                        | 8000              | 0.33              
Platelet_count_12                                                          | 8000              | 0.13              
Red_blood_cell__RBC__count_10                                              | 8000              | -0.05             
Serum_C_reactive_protein_level_59                                          | 8000              | 0.02              
Serum_TSH_level

In [17]:

# Plot loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train), len(loss_curves_train)) * opt.eval_interval
plt.plot(iterations, loss_curves_train, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val), len(loss_curves_val)) * opt.eval_interval
plt.plot(iterations, loss_curves_val, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/TTE_tab/loss.png")

# Plot Classifier loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_clf), len(loss_curves_train_clf)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_clf, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val_clf), len(loss_curves_val_clf)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_clf, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_clf.png")

# Plot TTE loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_tte), len(loss_curves_train_tte)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_tte, label="train", )
# Validation
iterations = np.linspace(0, len(loss_curves_val_tte), len(loss_curves_val_tte)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_tte, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_tte.png")

# Plot value loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_values), len(loss_curves_train_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_values, label="train", )
# Validation
iterations = np.linspace(0, len(loss_curves_val_values), len(loss_curves_val_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_values, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_values.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [22]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

with torch.no_grad(): 
    model.eval()

    for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
        print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        value = torch.tensor(value).reshape((1,-1)).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                 values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2)
        print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
        print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type


Control: 	 (DEPRESSION): 
torch.Size([1, 1])
torch.Size([1, 1])
	probability of type I diabetes: 0.0302%
	probability of type II diabetes: 0.1716%

Type 1: 	 (DEPRESSION,TYPE1DM): 
torch.Size([1, 2])
torch.Size([1, 2])
	probability of type I diabetes: 0.0321%
	probability of type II diabetes: 0.1864%

Type 2: 	 (DEPRESSION,TYPE2DIABETES): 
torch.Size([1, 2])
torch.Size([1, 2])
	probability of type I diabetes: 0.0342%
	probability of type II diabetes: 0.1831%


## Values: How increasing BMI affects likelihood of diagnoses

In [24]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["Body_mass_index_3"]
# values = [torch.tensor([standardise(_cat, v) for _cat in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(-2,2,10)]
print(values)
age = [40]

# for condition in target_conditions:
#     print(f"Probability of {condition}")
#     target_token = dm.tokenizer._stoi[condition]

for p_idx, value in enumerate(values):
    print(f"Value {value}\n======")
    
    encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
    value = value.reshape((1,-1))
    
    (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                              values=value,
                                             ages=to_days(age),
                                             is_generation=True)
    probs = torch.nn.functional.softmax(lgts, dim=2) * 100
    
    topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
    for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
        if i in events_of_interest:
            print(f"\t{i}: {j:.2f}%")


[tensor([-2.], device='cuda:0'), tensor([-1.5556], device='cuda:0'), tensor([-1.1111], device='cuda:0'), tensor([-0.6667], device='cuda:0'), tensor([-0.2222], device='cuda:0'), tensor([0.2222], device='cuda:0'), tensor([0.6667], device='cuda:0'), tensor([1.1111], device='cuda:0'), tensor([1.5556], device='cuda:0'), tensor([2.], device='cuda:0')]
Value tensor([-2.], device='cuda:0')
	Diastolic_blood_pressure_5: 5.72%
	Body_mass_index_3: 1.66%
	HYPERTENSION: 0.01%
	TYPE2DIABETES: 0.01%
	CKDSTAGE3TO5: 0.00%
	OSTEOARTHRITIS: 0.00%
	TYPE1DM: 0.00%
Value tensor([-1.5556], device='cuda:0')
	Diastolic_blood_pressure_5: 7.37%
	Body_mass_index_3: 1.71%
	TYPE2DIABETES: 0.01%
	HYPERTENSION: 0.01%
	CKDSTAGE3TO5: 0.00%
	OSTEOARTHRITIS: 0.00%
	TYPE1DM: 0.00%
Value tensor([-1.1111], device='cuda:0')
	Diastolic_blood_pressure_5: 10.39%
	Body_mass_index_3: 1.65%
	TYPE2DIABETES: 0.01%
	HYPERTENSION: 0.01%
	CKDSTAGE3TO5: 0.01%
	OSTEOARTHRITIS: 0.00%
	TYPE1DM: 0.00%
Value tensor([-0.6667], device='cuda:0')

## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [26]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["Diastolic_blood_pressure_5"]
# values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [60.,70.,80.,90.,100.,120.]]
values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(-2,2,10)]
age = [40]


# for condition in target_conditions:
#     print(f"Probability of {condition}")
#     target_token = dm.tokenizer._stoi[condition]

for p_idx, value in enumerate(values):
    print(f"Value {value}\n======")
    encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
    value = value.reshape((1,-1))
    
    (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                              values=value,
                                             ages=to_days(age),
                                             is_generation=True)
    probs = torch.nn.functional.softmax(lgts, dim=2) * 100
    
    topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
    for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
        if i in events_of_interest:
            print(f"\t{i}: {j:.2f}%")


Value tensor([-2.], device='cuda:0')
	Diastolic_blood_pressure_5: 1.46%
	Body_mass_index_3: 1.36%
	TYPE2DIABETES: 0.01%
	CKDSTAGE3TO5: 0.01%
	OSTEOARTHRITIS: 0.00%
	HYPERTENSION: 0.00%
	TYPE1DM: 0.00%
Value tensor([-1.5556], device='cuda:0')
	Diastolic_blood_pressure_5: 1.54%
	Body_mass_index_3: 1.18%
	TYPE2DIABETES: 0.01%
	CKDSTAGE3TO5: 0.01%
	OSTEOARTHRITIS: 0.00%
	HYPERTENSION: 0.00%
	TYPE1DM: 0.00%
Value tensor([-1.1111], device='cuda:0')
	Diastolic_blood_pressure_5: 1.49%
	Body_mass_index_3: 0.89%
	TYPE2DIABETES: 0.00%
	CKDSTAGE3TO5: 0.00%
	OSTEOARTHRITIS: 0.00%
	HYPERTENSION: 0.00%
	TYPE1DM: 0.00%
Value tensor([-0.6667], device='cuda:0')
	Diastolic_blood_pressure_5: 1.21%
	Body_mass_index_3: 0.52%
	TYPE2DIABETES: 0.00%
	CKDSTAGE3TO5: 0.00%
	OSTEOARTHRITIS: 0.00%
	HYPERTENSION: 0.00%
	TYPE1DM: 0.00%
Value tensor([-0.2222], device='cuda:0')
	Diastolic_blood_pressure_5: 0.85%
	Body_mass_index_3: 0.25%
	TYPE2DIABETES: 0.00%
	CKDSTAGE3TO5: 0.00%
	HYPERTENSION: 0.00%
	OSTEOARTHRITIS: 0

## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [29]:
# display(dm.tokenizer._stoi)
t1_token = dm.tokenizer._stoi["Diastolic_blood_pressure_5"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [39]


# for condition in target_conditions:
#     print(f"Probability of {condition}")
#     target_token = dm.tokenizer._stoi[condition]

for p_idx, diagnosis in enumerate(diagnoses):
    print(f"\nDiagnosis {diagnosis}\n======")
    encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
    values = values.reshape((1,-1))
    (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                              values=values,
                                             ages=to_days(age),
                                             is_generation=True)
    dist = val_dist[model.value_layer.token_key(t1_token)]
    print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
    # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")




Diagnosis ['DEPRESSION']
standardised diastolic_blood_pressure ~ N(-0.1, 0.2)

Diagnosis ['TYPE2DIABETES']
standardised diastolic_blood_pressure ~ N(-0.1, 0.2)

Diagnosis ['HF']
standardised diastolic_blood_pressure ~ N(-0.1, 0.2)

Diagnosis ['HYPERTENSION']
standardised diastolic_blood_pressure ~ N(-0.1, 0.2)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [30]:
t1_token = dm.tokenizer._stoi["Diastolic_blood_pressure_5"]

prompt = ["Body_mass_index_3"]
# values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [12.,15.,18.,21.,24.,30.,40.,50.]]
values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(-2,2,10)]
age = [40]

# for condition in target_conditions:
#     print(f"Probability of {condition}")
#     target_token = dm.tokenizer._stoi[condition]

for p_idx, value in enumerate(values):
    print(f"Values {value.tolist()}\n======")
    encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
    value = value.reshape((1,-1))

    (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                             values=value,
                                             ages=to_days(age),
                                             is_generation=True)
    
    dist = val_dist[model.value_layer.token_key(t1_token)]
    print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
    # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

Values [-2.0]
standardised diastolic_blood_pressure ~ N(-0.5, 0.2)
Values [-1.5555555820465088]
standardised diastolic_blood_pressure ~ N(-0.4, 0.2)
Values [-1.1111111640930176]
standardised diastolic_blood_pressure ~ N(-0.4, 0.2)
Values [-0.6666666865348816]
standardised diastolic_blood_pressure ~ N(-0.3, 0.2)
Values [-0.2222222238779068]
standardised diastolic_blood_pressure ~ N(-0.1, 0.2)
Values [0.2222222238779068]
standardised diastolic_blood_pressure ~ N(0.0, 0.2)
Values [0.6666666865348816]
standardised diastolic_blood_pressure ~ N(0.1, 0.2)
Values [1.1111111640930176]
standardised diastolic_blood_pressure ~ N(0.2, 0.2)
Values [1.5555555820465088]
standardised diastolic_blood_pressure ~ N(0.3, 0.2)
Values [2.0]
standardised diastolic_blood_pressure ~ N(0.3, 0.2)


# Appendix: model architectures

In [32]:
display(model)

TTETransformerForCausalTimeSeriesModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (static_proj): Linear(in_features=16, out_features=384, bias=True)
      (dynamic_embedding_layer): SplitDynamicEmbeddingLayer(
        (cat_event_embed_layer): Embedding(184, 384, padding_idx=0)
        (cat_event_proj): Linear(in_features=384, out_features=384, bias=True)
        (num_value_embed_layer): EmbeddingBag(184, 384, mode=sum, padding_idx=0)
        (num_value_proj): Linear(in_features=384, out_features=384, bias=True)
      )
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v

In [33]:
!jupyter nbconvert --to html --no-input TTE_tabular.ipynb

[NbConvertApp] Converting notebook TTE_tabular.ipynb to html
[NbConvertApp] Writing 606251 bytes to TTE_tabular.html
