# Demo Notebook:
## Competing Risk Survival Transformer For Causal Sequence Modelling 

Including time, tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT/notebooks/CompetingRisk


In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from CPRD.examples.modelling.SurvStreamGPT.experiment import run
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Demo Version of SurvStreamGPT

## Build configurations

In [4]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", overrides=[])


# cfg.data.batch_size = 16
# cfg.transformer.block_size = 32
# # cfg.transformer.n_layer = 10

In [5]:
print(OmegaConf.to_yaml(cfg))

is_decoder: true
data:
  batch_size: 64
  unk_freq_threshold: 0.0
  min_workers: 20
  global_diagnoses: false
  meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/meta_information.pickle
experiment:
  project_name: SurvStreamGPT_${head.SurvLayer}
  run_id: PreTrain_${head.SurvLayer}_11M_${experiment.seed}
  train: true
  test: true
  verbose: true
  seed: 1337
  log: true
  log_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/
  ckpt_dir: /rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/
optim:
  num_epochs: 1
  learning_rate: 0.0003
  val_check_interval: 1000
  early_stop: false
  early_stop_patience: 5
  log_every_n_steps: 20
  limit_val_batches: 0.05
  limit_test_batches: 0.05
transformer:
  block_type: Neo
  block_size: 128
  n_layer: 6
  n_head: 6
  n_embd: 384
  layer_norm_bias: false
  attention_type: global
  bias: true
  dropout: 0.0
  attention_dropout: 0.0
  res

In [6]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

# TODO: with above this trains, but due to widgets issue on hpc it does not print progress to notebook
# cfg.experiment.train = False
# cfg.experiment.test = False
# cfg.experiment.log = False
# model, dm = run(cfg)     


env: SLURM_NTASKS_PER_NODE=28


## Or define training process by hand
### Create data loader

In [7]:
# Build 
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            global_diagnoses=cfg.data.global_diagnoses,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                            overwrite_meta_information=cfg.data.meta_information_path,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# list of univariate measurements to model with Normal distribution
# Extract the measurements, using the fact that the diagnoses are all up upper case.
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
cfg.head.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
# display(measurements_for_univariate_regression)


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/meta_information.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 184 tokens
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/split=train/ dataset, with 23,343,104 samples
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/split=test/ dataset, with 1,263,168 samples
INFO:root:

184 vocab elements


In [8]:
# dm.train_set.view_sample(1000, report_time=True) # max_dynamic_events=120,

### Train

In [1]:
model = SurvStreamGPTForCausalModelling(cfg, vocab_size).to(device)
model = model.to(device)

loss_curves_train = []
loss_curves_train_surv = []
loss_curves_train_values = []

loss_curves_val = []
loss_curves_val_surv = []
loss_curves_val_values = []    
print(f"Training model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.optim.learning_rate)

best_val, epochs_since_best = np.inf, 0
for epoch in range(10):
    
    epoch_loss, epoch_surv_loss, epoch_values_loss = 0, 0, 0
    model.train()
    for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
        
            
        # evaluate the loss
        _, loss_dict, hidden_states = model(tokens=batch['tokens'].to(device),
                                ages=batch['ages'].to(device),
                                values=batch['values'].to(device),
                                covariates=batch["static_covariates"].to(device),
                                attention_mask=batch['attention_mask'].to(device)
                               )
        
        optimizer.zero_grad(set_to_none=True)
        loss_dict["loss"].backward()
        optimizer.step()

        # record
        epoch_loss += loss_dict["loss"].item()            
        epoch_surv_loss += loss_dict["loss_desurv"].item()
        epoch_values_loss += loss_dict["loss_values"].item()

        if i > 1000:
            break
    
    epoch_loss /= i
    epoch_surv_loss /= i
    epoch_values_loss /= i
    loss_curves_train.append(epoch_loss)
    loss_curves_train_surv.append(epoch_surv_loss)
    loss_curves_train_values.append(epoch_values_loss)

    # evaluate the loss on val set
    with torch.no_grad(): 
        model.eval()
        if epoch % 1 == 0 or epoch == cfg.optim.num_epochs - 1:
            val_loss, val_surv_loss, val_values_loss = 0, 0, 0
            for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                if j > 100:
                    break
                _, loss_dict, _ = model(tokens=batch['tokens'].to(device),
                                        ages=batch['ages'].to(device),
                                        values=batch['values'].to(device),
                                        covariates=batch["static_covariates"].to(device),
                                        attention_mask=batch['attention_mask'].to(device)
                                       )
                # record
                val_loss += loss_dict["loss"].item()                    
                val_surv_loss += loss_dict["loss_desurv"].item()
                val_values_loss += loss_dict["loss_values"].item()
                
            val_loss /= j
            val_surv_loss /= j
            val_values_loss /= j
            loss_curves_val.append(val_loss)
            loss_curves_val_surv.append(val_surv_loss)
            loss_curves_val_values.append(val_values_loss)

            print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_surv_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f}: ({val_surv_loss:.2f}, {val_values_loss:.2f})")          
            # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning

        if val_loss >= best_val:
            epochs_since_best += 1
            if epochs_since_best >= 5:
                break
        else:
            best_val = val_loss
            epochs_since_best = 0

            # Save best seen model
            # torch.save(model.state_dict(), path_to_db + "polars/CR.pt")
            

In [None]:
# Plot loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train), len(loss_curves_train)) * opt.eval_interval
plt.plot(iterations, loss_curves_train, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val), len(loss_curves_val)) * opt.eval_interval
plt.plot(iterations, loss_curves_val, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/loss.png")

# Plot DeSurv loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_surv), len(loss_curves_train_surv)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_surv, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val_surv), len(loss_curves_val_surv)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_surv, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/loss_desurv.png")

# Plot value loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_values), len(loss_curves_train_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_values, label="train", )
# Validation
iterations = np.linspace(0, len(loss_curves_val_values), len(loss_curves_val_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_values, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/loss_val.png")

# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input competing_risk.ipynb