# Demo Notebook:
## Time to Event Transformer For Causal Sequence Modelling 

Including time, and excluding values

In [1]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal import TTETransformerForCausalSequenceModelling

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/tteGPT


## Build configurations

In [2]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 128        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 50
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [3]:
from CPRD.data.database import queries

# Get a list of patients which fit a reduced set of criterion
PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
conn = sqlite3.connect(PATH_TO_DB)
cursor = conn.cursor()
# identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
# all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set
all_identifiers = identifiers2

if False:
    # Lets take only the first N for faster run-time
    N = np.min((len(all_identifiers), 10000))
    print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
    identifiers = random.choices(all_identifiers, k=N)
else:
    print(f"Using all available {len(all_identifiers)} samples")
    identifiers = all_identifiers

# Build 
dm = FoundationalDataModule(identifiers=identifiers,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            include_measurements=False,
                            include_diagnoses=True,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training samples")
print(f"{len(dm.val_set)} validation samples")
print(f"{len(dm.test_set)} test samples")
print(f"{vocab_size} vocab elements")
# print(dm.train_set.tokenizer._itos)

INFO:root:Building polars dataset


Using all available 129717 samples


INFO:root:Using diagnoses
INFO:root:Dropping samples with no dynamic events
INFO:root:Using tabular tokenizer


107680 training samples
5983 validation samples
5982 test samples
74 vocab elements


In [4]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

shape: (73, 3)
┌───────────────────────────────────┬────────┬──────────┐
│ EVENT                             ┆ counts ┆ freq     │
│ ---                               ┆ ---    ┆ ---      │
│ str                               ┆ u32    ┆ f64      │
╞═══════════════════════════════════╪════════╪══════════╡
│ UNK                               ┆ 0      ┆ 0.0      │
│ DEPRESSION                        ┆ 86522  ┆ 0.177368 │
│ ANXIETY                           ┆ 48177  ┆ 0.098762 │
│ HYPERTENSION                      ┆ 32930  ┆ 0.067506 │
│ TYPE2DIABETES                     ┆ 27362  ┆ 0.056092 │
│ OSTEOARTHRITIS                    ┆ 25391  ┆ 0.052051 │
│ ASTHMA_PUSHASTHMA                 ┆ 25340  ┆ 0.051946 │
│ ATOPICECZEMA                      ┆ 23178  ┆ 0.047514 │
│ ALLERGICRHINITISCONJ              ┆ 18462  ┆ 0.037847 │
│ ANY_DEAFNESS_HEARING_LOSS         ┆ 17233  ┆ 0.035327 │
│ PREVALENT_IBS                     ┆ 11715  ┆ 0.024015 │
│ ALLCA_NOBCC_VFINAL                ┆ 11695  ┆ 0.023974 │

## Create models and train

In [5]:
models, m_names = [], []

# My development model
for tte_layer in ["Exponential", "Geometric"]:
    config = DemoConfig()
    config.TTELayer = tte_layer
    models.append(TTETransformerForCausalSequenceModelling(config, vocab_size).to(device))
    m_names.append(f"TPPTransformerForCausalSequenceModelling: {tte_layer} TTE")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using ExponentialTTELayer. This module predicts the time until next event as an exponential distribution
INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using GeometricTTELayer. This module predicts the time until next event as a geometric distribution, supported on the set {0,1,...}


In [6]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_tte = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_tte = [[] for _ in models]

In [7]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss, epoch_clf_loss, epoch_tte_loss = 0, 0, 0
        model.train()
        for i, batch in enumerate(dm.train_dataloader()):
            # evaluate the loss
            _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                  ages=batch['ages'].to(device), 
                                                  attention_mask=batch['attention_mask'].to(device)  
                                                  )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            # record
            epoch_clf_loss += loss_clf.item()
            epoch_tte_loss += loss_tte.item()
        epoch_loss /= i
        epoch_clf_loss /= i
        epoch_tte_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_clf[m_idx].append(epoch_clf_loss)
        loss_curves_train_tte[m_idx].append(epoch_tte_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_clf_loss, val_tte_loss = 0, 0, 0
                for j, batch in enumerate(dm.val_dataloader()):
                    _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                          ages=batch['ages'].to(device),
                                                          attention_mask=batch['attention_mask'].to(device) 
                                                          )
                    val_loss += loss.item()
                    # record
                    val_clf_loss += loss_clf.item()
                    val_tte_loss += loss_tte.item()
                val_loss /= j
                val_clf_loss /= j
                val_tte_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_clf[m_idx].append(val_clf_loss)
                loss_curves_val_tte[m_idx].append(val_tte_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    # values = torch.tensor([[torch.nan]], device=device)
    # generate: sample the next 10 tokens
    new_tokens, new_ages = model.generate(tokens, ages, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    #    note, Not considering value yet.
    for _cat, _age in zip(generated.split(" "), new_ages[0, :]):
        print(f"\t {_cat} at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}


Training model `TPPTransformerForCausalSequenceModelling: Exponential TTE`, with 10.726465 M parameters
Epoch 0:	Train loss 2.12  (2.81, 1.43). Val loss 2.10 (2.78, 1.42)
Epoch 1:	Train loss 2.07  (2.74, 1.40). Val loss 2.08 (2.75, 1.41)
Epoch 2:	Train loss 2.06  (2.72, 1.40). Val loss 2.08 (2.76, 1.41)
Epoch 3:	Train loss 2.05  (2.71, 1.39). Val loss 2.09 (2.75, 1.42)
Epoch 4:	Train loss 2.04  (2.70, 1.39). Val loss 2.08 (2.74, 1.41)
Epoch 5:	Train loss 2.04  (2.69, 1.39). Val loss 2.08 (2.75, 1.41)
Epoch 6:	Train loss 2.03  (2.68, 1.38). Val loss 2.08 (2.75, 1.41)
Epoch 7:	Train loss 2.03  (2.67, 1.38). Val loss 2.08 (2.75, 1.41)
Epoch 8:	Train loss 2.02  (2.66, 1.38). Val loss 2.08 (2.75, 1.41)




Epoch 9:	Train loss 2.01  (2.65, 1.38). Val loss 2.08 (2.76, 1.41)
	 DEPRESSION at age 20 (7300.0 days)
	 ANXIETY at age 21 (7507.9 days)
	 PREVALENT_IBS at age 24 (8609.7 days)
	 ALCOHOLMISUSE at age 29 (10644.2 days)
	 SUBSTANCEMISUSE at age 34 (12559.1 days)
	 EATINGDISORDERS at age 38 (13740.4 days)
	 STROKEUNSPECIFIED at age 40 (14678.8 days)
	 ENDOMETRIOSIS_ADENOMYOSIS_V2 at age 41 (15134.7 days)
	 ALLCA_NOBCC_VFINAL at age 42 (15294.9 days)
	 OSTEOARTHRITIS at age 45 (16252.7 days)
	 HAEMOCHROMATOSIS at age 46 (16733.5 days)
Training model `TPPTransformerForCausalSequenceModelling: Geometric TTE`, with 10.726465 M parameters
Epoch 0:	Train loss 5.32  (2.83, 7.80). Val loss 5.34 (2.79, 7.89)
Epoch 1:	Train loss 5.26  (2.75, 7.77). Val loss 5.33 (2.77, 7.88)
Epoch 2:	Train loss 5.25  (2.73, 7.77). Val loss 5.32 (2.76, 7.89)
Epoch 3:	Train loss 5.24  (2.72, 7.76). Val loss 5.32 (2.75, 7.88)
Epoch 4:	Train loss 5.23  (2.71, 7.76). Val loss 5.32 (2.75, 7.88)
Epoch 5:	Train loss 5.23 



Epoch 10:	Train loss 5.21  (2.67, 7.75). Val loss 5.32 (2.75, 7.88)
	 DEPRESSION at age 20 (7300.0 days)
	 ANXIETY at age 33 (11871.0 days)
	 ASTHMA_PUSHASTHMA at age 48 (17364.0 days)
	 OSTEOARTHRITIS at age 48 (17612.0 days)
	 HYPERTENSION at age 51 (18797.0 days)
	 SUBSTANCEMISUSE at age 53 (19219.0 days)
	 PREVALENT_IBS at age 53 (19220.0 days)
	 IHD_NOMI at age 62 (22631.0 days)
	 ISCHAEMICSTROKE at age 63 (22859.0 days)
	 STROKEUNSPECIFIED at age 64 (23516.0 days)
	 ALCOHOLMISUSE at age 69 (25014.0 days)


## Comparing output to real data

In [8]:
for batch in dm.train_dataloader():
    break
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, age) in enumerate(zip(conditions[0], batch["ages"][0,:])):
    if token == 0 or idx >= 10:
        break
    print(f"{dm.decode([token])}, at age {age/365:.0f} ({age:.1f} days)")

HYPERTENSION, at age 26 (9611.0 days)
DEPRESSION, at age 26 (9618.0 days)
NAFLD, at age 37 (13455.0 days)
COPD, at age 38 (13994.0 days)
ASTHMA_PUSHASTHMA, at age 38 (14049.0 days)
ALCOHOLMISUSE, at age 41 (15039.0 days)
IHD_NOMI, at age 51 (18632.0 days)
ANXIETY, at age 51 (18643.0 days)
PERIPHERAL_NEUROPATHY, at age 53 (19419.0 days)


In [9]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss.png")

# Plot classifier loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_clf[m_idx]), len(loss_curves_train_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_clf[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_clf[m_idx]), len(loss_curves_val_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_clf[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss_clf.png")

# Plot tte loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_tte[m_idx]), len(loss_curves_train_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_tte[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_tte[m_idx]), len(loss_curves_val_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_tte[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss_tte.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [10]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
            model.eval()
    
            for p_idx, (prompt, age) in enumerate(zip(prompts, ages)):
                print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
                encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
                (lgts, tte_dist), _, _ = model(encoded_prompt,
                                                       # values=torch.tensor(value).to(device),
                                                       ages=to_days(age),
                                                       is_generation=True)
                probs = torch.nn.functional.softmax(lgts, dim=2)
                print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
                print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



TPPTransformerForCausalSequenceModelling: Exponential TTE
--------------------------------------

Control: 	 (DEPRESSION): 
	probability of type I diabetes: 0.1267%
	probability of type II diabetes: 0.1983%

Type 1: 	 (DEPRESSION,TYPE1DM): 
	probability of type I diabetes: 0.0709%
	probability of type II diabetes: 4.5693%

Type 2: 	 (DEPRESSION,TYPE2DIABETES): 
	probability of type I diabetes: 4.9913%
	probability of type II diabetes: 0.0227%


TPPTransformerForCausalSequenceModelling: Geometric TTE
--------------------------------------

Control: 	 (DEPRESSION): 
	probability of type I diabetes: 0.1530%
	probability of type II diabetes: 0.2534%

Type 1: 	 (DEPRESSION,TYPE1DM): 
	probability of type I diabetes: 0.1382%
	probability of type II diabetes: 1.9343%

Type 2: 	 (DEPRESSION,TYPE2DIABETES): 
	probability of type I diabetes: 2.6466%
	probability of type II diabetes: 0.0310%


## Age: How increasing prompt age affects likelihood of age related diagnoses

In [11]:
prompt = ["DEPRESSION"]
ages = [[4],[8],[20],[30],[60],[80],[90]]

# target_conditions=["TYPE1DM"]#, "TYPE2DIABETES", "OSTEOARTHRITIS", "ANY_DEAFNESS_HEARING_LOSS"]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, age in enumerate(ages):
        print(f"\nAge {age[-1]}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist), _, _ = model(encoded_prompt,
                                       ages=to_days(age),
                                       is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100

        # top K
        k = 10
        print(f"Top {k}")
        topk_prob, topk_ind = torch.topk(probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {j:.2f}%")

        # bottom K
        k = 30
        print(f"Bottom {k}")
        topk_prob, topk_ind = torch.topk(-probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {-j:.2f}%")
        
            # print(f"Age: {age[-1]} years old:  {100*float(probs[0, 0, target_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type



TPPTransformerForCausalSequenceModelling: Exponential TTE
--------------------------------------

Age 4
Top 10
	ANXIETY: 26.16%
	ASTHMA_PUSHASTHMA: 15.51%
	ATOPICECZEMA: 12.51%
	ALLERGICRHINITISCONJ: 10.22%
	ANY_DEAFNESS_HEARING_LOSS: 8.91%
	EATINGDISORDERS: 3.22%
	PSORIASIS: 2.87%
	AUTISM: 2.73%
	PREVALENT_IBS: 2.61%
	SUBSTANCEMISUSE: 2.20%
Bottom 30
	UNK: 0.00%
	PAD: 0.00%
	SICKLE_CELL_DISEASE: 0.00%
	AORTICANEURYSM: 0.00%
	CHRONIC_LIVER_DISEASE_ALCOHOL: 0.00%
	CYSTICFIBROSIS: 0.00%
	PSORIATICARTHRITIS2021: 0.00%
	PARKINSONS: 0.00%
	HAEMOCHROMATOSIS: 0.00%
	PLASMACELL_NEOPLASM: 0.00%
	SJOGRENSSYNDROME: 0.00%
	PERNICIOUSANAEMIA: 0.00%
	MENIERESDISEASE: 0.01%
	PMRANDGCA: 0.01%
	ALL_DEMENTIA: 0.01%
	SYSTEMIC_SCLEROSIS: 0.01%
	DOWNSSYNDROME: 0.01%
	ILD_SH: 0.01%
	ISCHAEMICSTROKE: 0.01%
	PAD_STRICT: 0.01%
	MINFARCTION: 0.01%
	ADDISON_DISEASE: 0.01%
	SYSTEMIC_LUPUS_ERYTHEMATOSUS: 0.01%
	CKDSTAGE3TO5: 0.02%
	MS: 0.02%
	AF: 0.02%
	STROKE_HAEMRGIC: 0.02%
	NAFLD: 0.02%
	HF: 0.03%
	BRONCHIECT

# Appendix: model architectures

In [12]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")



TPPTransformerForCausalSequenceModelling: Exponential TTE


TTETransformerForCausalSequenceModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (token_embed_layer): Embedding(74, 384, padding_idx=0)
      (value_embed_layer): EmbeddingBag(74, 384, mode=sum, padding_idx=0)
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Linear(in_features=384, out_features=384, bias=False)
          (q_proj): Linear(in_features=384, out_features=384, bias=False)
          (out_proj): Linear(in_features=384, out_features=384, bias=True)
        )
        (ln_2): LayerNorm((384,),

In [13]:
!jupyter nbconvert --to html --no-input TTE.ipynb

[NbConvertApp] Converting notebook TTE.ipynb to html
[NbConvertApp] Writing 589446 bytes to TTE.html
