# Demo Notebook:
## Competing Risk Survival Transformer For Causal Sequence Modelling 

Including time, tabular values

In [4]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT/notebooks/CompetingRisk
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Demo Version of SurvStreamGPT

## Build configurations

In [30]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk129M", overrides=[])


# cfg.data.batch_size = 16
# cfg.transformer.block_size = 32
# # cfg.transformer.n_layer = 10

In [31]:
print(OmegaConf.to_yaml(cfg))

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 20
experiment:
  project_name: SurvStreamGPT_${head.SurvLayer}
  run_id: PreTrain_${head.SurvLayer}_129M_${experiment.seed}
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
optim:
  num_epochs: 1
  learning_rate: 0.0001
  val_check_interval: 1000
  early_stop_patience: 5
  log_every_n_steps: 20
  limit_val_batches: 0.05
  limit_test_batches: 0.05
transformer:
  block_type: Neo
  block_size: 128
  n_layer: 10
  n_head: 8
  n_embd: 1024
  layer_norm_bias: false
  attention_type: global
  bias: true
  dropout: 0.2
  attention_dropout: 0.2
  resid_dropout: 0.2
head:
  SurvLayer: cr
  surv_weight: 0.5
  tokens_for_univariate_regression: None
  value_weight: 0.5



In [32]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

# TODO: with above this trains, but due to widgets issue on hpc it does not print progress to notebook
# cfg.experiment.train = False
# cfg.experiment.test = False
# cfg.experiment.log = False
# model, dm = run(cfg)     


env: SLURM_NTASKS_PER_NODE=28


## Or define training process by hand
### Create data loader

In [33]:
# Build 
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            unk_freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            inclusion_conditions=["COUNTRY = 'E'"],
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)


INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,912,046 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,207,449 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,226,576 samples


184 vocab elements


In [42]:
dm.train_set.view_sample(1000, report_time=True) # max_dynamic_events=120,

# import time
# start_time = time.time()
# for _idx, batch in enumerate(tqdm(dm.train_dataloader())):
#     if _idx > 100:
#         break
# print(time.time() - start_time)
# print(dm.meta_information.keys())


Time to retrieve sample index 1000 was 0.23534679412841797 seconds

SEX                 | F
IMD                 | 3.0
ETHNICITY           | WHITE
birth_year          | 1966.0

Token                                                                      | Age               | Standardised value
PSORIASIS                                                                  | 535               | nan               
ULCERATIVE_COLITIS                                                         | 9179              | nan               
DEPRESSION                                                                 | 9614              | nan               
ANXIETY                                                                    | 9614              | nan               
CROHNS_DISEASE                                                             | 10590             | nan               
Serum_free_T4_level_75                                                     | 15603             | 0.08              
Total_white_

### Train

In [14]:
model = SurvStreamGPTForCausalModelling(cfg, vocab_size).to(device)
model = model.to(device)

loss_curves_train = []
loss_curves_train_surv = []
loss_curves_train_values = []

loss_curves_val = []
loss_curves_val_surv = []
loss_curves_val_values = []    
print(f"Training model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.optim.learning_rate)

best_val, epochs_since_best = np.inf, 0
for epoch in range(10):
    
    epoch_loss, epoch_surv_loss, epoch_values_loss = 0, 0, 0
    model.train()
    for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
        
            
        # evaluate the loss
        _, (losses_desurv, loss_values), loss = model(tokens=batch['tokens'].to(device), 
                                                      ages=batch['ages'].to(device), 
                                                      values=batch['values'].to(device),
                                                      covariates=batch["static_covariates"].to(device),
                                                      attention_mask=batch['attention_mask'].to(device)   
                                                      )
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # record
        epoch_loss += loss.item()            
        epoch_surv_loss += torch.sum(losses_desurv).item()
        epoch_values_loss += loss_values.item()

        if i > 1000:
            break
    
    epoch_loss /= i
    epoch_surv_loss /= i
    epoch_values_loss /= i
    loss_curves_train.append(epoch_loss)
    loss_curves_train_surv.append(epoch_surv_loss)
    loss_curves_train_values.append(epoch_values_loss)

    # evaluate the loss on val set
    with torch.no_grad(): 
        model.eval()
        if epoch % 1 == 0 or epoch == cfg.optim.num_epochs - 1:
            val_loss, val_surv_loss, val_values_loss = 0, 0, 0
            for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                if j > 100:
                    break
                _, (losses_desurv, loss_values), loss = model(tokens=batch['tokens'].to(device), 
                                                              ages=batch['ages'].to(device),
                                                              values=batch['values'].to(device),
                                                              covariates=batch["static_covariates"].to(device),
                                                              attention_mask=batch['attention_mask'].to(device)   
                                                              )
                # record
                val_loss += loss.item()                    
                val_surv_loss += torch.sum(losses_desurv).item()
                val_values_loss += loss_values.item()
                
            val_loss /= j
            val_surv_loss /= j
            val_values_loss /= j
            loss_curves_val.append(val_loss)
            loss_curves_val_surv.append(val_surv_loss)
            loss_curves_val_values.append(val_values_loss)

            print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_surv_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f}: ({val_surv_loss:.2f}, {val_values_loss:.2f})")          
            # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning

        if val_loss >= best_val:
            epochs_since_best += 1
            if epochs_since_best >= 5:
                break
        else:
            best_val = val_loss
            epochs_since_best = 0

            # Save best seen model
            # torch.save(model.state_dict(), path_to_db + "polars/CR.pt")
            

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Competing-Risk DeSurv head.
INFO:root:Internally scaling time in survival head by 1825 days
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1825.0] with 1826 intervals of delta=1.0


Training model with 129.024283 M parameters


Training epoch 0:   0%|          | 1001/358001 [05:57<35:24:23,  2.80it/s]
Validation epoch 0:   1%|          | 101/19166 [00:23<1:14:17,  4.28it/s]


Epoch 0:	Train loss -0.71: (1.43, -2.85). Val loss -2.85: (0.64, -6.35)


Training epoch 1:   0%|          | 667/358001 [03:59<35:39:33,  2.78it/s]
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0f23f7c4c0>
Traceback (most recent call last):
  File "/rds/bear-apps/2022a/EL8-ice/software/PyTorch/2.0.1-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
  File "/rds/bear-apps/2022a/EL8-ice/software/PyTorch/2.0.1-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/rds/bear-apps/2022a/EL8-ice/software/Python/3.10.4-GCCcore-11.3.0/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/rds/bear-apps/2022a/EL8-ice/software/Python/3.10.4-GCCcore-11.3.0/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/rds/bear-apps/2022a/EL8-ice/software/Pyth

KeyboardInterrupt: 

In [None]:
# Plot loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train), len(loss_curves_train)) * opt.eval_interval
plt.plot(iterations, loss_curves_train, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val), len(loss_curves_val)) * opt.eval_interval
plt.plot(iterations, loss_curves_val, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/competing_risk/loss.png")

# Plot DeSurv loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_surv), len(loss_curves_train_surv)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_surv, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val_surv), len(loss_curves_val_surv)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_surv, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/competing_risk/loss_desurv.png")

# Plot value loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_values), len(loss_curves_train_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_values, label="train", )
# Validation
iterations = np.linspace(0, len(loss_curves_val_values), len(loss_curves_val_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_values, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/competing_risk/loss_val.png")

# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input competing_risk.ipynb