# Demo Notebook:
## Time to Event Transformer For Causal Sequence Modelling 

Including time, and excluding values

In [1]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal import TTETransformerForCausalSequenceModelling
from tqdm import tqdm
import os

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

  return torch._C._cuda_getDeviceCount() > 0


cpu
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/tteGPT


## Build configurations

In [2]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 128
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 50
    
opt = OptConfig()

In [3]:


print(os.cpu_count())

72


## Create data loader on a reduced cohort

In [4]:
# Get a list of patients which fit a reduced set of criterion
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=10, # os.cpu_count()
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")


INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/polars/
INFO:root:{'diagnosis_table':                    event  count
0        ADDISON_DISEASE    250
1                     AF  17696
2          ALCOHOLMISUSE  31555
3     ALLCA_NOBCC_VFINAL  35375
4   ALLERGICRHINITISCONJ  71148
..                   ...    ...
69               TYPE1DM   3184
70         TYPE2DIABETES  30317
71    ULCERATIVE_COLITIS   2697
72      VALVULARDISEASES   9296
73     VISUAL_IMPAIRMENT   3512

[74 rows x 2 columns], 'measurement_table':                                                  event    count  count_obs  \
0                        25-Hydroxyvitamin_D2_level_92    10102       9212   
1                        25-Hydroxyvitamin_D3_level_90    10010       9927   
2                    AST_-_aspartate_transam._SGOT__46    41243      38870   
3                                   AST_serum_level_47   165378     157037   
4                        Albumin

8466869 training patients
405239 validation patients
373904 test patients
184 vocab elements


In [5]:
import time

# start = time.time()   # starting time
# for row_idx, row in enumerate(dm.train_set):
#     print(time.time() - start)
#     start = time.time()
#     if row_idx > opt.batch_size - 1:
#         break

start = time.time()   # starting time
for row_idx, row in enumerate(dm.train_dataloader()):
    print(time.time() - start)
    time.sleep(np.abs(np.random.normal(10,0.5)))
    start = time.time()
    if row_idx > 300:
        break
# print(f"{row} loaded in {time.time()-start} seconds")

30.107887983322144
0.47054362297058105
0.06446027755737305
0.00037550926208496094
0.0003097057342529297
0.00037980079650878906
0.0002963542938232422
0.001577138900756836
0.00029468536376953125
0.0002884864807128906
0.004949808120727539
0.001485586166381836
0.003221750259399414
0.00033926963806152344
0.00030159950256347656
0.0003521442413330078
0.00030422210693359375
0.0014061927795410156
0.0003914833068847656
0.0003902912139892578
0.0014438629150390625
0.001495361328125
0.001459360122680664
0.0014119148254394531
0.0014584064483642578
0.0014448165893554688
0.0014202594757080078
0.0014004707336425781
0.0013849735260009766
0.0014469623565673828
0.0013654232025146484
0.0014650821685791016
0.0023679733276367188
0.0014128684997558594
0.0014066696166992188
0.0013804435729980469
0.0014569759368896484
0.0013403892517089844
0.0013632774353027344
0.00141143798828125
0.0013647079467773438
0.0014128684997558594
0.0014157295227050781
0.0013294219970703125
0.0013530254364013672
0.0013823509216308594



KeyboardInterrupt



In [None]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

## Create models and train

In [None]:
models, m_names = [], []

# My development model
for tte_layer in ["Exponential", "Geometric"]:
    config = DemoConfig()
    config.TTELayer = tte_layer
    models.append(TTETransformerForCausalSequenceModelling(config, vocab_size).to(device))
    m_names.append(f"TPPTransformerForCausalSequenceModelling: {tte_layer} TTE")

In [None]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_tte = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_tte = [[] for _ in models]

In [None]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss, epoch_clf_loss, epoch_tte_loss = 0, 0, 0
        model.train()
        for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
            if i > 50:
                break
                
            # evaluate the loss
            _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                  ages=batch['ages'].to(device), 
                                                  attention_mask=batch['attention_mask'].to(device)  
                                                  )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            # record
            epoch_clf_loss += loss_clf.item()
            epoch_tte_loss += loss_tte.item()
        epoch_loss /= i
        epoch_clf_loss /= i
        epoch_tte_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_clf[m_idx].append(epoch_clf_loss)
        loss_curves_train_tte[m_idx].append(epoch_tte_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_clf_loss, val_tte_loss = 0, 0, 0
                for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                    if j > 20:
                        break
                    _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                          ages=batch['ages'].to(device),
                                                          attention_mask=batch['attention_mask'].to(device) 
                                                          )
                    val_loss += loss.item()
                    # record
                    val_clf_loss += loss_clf.item()
                    val_tte_loss += loss_tte.item()
                val_loss /= j
                val_clf_loss /= j
                val_tte_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_clf[m_idx].append(val_clf_loss)
                loss_curves_val_tte[m_idx].append(val_tte_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    # values = torch.tensor([[torch.nan]], device=device)
    # generate: sample the next 10 tokens
    new_tokens, new_ages = model.generate(tokens, ages, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    #    note, Not considering value yet.
    for _cat, _age in zip(generated.split(" "), new_ages[0, :]):
        print(f"\t {_cat} at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}


## Comparing output to real data

In [None]:
for batch in dm.train_dataloader():
    break
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, age) in enumerate(zip(conditions[0], batch["ages"][0,:])):
    if token == 0 or idx >= 10:
        break
    print(f"{dm.decode([token])}, at age {age/365:.0f} ({age:.1f} days)")

In [None]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss.png")

# Plot classifier loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_clf[m_idx]), len(loss_curves_train_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_clf[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_clf[m_idx]), len(loss_curves_val_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_clf[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss_clf.png")

# Plot tte loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_tte[m_idx]), len(loss_curves_train_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_tte[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_tte[m_idx]), len(loss_curves_val_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_tte[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss_tte.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [None]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
            model.eval()
    
            for p_idx, (prompt, age) in enumerate(zip(prompts, ages)):
                print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
                encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
                (lgts, tte_dist), _, _ = model(encoded_prompt,
                                                       # values=torch.tensor(value).to(device),
                                                       ages=to_days(age),
                                                       is_generation=True)
                probs = torch.nn.functional.softmax(lgts, dim=2)
                print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
                print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

## Age: How increasing prompt age affects likelihood of age related diagnoses

In [None]:
prompt = ["ALLERGICRHINITISCONJ"]
ages = [[4],[8],[20],[30],[60],[80],[90]]

# target_conditions=["TYPE1DM"]#, "TYPE2DIABETES", "OSTEOARTHRITIS", "ANY_DEAFNESS_HEARING_LOSS"]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, age in enumerate(ages):
        print(f"\nAge {age[-1]}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist), _, _ = model(encoded_prompt,
                                       ages=to_days(age),
                                       is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100

        # top K
        k = 10
        print(f"Top {k}")
        topk_prob, topk_ind = torch.topk(probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {j:.2f}%")

        # bottom K
        k = 30
        print(f"Bottom {k}")
        topk_prob, topk_ind = torch.topk(-probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {-j:.2f}%")
        
            # print(f"Age: {age[-1]} years old:  {100*float(probs[0, 0, target_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

# Appendix: model architectures

In [None]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")

In [None]:
!jupyter nbconvert --to html --no-input TTE.ipynb