# Demo Notebook:
## Time to Event Transformer For Causal Sequence Modelling 

Including time, and excluding values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

# Perform sqlite operations on disk
%env SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
%env TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
!echo $SQLITE_TMPDIR
!echo $TMPDIR
!echo $USERPROFILE

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
env: SQLITE_TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
env: TMPDIR=/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles
/rds/projects/g/gokhalkm-optimal/DataforCharles



In [2]:
import pytorch_lightning 
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.TTE.task_heads.causal import TTETransformerForCausalSequenceModelling
from tqdm import tqdm
import time
import os
# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# device = "cpu"    # if more informative debugging statements are needed
!pwd

cuda
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/tteGPT


## Build configurations

In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 256        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    TTELayer = "Exponential"

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 10
    
opt = OptConfig()

In [4]:
print(os.cpu_count())

72


## Create data loader on a reduced cohort

In [6]:
# Get a list of patients which fit a reduced set of criterion
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=False,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=10,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size

print(f"{len(dm.train_set)} training patients")
print(f"{len(dm.val_set)} validation patients")
print(f"{len(dm.test_set)} test patients")
print(f"{vocab_size} vocab elements")

INFO:root:Building Polars dataset and saving to /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/polars/
INFO:root:Chunking by unique practice ID with no inclusion conditions
INFO:root:Creating train/test/val splits using practice_ids
INFO:root:Extracting practice_patient_ids for each practice
                                             Train: 100%|██████████| 27/27 [00:00<00:00, 62.78it/s]
                                              Test: 100%|██████████| 2/2 [00:00<00:00, 98.20it/s]
                                        Validation: 100%|██████████| 2/2 [00:00<00:00, 86.67it/s]
INFO:root:Collecting meta information from database. This will be used for tokenization and standardisation.
                                      Measurements: 100%|██████████| 108/108 [00:11<00:00,  9.47it/s]
INFO:root:Collating train split into a DL friendly format. Generating over practices IDs
100%|██████████| 27/27 [04:20<00:00,  9.64s/it]
INFO:root:Collating test split into a

458330 training patients
25347 validation patients
22562 test patients
184 vocab elements





In [12]:
import time

# start = time.time()   # starting time
# for row_idx, row in enumerate(dm.train_set):
#     print(time.time() - start)
#     start = time.time()
#     if row_idx > opt.batch_size - 1:
#         break

start = time.time()   # starting time
for batch_idx, batch in enumerate(dm.train_dataloader()):
    print(time.time() - start)
    # time.sleep(np.abs(np.random.normal(10,0.5)))
    start = time.time()
    if batch_idx > 3:
        break
# print(f"{row} loaded in {time.time()-start} seconds")

11.740551948547363
1.9602737426757812
6.127357482910156e-05
3.0279159545898438e-05
2.8133392333984375e-05


In [8]:
display(batch)

