# Demo Notebook:
## Time to Event Transformer For Causal Sequence Modelling 

Including time, and excluding values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

# Perform sqlite operations on disk
%env SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
%env TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
!echo $SQLITE_TMPDIR
!echo $TMPDIR
!echo $USERPROFILE

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
env: SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
env: TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles



In [2]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal import TTETransformerForCausalSequenceModelling
from tqdm import tqdm
import time
import os
# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/tteGPT


## Build configurations

In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"       # alternatively "Geometric"

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 3
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [4]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=20,
                            inclusion_conditions=["COUNTRY = 'E'"],
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,912,046 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,207,449 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,226,576 samples


184 vocab elements


## View a single patient

In [5]:
dm.train_set.view_sample(1, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 1 was 0.10879707336425781 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1997.0

Token                                                                      | Age               | Standardised value
O_E___weight_2                                                             | 5523              | 0.11              
Systolic_blood_pressure_4                                                  | 5523              | -0.11             
Body_mass_index_3                                                          | 5564              | 0.24              
O_E___height_1                                                             | 5564              | -0.05             
O_E___weight_2                                                             | 5564              | 0.17              
Body_mass_index_3                                                          | 5687              | 0.26              
O_E___height_1 

## Create models and train

In [6]:
model = TTETransformerForCausalSequenceModelling(config, vocab_size).to(device)

loss_curves_train = []
loss_curves_train_clf = []
loss_curves_train_tte = []

loss_curves_val = []
loss_curves_val_clf = []
loss_curves_val_tte = []


INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using ExponentialTTELayer. This module predicts the time until next event as an exponential distribution


In [None]:
print(f"Training model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
model = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

best_val, epochs_since_best = np.inf, 0
for epoch in range(opt.epochs):
    
    epoch_loss, epoch_clf_loss, epoch_tte_loss = 0, 0, 0
    model.train()
    for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
        if i > 1000:
            break
            
        # evaluate the loss
        _, (losses_clf, loss_tte), loss = model(tokens=batch['tokens'].to(device),
                                                ages=batch['ages'].to(device),
                                                attention_mask=batch['attention_mask'].to(device)
                                               )
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # record
        epoch_loss += loss.item()            
        epoch_clf_loss += torch.sum(losses_clf).item()
        epoch_tte_loss += loss_tte.item()
    
    epoch_loss /= i
    epoch_clf_loss /= i
    epoch_tte_loss /= i
    loss_curves_train.append(epoch_loss)
    loss_curves_train_clf.append(epoch_clf_loss)
    loss_curves_train_tte.append(epoch_tte_loss)

    # evaluate the loss on val set
    with torch.no_grad(): 
        model.eval()
        if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
            val_loss, val_clf_loss, val_tte_loss = 0, 0, 0
            for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                if j > 100:
                    break
                _, (losses_clf, loss_tte), loss = model(tokens=batch['tokens'].to(device), 
                                                              ages=batch['ages'].to(device),
                                                              attention_mask=batch['attention_mask'].to(device)   
                                                              )
                # record
                val_loss += loss.item()                    
                val_clf_loss += torch.sum(losses_clf).item()
                val_tte_loss += loss_tte.item()
                
            val_loss /= j
            val_clf_loss /= j
            val_tte_loss /= j
            loss_curves_val.append(val_loss)
            loss_curves_val_clf.append(val_clf_loss)
            loss_curves_val_tte.append(val_tte_loss)

            print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}). Val loss {val_loss:.2f}: ({val_clf_loss:.2f}, {val_tte_loss:.2f})")          
            # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning

        if val_loss >= best_val:
            epochs_since_best += 1
            if epochs_since_best >= 5:
                break
        else:
            best_val = val_loss
            epochs_since_best = 0

            # Save best seen model
            torch.save(model.state_dict(), path_to_db + "polars/SR.pt")
            


Training model with 11.155393 M parameters


Training epoch 0:   0%|          | 1001/358001 [02:58<17:41:52,  5.60it/s]
Validation epoch 0:   1%|          | 101/19166 [00:19<1:01:42,  5.15it/s]


Epoch 0:	Train loss -0.61: (1.22, -2.44). Val loss -0.96: (1.05, -2.98)


Training epoch 1:   0%|          | 540/358001 [01:36<13:12:32,  7.52it/s]

## Generation

In [None]:
# Default context start
prompt = ["O_E___height_1", "O_E___weight_2"]
ages_in_years = [18.2, 18.2]

# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

# Convert for model
tokens = encode_prompt(prompt)
ages_in_days = encode_age(ages_in_years)

In [None]:
# generate: sample the next 10 tokens
new_tokens, new_ages = model.generate(tokens, ages_in_days, max_new_tokens=10)

# report:
print(f"PROMPT:")
for _idx, (_cat, _age,) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "),
                                         new_ages[0, :]
                                        )
                                    ):
    # _value = dm.unstandardise(_cat, _value)
    print(f"{_cat}".ljust(50) + f"at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}
    if _idx == tokens.shape[-1] - 1:
        print("="*75)
        print(f"GENERATION")

## Comparing generation to real data

In [None]:
dm.train_set.view_sample(1, max_dynamic_events=10, report_time=True)

In [None]:

# Plot loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train), len(loss_curves_train)) * opt.eval_interval
plt.plot(iterations, loss_curves_train, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val), len(loss_curves_val)) * opt.eval_interval
plt.plot(iterations, loss_curves_val, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/TTE/loss.png")

# Plot Classifier loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_clf), len(loss_curves_train_clf)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_clf, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val_clf), len(loss_curves_val_clf)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_clf, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/TTE/loss_clf.png")

# Plot TTE loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_tte), len(loss_curves_train_tte)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_tte, label="train", )
# Validation
iterations = np.linspace(0, len(loss_curves_val_tte), len(loss_curves_val_tte)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_tte, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/TTE/loss_tte.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [None]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])

with torch.no_grad(): 
    model.eval()

    for p_idx, (prompt, age) in enumerate(zip(prompts, ages)):
        print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist), _, _ = model(encoded_prompt,
                                               # values=torch.tensor(value).to(device),
                                               ages=to_days(age),
                                               is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2)
        print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
        print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

## Age: How increasing prompt age affects likelihood of age related diagnoses

In [None]:
prompt = ["ALLERGICRHINITISCONJ"]
ages = [[4],[8],[20],[30],[60],[80],[90]]

# target_conditions=["TYPE1DM"]#, "TYPE2DIABETES", "OSTEOARTHRITIS", "ANY_DEAFNESS_HEARING_LOSS"]


# for condition in target_conditions:
#     print(f"Probability of {condition}")
#     target_token = dm.tokenizer._stoi[condition]

for p_idx, age in enumerate(ages):
    print(f"\nAge {age[-1]}\n======")
    encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
    (lgts, tte_dist), _, _ = model(encoded_prompt,
                                   ages=to_days(age),
                                   is_generation=True)
    probs = torch.nn.functional.softmax(lgts, dim=2) * 100

    # top K
    k = 10
    print(f"Top {k}")
    topk_prob, topk_ind = torch.topk(probs[0,0,:], k)
    for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
        print(f"\t{i}: {j:.2f}%")

    # bottom K
    k = 30
    print(f"Bottom {k}")
    topk_prob, topk_ind = torch.topk(-probs[0,0,:], k)
    for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
        print(f"\t{i}: {-j:.2f}%")
    
        # print(f"Age: {age[-1]} years old:  {100*float(probs[0, 0, target_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input TTE.ipynb