# Demo Notebook:
## Time to Event Transformer For Causal Time Series Modelling 

Including time and tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

# Perform sqlite operations on disk
%env SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
%env TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
!echo $SQLITE_TMPDIR
!echo $TMPDIR
!echo $USERPROFILE

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
env: SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
env: TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles



In [2]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from tqdm import tqdm
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal_tabular import TTETransformerForCausalTimeSeriesModelling

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(device)

!pwd
%load_ext autoreload
%autoreload 2

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/tteGPT


## Build configurations

In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 8        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"                                  # "Geometric"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 256
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 50
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [8]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=False,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=10
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Building Polars dataset and saving to /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Chunking by unique practice ID with no inclusion conditions
INFO:root:Creating train/test/val splits using practice_ids
INFO:root:Extracting practice_patient_ids for each practice
                                             Train: 100%|██████████| 1341/1341 [00:24<00:00, 53.87it/s]
                                              Test: 100%|██████████| 75/75 [00:01<00:00, 64.24it/s]
                                        Validation: 100%|██████████| 75/75 [00:01<00:00, 52.72it/s]
INFO:root:Collating train split into a DL friendly format. Generating over practices IDs
  0%|          | 0/1341 [00:00<?, ?it/s]INFO:root:collect time 339.5734267234802
INFO:root:collate time 0.001291513442993164
INFO:root:save time 2.993227243423462
  0%|          | 1/1341 [05:42<127:30:43, 342.57s/it]INFO:root:collect time 276.8205626010895
INFO:root:collate time 0.004151582717895508

In [None]:
# display(dm.train_set.tokenizer._itos)

## View the frequency of tokens in the extracted data

In [None]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
display(dm.tokenizer._event_counts.head())

In [None]:
# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]

display(measurements_for_univariate_regression[:3])
# display(dm.encode(measurements_for_univariate_regression))
# print(dm.decode([7,4,3,2]))

## Create models and train

In [None]:
models, m_names = [], []

# My development model
for tte_layer in ["Exponential"]: #, "Geometric"]:
    
    ## Create configuration
    config = DemoConfig()
    # Specify which TTE layer to use
    config.TTELayer = tte_layer    
    # list of univariate measurements to model with Normal distribution
    config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression)
    
    models.append(TTETransformerForCausalTimeSeriesModelling(config, vocab_size).to(device))
    m_names.append(f"TTETransformerForCausalTimeSeriesModelling: {tte_layer} TTE")

In [None]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_tte = [[] for _ in models]
loss_curves_train_values = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_tte = [[] for _ in models]
loss_curves_val_values = [[] for _ in models]

In [None]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss, epoch_clf_loss, epoch_tte_loss, epoch_values_loss = 0, 0, 0, 0
        model.train()
        for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
            if i > 500:
                break

            # evaluate the loss
            _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                               ages=batch['ages'].to(device), 
                                                               values=batch['values'].to(device),
                                                               attention_mask=batch['attention_mask'].to(device)   
                                                               )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
            # record
            epoch_clf_loss += loss_clf.item()
            epoch_tte_loss += loss_tte.item()
            epoch_values_loss += loss_values.item()
        epoch_loss /= i
        epoch_clf_loss /= i
        epoch_tte_loss /= i
        epoch_values_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_clf[m_idx].append(epoch_clf_loss)
        loss_curves_train_tte[m_idx].append(epoch_tte_loss)
        loss_curves_train_values[m_idx].append(epoch_values_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_clf_loss, val_tte_loss, val_values_loss = 0, 0, 0, 0
                for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                    if j > 50:
                        break
                    _, (loss_clf, loss_tte, loss_values), loss = model(batch['tokens'].to(device), 
                                                                       ages=batch['ages'].to(device),
                                                                       values=batch['values'].to(device),
                                                                       attention_mask=batch['attention_mask'].to(device)   
                                                                       )
                    val_loss += loss.item()
                    
                    # record
                    val_clf_loss += loss_clf.item()
                    val_tte_loss += loss_tte.item()
                    val_values_loss += loss_values.item()
                val_loss /= j
                val_clf_loss /= j
                val_tte_loss /= j
                val_values_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_clf[m_idx].append(val_clf_loss)
                loss_curves_val_tte[m_idx].append(val_tte_loss)
                loss_curves_val_values[m_idx].append(val_values_loss)

                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f}, {val_values_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    values = torch.tensor([[torch.nan]], device=device)
    
    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
        try:
            _value = unstandardise(_cat, _value)
        except:
            pass
        print(f"\t {_cat}:{_value:.02f}, at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}

In [None]:
# batch['attention_mask'].shape
print(batch["values"])

## Comparing output to real data

In [None]:
for batch in dm.train_dataloader():
    break
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, _age, _value) in enumerate(zip(conditions[0], batch["ages"][0,:],  batch["values"][0,:])):
    if token == 0 or idx >= 10:
        break
    _cat = dm.decode([token])
    try:
        _value = unstandardise(_cat, _value)
    except:
        pass
        
    print(f"{_cat}:{_value:.02f}, at age {_age/365:.0f} ({_age:.1f} days)")

In [None]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss.png")

# Plot classifier loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_clf[m_idx]), len(loss_curves_train_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_clf[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_clf[m_idx]), len(loss_curves_val_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_clf[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_clf.png")

# Plot tte loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_tte[m_idx]), len(loss_curves_train_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_tte[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_tte[m_idx]), len(loss_curves_val_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_tte[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_tte.png")

# Plot values loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_values[m_idx]), len(loss_curves_train_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_values[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_values[m_idx]), len(loss_curves_val_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_values[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/TTE_tab/loss_val.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [None]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
            print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                     values=torch.tensor(value).to(device),
                                                     ages=to_days(age),
                                                     is_generation=True)
            probs = torch.nn.functional.softmax(lgts, dim=2)
            print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
            print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

## Values: How increasing BMI affects likelihood of diagnoses

In [None]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, v) for _cat in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Value {value}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100
        
        topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            if i in events_of_interest:
                print(f"\t{i}: {j:.2f}%")


## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [None]:
events_of_interest = ["bmi", "diastolic_blood_pressure", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF", "ISCHAEMICSTROKE"
                     ]

prompt = ["diastolic_blood_pressure"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [60.,70.,80.,90.,100.,120.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Value {value}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100
        
        topk_prob, topk_ind = torch.sort(probs[0,0,:], descending=True)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            if i in events_of_interest:
                print(f"\t{i}: {j:.2f}%")


## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [None]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [39]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, diagnosis in enumerate(diagnoses):
        print(f"\nDiagnosis {diagnosis}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                  values=values,
                                                 ages=to_days(age),
                                                 is_generation=True)
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")



## Values: How increasing bmi affects value of diastolic_blood_pressure

In [None]:
t1_token = dm.tokenizer._stoi["diastolic_blood_pressure"]

prompt = ["bmi"]
values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [12.,15.,18.,21.,24.,30.,40.,50.]]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, value in enumerate(values):
        print(f"Values {value.tolist()}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist, val_dist), _, _ = model(encoded_prompt,
                                                 values=value,
                                                 ages=to_days(age),
                                                 is_generation=True)
        
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Appendix: model architectures

In [None]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")

In [None]:
!jupyter nbconvert --to html --no-input TTE_tabular.ipynb