{'tokens': tensor([[176, 183, 160,  ...,   0,   0,   0],
         [106,  44,  45,  ...,   0,   0,   0],
         [104, 111, 176,  ...,   0,   0,   0],
         ...,
         [176, 181, 182,  ...,   0,   0,   0],
         [176, 183, 181,  ..., 178, 170, 159],
         [176, 183, 160,  ...,   0,   0,   0]]),
 'ages': tensor([[14226, 14226, 14226,  ...,     0,     0,     0],
         [ 9264, 12419, 12419,  ...,     0,     0,     0],
         [ 4673,  6347,  9042,  ...,     0,     0,     0],
         ...,
         [10995, 10995, 10995,  ...,     0,     0,     0],
         [16725, 16725, 16725,  ..., 18352, 18352, 18352],
         [10325, 10325, 10325,  ...,     0,     0,     0]]),
 'values': tensor([[7.5033e-08, 2.2157e-04, 1.0164e-03,  ...,        nan,        nan,
                 nan],
         [       nan,        nan,        nan,  ...,        nan,        nan,
                 nan],
         [       nan,        nan, 6.3632e-08,  ...,        nan,        nan,
                 nan],
       

In [9]:
import polars as pl
pl.Config.set_tbl_rows(vocab_size + 1)
print(dm.tokenizer._event_counts)

shape: (183, 3)
┌───────────────────────────────────┬─────────┬───────────┐
│ EVENT                             ┆ COUNT   ┆ FREQUENCY │
│ ---                               ┆ ---     ┆ ---       │
│ str                               ┆ u32     ┆ f64       │
╞═══════════════════════════════════╪═════════╪═══════════╡
│ UNK                               ┆ 0       ┆ 0.0       │
│ Plasma_N_terminal_pro_B_type_nat… ┆ 39      ┆ 4.8626e-7 │
│ CYSTICFIBROSIS                    ┆ 135     ┆ 0.000002  │
│ SICKLE_CELL_DISEASE_V2            ┆ 136     ┆ 0.000002  │
│ SYSTEMIC_SCLEROSIS                ┆ 211     ┆ 0.000003  │
│ ADDISON_DISEASE                   ┆ 250     ┆ 0.000003  │
│ DOWNSSYNDROME                     ┆ 383     ┆ 0.000005  │
│ PLASMACELL_NEOPLASM               ┆ 426     ┆ 0.000005  │
│ HAEMOCHROMATOSIS_V2               ┆ 536     ┆ 0.000007  │
│ SJOGRENSSYNDROME                  ┆ 557     ┆ 0.000007  │
│ SYSTEMIC_LUPUS_ERYTHEMATOSUS      ┆ 604     ┆ 0.000008  │
│ N_terminal_pro_brain_n

## Create models and train

In [13]:
models, m_names = [], []

# My development model
for tte_layer in ["Exponential", "Geometric"]:
    config = DemoConfig()
    config.TTELayer = tte_layer
    models.append(TTETransformerForCausalSequenceModelling(config, vocab_size).to(device))
    m_names.append(f"TPPTransformerForCausalSequenceModelling: {tte_layer} TTE")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using ExponentialTTELayer. This module predicts the time until next event as an exponential distribution
INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using GeometricTTELayer. This module predicts the time until next event as a geometric distribution, supported on the set {0,1,...}


In [14]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_tte = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_tte = [[] for _ in models]

In [16]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        epoch_loss, epoch_clf_loss, epoch_tte_loss = 0, 0, 0
        model.train()
        start = time.time()
        for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
            print(f"Time to load batch {time.time()-start}")
            # if i > 50:
            #     break
                
            # evaluate the loss
            _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                  ages=batch['ages'].to(device), 
                                                  attention_mask=batch['attention_mask'].to(device)  
                                                  )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            # record
            epoch_clf_loss += loss_clf.item()
            epoch_tte_loss += loss_tte.item()
            start = time.time()
        epoch_loss /= i
        epoch_clf_loss /= i
        epoch_tte_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_clf[m_idx].append(epoch_clf_loss)
        loss_curves_train_tte[m_idx].append(epoch_tte_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_clf_loss, val_tte_loss = 0, 0, 0
                for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                    # if j > 20:
                    #     break
                    _, (loss_clf, loss_tte), loss = model(batch['tokens'].to(device), 
                                                          ages=batch['ages'].to(device),
                                                          attention_mask=batch['attention_mask'].to(device) 
                                                          )
                    val_loss += loss.item()
                    # record
                    val_clf_loss += loss_clf.item()
                    val_tte_loss += loss_tte.item()
                val_loss /= j
                val_clf_loss /= j
                val_tte_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_clf[m_idx].append(val_clf_loss)
                loss_curves_val_tte[m_idx].append(val_tte_loss)
                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}  ({epoch_clf_loss:.2f}, {epoch_tte_loss:.2f}). Val loss {val_loss:.2f} ({val_clf_loss:.2f}, {val_tte_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    # values = torch.tensor([[torch.nan]], device=device)
    # generate: sample the next 10 tokens
    new_tokens, new_ages = model.generate(tokens, ages, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    #    note, Not considering value yet.
    for _cat, _age in zip(generated.split(" "), new_ages[0, :]):
        print(f"\t {_cat} at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}


Training model `TPPTransformerForCausalSequenceModelling: Exponential TTE`, with 10.853185 M parameters


Training epoch 0:   0%|          | 0/7162 [00:00<?, ?it/s]

Time to load batch 12.829303979873657


Training epoch 0:   0%|          | 2/7162 [00:15<13:09:10,  6.61s/it]

Time to load batch 0.011183738708496094
Time to load batch 0.011097192764282227


Training epoch 0:   0%|          | 4/7162 [00:16<4:28:07,  2.25s/it] 

Time to load batch 0.002125978469848633
Time to load batch 0.0014393329620361328


Training epoch 0:   0%|          | 6/7162 [00:16<2:02:12,  1.02s/it]

Time to load batch 0.0020134449005126953
Time to load batch 0.002160787582397461


Training epoch 0:   0%|          | 8/7162 [00:16<1:03:45,  1.87it/s]

Time to load batch 0.0020194053649902344
Time to load batch 0.0021080970764160156


Training epoch 0:   0%|          | 10/7162 [00:16<37:11,  3.21it/s] 

Time to load batch 0.0012030601501464844


Training epoch 0:   0%|          | 11/7162 [00:26<6:16:04,  3.16s/it]

Time to load batch 9.427608489990234
Time to load batch 0.004099607467651367


Training epoch 0:   0%|          | 13/7162 [00:26<3:09:53,  1.59s/it]

Time to load batch 0.01113438606262207
Time to load batch 0.00103759765625


Training epoch 0:   0%|          | 15/7162 [00:27<1:40:53,  1.18it/s]

Time to load batch 0.0010645389556884766
Time to load batch 0.0009520053863525391


Training epoch 0:   0%|          | 17/7162 [00:27<57:12,  2.08it/s]  

Time to load batch 0.0010423660278320312
Time to load batch 0.0009596347808837891


Training epoch 0:   0%|          | 20/7162 [00:27<28:37,  4.16it/s]

Time to load batch 0.009146451950073242
Time to load batch 0.0010325908660888672


Training epoch 0:   0%|          | 21/7162 [00:39<7:30:08,  3.78s/it]

Time to load batch 11.906419277191162
Time to load batch 0.011122941970825195


Training epoch 0:   0%|          | 23/7162 [00:39<3:48:08,  1.92s/it]

Time to load batch 0.011175394058227539
Time to load batch 0.0011758804321289062


Training epoch 0:   0%|          | 25/7162 [00:40<1:58:59,  1.00s/it]

Time to load batch 0.0008556842803955078
Time to load batch 0.0009136199951171875


Training epoch 0:   0%|          | 28/7162 [00:40<48:59,  2.43it/s]  

Time to load batch 0.002193450927734375
Time to load batch 0.0009546279907226562
Time to load batch 0.0009748935699462891


Training epoch 0:   0%|          | 30/7162 [00:40<30:41,  3.87it/s]

Time to load batch 0.0009722709655761719


Training epoch 0:   0%|          | 31/7162 [00:51<6:49:05,  3.44s/it]

Time to load batch 10.715041875839233
Time to load batch 0.011060476303100586


Training epoch 0:   0%|          | 33/7162 [00:52<3:38:20,  1.84s/it]

Time to load batch 0.27408289909362793
Time to load batch 0.011074304580688477


Training epoch 0:   0%|          | 35/7162 [00:52<1:53:54,  1.04it/s]

Time to load batch 0.0008938312530517578
Time to load batch 0.0008678436279296875


Training epoch 0:   1%|          | 37/7162 [00:52<1:03:04,  1.88it/s]

Time to load batch 0.0008814334869384766
Time to load batch 0.0013027191162109375


Training epoch 0:   1%|          | 39/7162 [00:52<38:07,  3.11it/s]  

Time to load batch 0.0008828639984130859
Time to load batch 0.0008823871612548828


Training epoch 0:   1%|          | 41/7162 [01:04<7:18:19,  3.69s/it]

Time to load batch 11.555003881454468
Time to load batch 0.011188983917236328


Training epoch 0:   1%|          | 43/7162 [01:04<3:42:25,  1.87s/it]

Time to load batch 0.0009324550628662109


Training epoch 0:   1%|          | 44/7162 [01:05<3:03:08,  1.54s/it]

Time to load batch 0.6487607955932617
Time to load batch 0.008224010467529297


Training epoch 0:   1%|          | 46/7162 [01:06<1:38:00,  1.21it/s]

Time to load batch 0.011124134063720703
Time to load batch 0.011170387268066406


Training epoch 0:   1%|          | 48/7162 [01:06<55:10,  2.15it/s]  

Time to load batch 0.011152029037475586
Time to load batch 0.011102676391601562


Training epoch 0:   1%|          | 50/7162 [01:06<33:41,  3.52it/s]

Time to load batch 0.0008754730224609375


Training epoch 0:   1%|          | 51/7162 [01:19<7:49:57,  3.97s/it]

Time to load batch 12.420634508132935
Time to load batch 0.0008759498596191406


Training epoch 0:   1%|          | 54/7162 [01:19<2:49:37,  1.43s/it]

Time to load batch 0.0013346672058105469
Time to load batch 0.002171754837036133
Time to load batch 0.002204418182373047


Training epoch 0:   1%|          | 57/7162 [01:19<1:06:47,  1.77it/s]

Time to load batch 0.002172708511352539
Time to load batch 0.0009388923645019531


Training epoch 0:   1%|          | 59/7162 [01:19<38:52,  3.05it/s]  

Time to load batch 0.0010182857513427734
Time to load batch 0.0008676052093505859
Time to load batch 0.0008306503295898438


Training epoch 0:   1%|          | 61/7162 [01:31<7:20:47,  3.72s/it]

Time to load batch 11.673893928527832
Time to load batch 0.0008962154388427734


Training epoch 0:   1%|          | 63/7162 [01:32<3:42:49,  1.88s/it]

Time to load batch 0.0009682178497314453
Time to load batch 0.0009853839874267578


Training epoch 0:   1%|          | 66/7162 [01:32<1:24:30,  1.40it/s]

Time to load batch 0.000982522964477539
Time to load batch 0.0009758472442626953
Time to load batch 0.000982522964477539


Training epoch 0:   1%|          | 68/7162 [01:32<48:12,  2.45it/s]  

Time to load batch 0.0009586811065673828
Time to load batch 0.0009136199951171875


Training epoch 0:   1%|          | 70/7162 [01:32<31:20,  3.77it/s]

Time to load batch 0.0009670257568359375


Training epoch 0:   1%|          | 71/7162 [01:43<6:47:39,  3.45s/it]

Time to load batch 10.73011302947998
Time to load batch 0.001009225845336914


Training epoch 0:   1%|          | 73/7162 [01:44<3:35:41,  1.83s/it]

Time to load batch 0.21232819557189941


Training epoch 0:   1%|          | 74/7162 [01:47<4:11:10,  2.13s/it]

Time to load batch 2.679110050201416
Time to load batch 0.011190652847290039


Training epoch 0:   1%|          | 76/7162 [01:47<2:10:53,  1.11s/it]

Time to load batch 0.011202335357666016
Time to load batch 0.011143207550048828


Training epoch 0:   1%|          | 78/7162 [01:47<1:11:54,  1.64it/s]

Time to load batch 0.0029489994049072266
Time to load batch 0.0022025108337402344


Training epoch 0:   1%|          | 80/7162 [01:47<42:08,  2.80it/s]  

Time to load batch 0.0018265247344970703


Training epoch 0:   1%|          | 81/7162 [01:56<5:41:46,  2.90s/it]

Time to load batch 8.687105417251587
Time to load batch 0.0009522438049316406


Training epoch 0:   1%|          | 83/7162 [01:56<2:55:28,  1.49s/it]

Time to load batch 0.0009829998016357422


Training epoch 0:   1%|          | 84/7162 [01:59<3:23:27,  1.72s/it]

Time to load batch 2.126112461090088
Time to load batch 0.011151552200317383


Training epoch 0:   1%|          | 86/7162 [01:59<1:47:50,  1.09it/s]

Time to load batch 0.011116981506347656
Time to load batch 0.006203174591064453


Training epoch 0:   1%|          | 88/7162 [01:59<59:59,  1.97it/s]  

Time to load batch 0.0022420883178710938
Time to load batch 0.001630544662475586


Training epoch 0:   1%|▏         | 90/7162 [01:59<36:16,  3.25it/s]

Time to load batch 0.0009443759918212891


Training epoch 0:   1%|▏         | 91/7162 [02:08<5:37:38,  2.87s/it]

Time to load batch 8.697573900222778
Time to load batch 0.011085271835327148


Training epoch 0:   1%|▏         | 93/7162 [02:09<2:54:04,  1.48s/it]

Time to load batch 0.0008785724639892578


Training epoch 0:   1%|▏         | 94/7162 [02:11<3:11:33,  1.63s/it]

Time to load batch 1.8437871932983398
Time to load batch 0.013039112091064453


Training epoch 0:   1%|▏         | 96/7162 [02:11<1:42:18,  1.15it/s]

Time to load batch 0.011188507080078125
Time to load batch 0.0021576881408691406


Training epoch 0:   1%|▏         | 99/7162 [02:11<43:25,  2.71it/s]  

Time to load batch 0.0019481182098388672
Time to load batch 0.002140522003173828


Training epoch 0:   1%|▏         | 100/7162 [02:11<34:06,  3.45it/s]

Time to load batch 0.0039017200469970703


Training epoch 0:   1%|▏         | 101/7162 [02:20<5:50:46,  2.98s/it]

Time to load batch 9.123026371002197
Time to load batch 0.002267599105834961


Training epoch 0:   1%|▏         | 103/7162 [02:21<3:21:16,  1.71s/it]

Time to load batch 0.6441376209259033


Training epoch 0:   1%|▏         | 104/7162 [02:24<3:37:25,  1.85s/it]

Time to load batch 2.025434732437134
Time to load batch 0.011183738708496094


Training epoch 0:   1%|▏         | 107/7162 [02:24<1:23:20,  1.41it/s]

Time to load batch 0.0008568763732910156
Time to load batch 0.002160310745239258
Time to load batch 0.002149820327758789


Training epoch 0:   2%|▏         | 109/7162 [02:24<47:44,  2.46it/s]  

Time to load batch 0.0021750926971435547
Time to load batch 0.003252267837524414


Training epoch 0:   2%|▏         | 111/7162 [02:33<5:27:27,  2.79s/it]

Time to load batch 8.403613805770874
Time to load batch 0.011160850524902344


Training epoch 0:   2%|▏         | 113/7162 [02:34<3:09:34,  1.61s/it]

Time to load batch 0.5921590328216553


Training epoch 0:   2%|▏         | 114/7162 [02:37<3:56:23,  2.01s/it]

Time to load batch 2.7698140144348145
Time to load batch 0.011161565780639648


Training epoch 0:   2%|▏         | 116/7162 [02:37<2:04:20,  1.06s/it]

Time to load batch 0.011176109313964844
Time to load batch 0.0015871524810791016


Training epoch 0:   2%|▏         | 119/7162 [02:37<50:59,  2.30it/s]  

Time to load batch 0.0118560791015625
Time to load batch 0.00093841552734375


Training epoch 0:   2%|▏         | 120/7162 [02:37<39:17,  2.99it/s]

Time to load batch 0.0010635852813720703


Training epoch 0:   2%|▏         | 121/7162 [02:46<5:25:31,  2.77s/it]

Time to load batch 8.328293085098267
Time to load batch 0.011215448379516602


Training epoch 0:   2%|▏         | 123/7162 [02:47<3:23:35,  1.74s/it]

Time to load batch 1.0537564754486084


Training epoch 0:   2%|▏         | 124/7162 [02:50<4:20:04,  2.22s/it]

Time to load batch 3.207258701324463
Time to load batch 0.0009429454803466797


Training epoch 0:   2%|▏         | 127/7162 [02:51<1:37:28,  1.20it/s]

Time to load batch 0.0008916854858398438
Time to load batch 0.0009694099426269531


Training epoch 0:   2%|▏         | 128/7162 [02:51<1:18:55,  1.49it/s]

Time to load batch 0.1824812889099121
Time to load batch 0.0022122859954833984


Training epoch 0:   2%|▏         | 130/7162 [02:51<44:51,  2.61it/s]  

Time to load batch 0.0020160675048828125


Training epoch 0:   2%|▏         | 131/7162 [02:58<4:36:52,  2.36s/it]

Time to load batch 6.837976932525635
Time to load batch 0.011117935180664062


Training epoch 0:   2%|▏         | 133/7162 [03:01<4:03:46,  2.08s/it]

Time to load batch 2.8154335021972656


Training epoch 0:   2%|▏         | 134/7162 [03:03<3:33:50,  1.83s/it]

Time to load batch 1.1070115566253662
Time to load batch 0.011150598526000977


Training epoch 0:   2%|▏         | 136/7162 [03:03<1:53:24,  1.03it/s]

Time to load batch 0.011122941970825195
Time to load batch 0.0023534297943115234


Training epoch 0:   2%|▏         | 138/7162 [03:05<2:04:25,  1.06s/it]

Time to load batch 1.709071159362793
Time to load batch 0.011184930801391602


Training epoch 0:   2%|▏         | 140/7162 [03:05<1:09:37,  1.68it/s]

Time to load batch 0.011078596115112305


Training epoch 0:   2%|▏         | 141/7162 [03:10<3:54:49,  2.01s/it]

Time to load batch 5.153005838394165
Time to load batch 0.011121988296508789


Training epoch 0:   2%|▏         | 143/7162 [03:14<3:49:24,  1.96s/it]

Time to load batch 3.0020358562469482


Training epoch 0:   2%|▏         | 144/7162 [03:16<3:41:10,  1.89s/it]

Time to load batch 1.604011058807373
Time to load batch 0.010957956314086914


Training epoch 0:   2%|▏         | 146/7162 [03:16<1:55:17,  1.01it/s]

Time to load batch 0.0007851123809814453
Time to load batch 0.0007920265197753906


Training epoch 0:   2%|▏         | 148/7162 [03:18<2:23:33,  1.23s/it]

Time to load batch 2.2406063079833984
Time to load batch 0.010959863662719727


Training epoch 0:   2%|▏         | 150/7162 [03:19<1:18:33,  1.49it/s]

Time to load batch 0.011229515075683594


Training epoch 0:   2%|▏         | 151/7162 [03:23<3:44:26,  1.92s/it]

Time to load batch 4.701861381530762
Time to load batch 0.0008213520050048828


Training epoch 0:   2%|▏         | 153/7162 [03:27<4:09:57,  2.14s/it]

Time to load batch 3.7636775970458984


Training epoch 0:   2%|▏         | 154/7162 [03:28<3:19:55,  1.71s/it]

Time to load batch 0.5910639762878418
Time to load batch 0.0014684200286865234


Training epoch 0:   2%|▏         | 156/7162 [03:28<1:44:30,  1.12it/s]

Time to load batch 0.0009388923645019531
Time to load batch 0.0008788108825683594


Training epoch 0:   2%|▏         | 158/7162 [03:31<2:23:13,  1.23s/it]

Time to load batch 2.399024486541748
Time to load batch 0.011173486709594727


Training epoch 0:   2%|▏         | 160/7162 [03:31<1:17:47,  1.50it/s]

Time to load batch 0.0007991790771484375


Training epoch 0:   2%|▏         | 161/7162 [03:38<4:41:46,  2.41s/it]

Time to load batch 6.352023601531982
Time to load batch 0.0008747577667236328


Training epoch 0:   2%|▏         | 163/7162 [03:41<3:57:07,  2.03s/it]

Time to load batch 2.5930628776550293


Training epoch 0:   2%|▏         | 164/7162 [03:41<3:15:52,  1.68s/it]

Time to load batch 0.7443728446960449


Training epoch 0:   2%|▏         | 165/7162 [03:42<2:44:32,  1.41s/it]

Time to load batch 0.6824781894683838
Time to load batch 0.0008502006530761719


Training epoch 0:   2%|▏         | 167/7162 [03:42<1:27:35,  1.33it/s]

Time to load batch 0.011080265045166016


Training epoch 0:   2%|▏         | 168/7162 [03:44<1:46:03,  1.10it/s]

Time to load batch 1.1668801307678223
Time to load batch 0.003968000411987305


Training epoch 0:   2%|▏         | 170/7162 [03:44<1:00:31,  1.93it/s]

Time to load batch 0.011058807373046875


Training epoch 0:   2%|▏         | 171/7162 [03:51<4:29:53,  2.32s/it]

Time to load batch 6.3120129108428955


Training epoch 0:   2%|▏         | 172/7162 [03:51<3:13:47,  1.66s/it]

Time to load batch 0.009000778198242188


Training epoch 0:   2%|▏         | 173/7162 [03:53<3:35:07,  1.85s/it]

Time to load batch 2.1189963817596436


Training epoch 0:   2%|▏         | 174/7162 [03:55<3:33:02,  1.83s/it]

Time to load batch 1.6530673503875732


Training epoch 0:   2%|▏         | 176/7162 [03:56<2:11:12,  1.13s/it]

Time to load batch 0.8512420654296875
Time to load batch 0.0008246898651123047
Time to load batch 0.0008726119995117188


Training epoch 0:   2%|▏         | 178/7162 [03:57<1:52:19,  1.04it/s]

Time to load batch 1.1609904766082764
Time to load batch 0.011106491088867188


Training epoch 0:   3%|▎         | 180/7162 [03:58<1:02:46,  1.85it/s]

Time to load batch 0.011111736297607422


Training epoch 0:   3%|▎         | 181/7162 [04:04<4:26:43,  2.29s/it]

Time to load batch 6.239028215408325
Time to load batch 0.01104593276977539


Training epoch 0:   3%|▎         | 183/7162 [04:06<3:39:32,  1.89s/it]

Time to load batch 2.303036689758301


Training epoch 0:   3%|▎         | 184/7162 [04:09<3:49:39,  1.97s/it]

Time to load batch 2.054827928543091
Time to load batch 0.0008058547973632812


Training epoch 0:   3%|▎         | 187/7162 [04:09<1:26:58,  1.34it/s]

Time to load batch 0.0007984638214111328
Time to load batch 0.0007636547088623047


Training epoch 0:   3%|▎         | 188/7162 [04:11<2:27:47,  1.27s/it]

Time to load batch 2.3499643802642822
Time to load batch 0.0008666515350341797


Training epoch 0:   3%|▎         | 190/7162 [04:12<1:21:52,  1.42it/s]

Time to load batch 0.01287531852722168


Training epoch 0:   3%|▎         | 191/7162 [04:16<3:23:14,  1.75s/it]

Time to load batch 4.0509934425354
Time to load batch 0.011053323745727539


Training epoch 0:   3%|▎         | 193/7162 [04:21<4:23:30,  2.27s/it]

Time to load batch 4.475890398025513


Training epoch 0:   3%|▎         | 194/7162 [04:22<3:41:58,  1.91s/it]

Time to load batch 0.9543466567993164
Time to load batch 0.0009021759033203125


Training epoch 0:   3%|▎         | 196/7162 [04:22<1:56:36,  1.00s/it]

Time to load batch 0.0008187294006347656
Time to load batch 0.0008349418640136719


Training epoch 0:   3%|▎         | 198/7162 [04:26<2:59:05,  1.54s/it]

Time to load batch 3.27091908454895
Time to load batch 0.0008308887481689453


Training epoch 0:   3%|▎         | 200/7162 [04:26<1:35:45,  1.21it/s]

Time to load batch 0.00510859489440918


Training epoch 0:   3%|▎         | 201/7162 [04:29<2:46:20,  1.43s/it]

Time to load batch 2.7187700271606445
Time to load batch 0.0008521080017089844


Training epoch 0:   3%|▎         | 203/7162 [04:35<4:57:22,  2.56s/it]

Time to load batch 5.967419385910034
Time to load batch 0.011674880981445312


Training epoch 0:   3%|▎         | 205/7162 [04:35<2:34:03,  1.33s/it]

Time to load batch 0.0008358955383300781
Time to load batch 0.0008218288421630859


Training epoch 0:   3%|▎         | 207/7162 [04:36<1:23:10,  1.39it/s]

Time to load batch 0.0008082389831542969


Training epoch 0:   3%|▎         | 208/7162 [04:39<2:55:35,  1.52s/it]

Time to load batch 3.2420828342437744
Time to load batch 0.0008349418640136719


Training epoch 0:   3%|▎         | 210/7162 [04:39<1:34:18,  1.23it/s]

Time to load batch 0.0010807514190673828


Training epoch 0:   3%|▎         | 211/7162 [04:43<3:18:02,  1.71s/it]

Time to load batch 3.6540701389312744
Time to load batch 0.01111459732055664


Training epoch 0:   3%|▎         | 213/7162 [04:49<4:49:46,  2.50s/it]

Time to load batch 5.314417123794556


Training epoch 0:   3%|▎         | 214/7162 [04:49<3:33:22,  1.84s/it]

Time to load batch 0.19888687133789062
Time to load batch 0.0015797615051269531


Training epoch 0:   3%|▎         | 216/7162 [04:49<1:50:43,  1.05it/s]

Time to load batch 0.0008516311645507812
Time to load batch 0.0007927417755126953


Training epoch 0:   3%|▎         | 218/7162 [04:52<2:40:03,  1.38s/it]

Time to load batch 2.851625442504883
Time to load batch 0.0009908676147460938


Training epoch 0:   3%|▎         | 220/7162 [04:52<1:24:32,  1.37it/s]

Time to load batch 0.0007953643798828125


Training epoch 0:   3%|▎         | 221/7162 [04:56<3:03:12,  1.58s/it]

Time to load batch 3.450319290161133
Time to load batch 0.010961294174194336


Training epoch 0:   3%|▎         | 223/7162 [05:02<4:40:43,  2.43s/it]

Time to load batch 5.273319959640503
Time to load batch 0.0009243488311767578


Training epoch 0:   3%|▎         | 226/7162 [05:02<1:44:08,  1.11it/s]

Time to load batch 0.000762939453125
Time to load batch 0.0008096694946289062
Time to load batch 0.0008087158203125


Training epoch 0:   3%|▎         | 228/7162 [05:06<3:23:28,  1.76s/it]

Time to load batch 4.178831100463867
Time to load batch 0.008320331573486328


Training epoch 0:   3%|▎         | 230/7162 [05:07<1:47:27,  1.08it/s]

Time to load batch 0.0008149147033691406


Training epoch 0:   3%|▎         | 231/7162 [05:08<2:07:35,  1.10s/it]

Time to load batch 1.3878977298736572
Time to load batch 0.011075258255004883


Training epoch 0:   3%|▎         | 233/7162 [05:15<4:58:39,  2.59s/it]

Time to load batch 6.583277702331543


Training epoch 0:   3%|▎         | 234/7162 [05:15<3:43:30,  1.94s/it]

Time to load batch 0.2730898857116699
Time to load batch 0.015057802200317383


Training epoch 0:   3%|▎         | 236/7162 [05:16<1:57:58,  1.02s/it]

Time to load batch 0.0008318424224853516
Time to load batch 0.0008246898651123047


Training epoch 0:   3%|▎         | 238/7162 [05:17<1:53:50,  1.01it/s]

Time to load batch 1.4059770107269287
Time to load batch 0.010954856872558594


Training epoch 0:   3%|▎         | 240/7162 [05:18<1:03:34,  1.81it/s]

Time to load batch 0.009050846099853516


Training epoch 0:   3%|▎         | 241/7162 [05:22<3:08:48,  1.64s/it]

Time to load batch 4.02508544921875
Time to load batch 0.011055946350097656


Training epoch 0:   3%|▎         | 243/7162 [05:28<5:09:47,  2.69s/it]

Time to load batch 6.025081157684326


Training epoch 0:   3%|▎         | 244/7162 [05:29<3:53:36,  2.03s/it]

Time to load batch 0.3609330654144287
Time to load batch 0.0007429122924804688


Training epoch 0:   3%|▎         | 246/7162 [05:29<2:00:41,  1.05s/it]

Time to load batch 0.0007398128509521484
Time to load batch 0.0008630752563476562


Training epoch 0:   3%|▎         | 248/7162 [05:31<2:04:12,  1.08s/it]

Time to load batch 1.6772682666778564
Time to load batch 0.0008444786071777344


Training epoch 0:   3%|▎         | 250/7162 [05:31<1:09:23,  1.66it/s]

Time to load batch 0.0008487701416015625


Training epoch 0:   4%|▎         | 251/7162 [05:34<2:29:16,  1.30s/it]

Time to load batch 2.7720608711242676
Time to load batch 0.0021164417266845703


Training epoch 0:   4%|▎         | 253/7162 [05:41<5:19:57,  2.78s/it]

Time to load batch 6.930565595626831


Training epoch 0:   4%|▎         | 255/7162 [05:42<3:04:14,  1.60s/it]

Time to load batch 0.8904118537902832
Time to load batch 0.0008528232574462891
Time to load batch 0.0008039474487304688


Training epoch 0:   4%|▎         | 257/7162 [05:42<1:36:13,  1.20it/s]

Time to load batch 0.0008256435394287109


Training epoch 0:   4%|▎         | 258/7162 [05:44<2:06:27,  1.10s/it]

Time to load batch 1.5832939147949219
Time to load batch 0.009287118911743164


Training epoch 0:   4%|▎         | 260/7162 [05:44<1:09:57,  1.64it/s]

Time to load batch 0.011110782623291016


Training epoch 0:   4%|▎         | 261/7162 [05:46<1:45:02,  1.10it/s]

Time to load batch 1.5020818710327148
Time to load batch 0.011085033416748047


Training epoch 0:   4%|▎         | 263/7162 [05:53<5:05:15,  2.65s/it]

Time to load batch 7.119077920913696


Training epoch 0:   4%|▎         | 264/7162 [05:55<4:41:52,  2.45s/it]

Time to load batch 1.8573341369628906
Time to load batch 0.0007660388946533203


Training epoch 0:   4%|▎         | 267/7162 [05:56<1:44:40,  1.10it/s]

Time to load batch 0.0008165836334228516
Time to load batch 0.0008294582366943359


Training epoch 0:   4%|▎         | 268/7162 [05:58<2:18:19,  1.20s/it]

Time to load batch 1.7603390216827393
Time to load batch 0.0011942386627197266


Training epoch 0:   4%|▍         | 270/7162 [05:58<1:15:14,  1.53it/s]

Time to load batch 0.0008556842803955078


Training epoch 0:   4%|▍         | 271/7162 [05:59<1:28:28,  1.30it/s]

Time to load batch 0.9316871166229248
Time to load batch 0.0009012222290039062


Training epoch 0:   4%|▍         | 272/7162 [06:04<2:33:58,  1.34s/it]

KeyboardInterrupt



## Comparing output to real data

In [22]:
for batch in dm.train_dataloader():
    break
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, age) in enumerate(zip(conditions[0], batch["ages"][0,:])):
    if token == 0 or idx >= 10:
        break
    print(f"{dm.decode([token])}, at age {age/365:.0f} ({age:.1f} days)")

Diastolic_blood_pressure_5, at age 29 (10762.0 days)
Systolic_blood_pressure_4, at age 29 (10762.0 days)
ANXIETY, at age 30 (11068.0 days)
Diastolic_blood_pressure_5, at age 30 (11109.0 days)
Systolic_blood_pressure_4, at age 30 (11109.0 days)


In [23]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss.png")

# Plot classifier loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_clf[m_idx]), len(loss_curves_train_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_clf[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_clf[m_idx]), len(loss_curves_val_clf[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_clf[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss_clf.png")

# Plot tte loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_tte[m_idx]), len(loss_curves_train_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_tte[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_tte[m_idx]), len(loss_curves_val_tte[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_tte[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.yscale("log")
plt.legend()
plt.savefig(f"figs/TTE/logloss_tte.png")

  plt.yscale("log")


# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [None]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
desc.append("Control")
prompts.append(base_prompt)
ages.append(ages_in_years)
desc.append("Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
desc.append("Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
            model.eval()
    
            for p_idx, (prompt, age) in enumerate(zip(prompts, ages)):
                print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
                encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
                (lgts, tte_dist), _, _ = model(encoded_prompt,
                                                       # values=torch.tensor(value).to(device),
                                                       ages=to_days(age),
                                                       is_generation=True)
                probs = torch.nn.functional.softmax(lgts, dim=2)
                print(f"\tprobability of type I diabetes: {100*float(probs[0, 0, t1_token].cpu().detach().numpy()):.4f}%")
                print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

## Age: How increasing prompt age affects likelihood of age related diagnoses

In [None]:
prompt = ["ALLERGICRHINITISCONJ"]
ages = [[4],[8],[20],[30],[60],[80],[90]]

# target_conditions=["TYPE1DM"]#, "TYPE2DIABETES", "OSTEOARTHRITIS", "ANY_DEAFNESS_HEARING_LOSS"]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    # for condition in target_conditions:
    #     print(f"Probability of {condition}")
    #     target_token = dm.tokenizer._stoi[condition]

    for p_idx, age in enumerate(ages):
        print(f"\nAge {age[-1]}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (lgts, tte_dist), _, _ = model(encoded_prompt,
                                       ages=to_days(age),
                                       is_generation=True)
        probs = torch.nn.functional.softmax(lgts, dim=2) * 100

        # top K
        k = 10
        print(f"Top {k}")
        topk_prob, topk_ind = torch.topk(probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {j:.2f}%")

        # bottom K
        k = 30
        print(f"Bottom {k}")
        topk_prob, topk_ind = torch.topk(-probs[0,0,:], k)
        for i, j in zip(dm.decode(topk_ind.tolist()).split(" "), topk_prob):
            print(f"\t{i}: {-j:.2f}%")
        
            # print(f"Age: {age[-1]} years old:  {100*float(probs[0, 0, target_token].cpu().detach().numpy()):.4f}%")

# Note: adding a diagnosis (even if potentially orthogonal) at the beginning of the prompt increases probability of either type

# Appendix: model architectures

In [None]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")

In [None]:
!jupyter nbconvert --to html --no-input TTE.ipynb