# Demo Notebook:
## Survival Transformer For Causal Sequence Modelling 

Including time, and excluding tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT


In [2]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from pycox.evaluation import EvalSurv
from tqdm import tqdm

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


## Build configurations

In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 128        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    SurvLayer = "cr"                                  # "Competing-Risk"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 30
    
opt = OptConfig()

## Create data loader

In [4]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=20,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]
# display(measurements_for_univariate_regression)
# list of univariate measurements to model with Normal distribution
config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,912,046 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,207,449 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,226,576 samples


184 vocab elements


In [5]:
dm.train_set.view_sample(6546, max_dynamic_events=12, report_time=True)

Time to retrieve sample index 6546 was 0.06381487846374512 seconds

SEX                 | M
IMD                 | 4.0
ETHNICITY           | ASIAN
birth_year          | 1988.0

Token                                                                      | Age               | Standardised value
Body_mass_index_3                                                          | 8805              | -0.14             
O_E___height_1                                                             | 8805              | 0.25              
O_E___weight_2                                                             | 8805              | 0.02              
Diastolic_blood_pressure_5                                                 | 11070             | -0.06             
Systolic_blood_pressure_4                                                  | 11070             | -0.22             
Diastolic_blood_pressure_5                                                 | 11081             | -0.13             
Systolic_blo

In [6]:
# display(dm.meta_information["static_table"])
# Xc = dm.train_set._encode_covariates(sex="M", deprivation=1.0, ethnicity="WHITE", year_of_birth=1992)
# print(Xc)

In [7]:
# import pandas as pd
# pd.set_option('display.max_rows', 1000) #replace n with the number of columns you want to see completely
# display(dm.train_set.meta_measurement)
# display(dm.train_set.meta_information["diagnosis_table"])

## Create models and train

In [8]:
model = SurvStreamGPTForCausalModelling(config, vocab_size).to(device)

loss_curves_train = []
loss_curves_train_surv = []
loss_curves_train_values = []

loss_curves_val = []
loss_curves_val_surv = []
loss_curves_val_values = []

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Competing-Risk DeSurv head.
INFO:root:Internally scaling time in survival head by 1825 days
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1825.0], with delta=1/300


In [9]:
# for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
print(f"Training model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
model = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

best_val, epochs_since_best = np.inf, 0
for epoch in range(opt.epochs):
    
    epoch_loss, epoch_surv_loss, epoch_values_loss = 0, 0, 0
    model.train()
    for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
        if i > 1000:
            break
            
        # evaluate the loss
        _, (losses_desurv, loss_values), loss = model(tokens=batch['tokens'].to(device), 
                                                      ages=batch['ages'].to(device), 
                                                      values=batch['values'].to(device),
                                                      covariates=batch["static_covariates"].to(device),
                                                      attention_mask=batch['attention_mask'].to(device)   
                                                      )
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # record
        epoch_loss += loss.item()            
        epoch_surv_loss += torch.sum(losses_desurv).item()
        epoch_values_loss += loss_values.item()
    
    epoch_loss /= i
    epoch_surv_loss /= i
    epoch_values_loss /= i
    loss_curves_train.append(epoch_loss)
    loss_curves_train_surv.append(epoch_surv_loss)
    loss_curves_train_values.append(epoch_values_loss)

    # evaluate the loss on val set
    with torch.no_grad(): 
        model.eval()
        if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
            val_loss, val_surv_loss, val_values_loss = 0, 0, 0
            for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                if j > 100:
                    break
                _, (losses_desurv, loss_values), loss = model(tokens=batch['tokens'].to(device), 
                                                              ages=batch['ages'].to(device),
                                                              values=batch['values'].to(device),
                                                              covariates=batch["static_covariates"].to(device),
                                                              attention_mask=batch['attention_mask'].to(device)   
                                                              )
                # record
                val_loss += loss.item()                    
                val_surv_loss += torch.sum(losses_desurv).item()
                val_values_loss += loss_values.item()
                
            val_loss /= j
            val_surv_loss /= j
            val_values_loss /= j
            loss_curves_val.append(val_loss)
            loss_curves_val_surv.append(val_surv_loss)
            loss_curves_val_values.append(val_values_loss)

            print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_surv_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f}: ({val_surv_loss:.2f}, {val_values_loss:.2f})")          
            # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning

        if val_loss >= best_val:
            epochs_since_best += 1
            if epochs_since_best >= 5:
                break
        else:
            best_val = val_loss
            epochs_since_best = 0

            # Save best seen model
            torch.save(model.state_dict(), path_to_db + "polars/CR.pt")
            


Training model with 11.308635 M parameters


Training epoch 0:   0%|          | 1001/358001 [06:23<37:56:35,  2.61it/s]
Validation epoch 0:   1%|          | 101/19166 [00:21<1:06:43,  4.76it/s]


Epoch 0:	Train loss -3.77: (1.11, -8.65). Val loss -6.84: (-0.35, -13.33)


Training epoch 1:   0%|          | 200/358001 [01:19<39:32:06,  2.51it/s]

KeyboardInterrupt



## Generation

In [120]:
# Default context start
baseline_covariates = {"sex": "F", "deprivation": 4.0, "ethnicity": "WHITE", "year_of_birth": 1997}
prompt = ["O_E___height_1", "O_E___weight_2"]
values = [163, 90]
ages_in_years = [18.2, 18.2]

# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

# Convert for model
covariates = dm.train_set._encode_covariates(**baseline_covariates).reshape(1,-1).to(device)
tokens = encode_prompt(prompt)
values_scaled = encode_value(prompt, values)
ages_in_days = encode_age(ages_in_years)

In [121]:
# generate: sample the next 10 tokens
new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, covariates, max_new_tokens=10)

# report:
print(f"Baseline covariates: \n{baseline_covariates}\n" + "="*90)
print(f"PROMPT:")
for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                new_ages[0, :], 
                                                new_values[0, :]
                                               )
                                           ):
    # _value = dm.unstandardise(_cat, _value)
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}
    if _idx == tokens.shape[-1] - 1:
        print("="*90)
        print(f"GENERATION")

Baseline covariates: 
{'sex': 'F', 'deprivation': 4.0, 'ethnicity': 'WHITE', 'year_of_birth': 1997}
PROMPT:
O_E___height_1                                    -0.04          at age 18 (6643.0 days)
O_E___weight_2                                    0.16           at age 18 (6643.0 days)
GENERATION
Plasma_C_reactive_protein_60                      0.28           at age 18 (6649.1 days)
Platelet_count_12                                 0.15           at age 18 (6649.1 days)
Red_blood_cell__RBC__count_10                     -0.10          at age 18 (6704.0 days)
GFR_calculated_abbreviated_MDRD_34                0.06           at age 18 (6722.3 days)
Haematocrit_15                                    -0.15          at age 18 (6728.5 days)
Haemoglobin_estimation_9                          -0.25          at age 18 (6728.5 days)
Lymphocyte_count_20                               -0.09          at age 18 (6734.6 days)
Mean_corpusc_haemoglobin_MCH__13                  -0.27          at age 19 (6777

## Comparing generation to real data

In [122]:
dm.train_set.view_sample(1, report_time=True)

Time to retrieve sample index 1 was 0.10558390617370605 seconds

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1997.0

Token                                                                      | Age               | Standardised value
Haemoglobin_estimation_9                                                   | 7416              | -0.08             
HbA1c_level__DCCT_aligned__7                                               | 7416              | -0.27             
Lymphocyte_count_20                                                        | 7416              | -0.01             
Mean_corpusc_haemoglobin_MCH__13                                           | 7416              | -0.06             
Mean_corpuscular_volume__MCV__11                                           | 7416              | -0.12             
Monocyte_count_23                                                          | 7416              | 0.14              
Neutrophil_coun

In [123]:
# Plot loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train), len(loss_curves_train)) * opt.eval_interval
plt.plot(iterations, loss_curves_train, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val), len(loss_curves_val)) * opt.eval_interval
plt.plot(iterations, loss_curves_val, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/competing_risk/loss.png")

# Plot DeSurv loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_surv), len(loss_curves_train_surv)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_surv, label="train")
# Validation
iterations = np.linspace(0, len(loss_curves_val_surv), len(loss_curves_val_surv)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_surv, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/competing_risk/loss_desurv.png")

# Plot value loss
plt.figure()
# Training
iterations = np.linspace(0, len(loss_curves_train_values), len(loss_curves_train_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_train_values, label="train", )
# Validation
iterations = np.linspace(0, len(loss_curves_val_values), len(loss_curves_val_values)) * opt.eval_interval
plt.plot(iterations, loss_curves_val_values, label="val", linestyle='dashed')
plt.legend()
plt.savefig(f"figs/competing_risk/loss_val.png")

# Prompt testing

## Diagnoses: How related conditions are impacted by each other


In [124]:
exp_prompts = [["DEPRESSION"], ["TYPE1DM"], ["TYPE2DIABETES"], ["TYPE2DIABETES"], ["Never_smoked_tobacco_85"]]
exp_ages = [[20] for _ in range(len(exp_prompts))]
exp_values = [[np.nan] for _ in range(len(exp_prompts))]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, (_exp_prompt, _exp_age, _exp_value) in enumerate(zip(exp_prompts, 
                                                                    exp_ages, 
                                                                    exp_values)):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        (surv, val_dist), _, _ = model(_tokens,
                                       values=_values_scaled,
                                       ages=_ages_in_days,
                                       covariates=covariates,
                                       is_generation=True)
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        for p_idx in range(len(exp_prompts)):
            plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{'->'.join(exp_prompts[p_idx]).lower()}")
        plt.legend()
        plt.savefig(f"figs/competing_risk/diabetes/{event_name}.png")


## Values: How increasing BMI affects diagnosis risk

In [125]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2"
                     ]

_exp_prompt = ["Body_mass_index_3"]
_exp_age = [40]
_exp_values = [[12.], [15.], [18.], [21.], [24.], [30.], [40.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        (surv, val_dist), _, _ = model(_tokens,
                                       values=_values_scaled,
                                       ages=_ages_in_days,
                                       covariates=covariates,
                                       is_generation=True)
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(f"figs/competing_risk/bmi/{event_name}.png")


## Values: How increasing DBP affects diagnosis risk

In [131]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2"
                     ]


_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [40]
_exp_values = [[60.], [70.], [80.], [90.], [100.], [110.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        (surv, val_dist), _, _ = model(_tokens,
                                       values=_values_scaled,
                                       ages=_ages_in_days,
                                       covariates=covariates,
                                       is_generation=True)
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(f"figs/competing_risk/diastolic_blood_pressure/{event_name}.png")


## Values: How varying diagnosis affects value of DBP

In [127]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompts = [["DEPRESSION"], ["TYPE2DIABETES"], ["HF_V3"], ["HYPERTENSION"]]
_exp_age = [20]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_prompt in enumerate(_exp_prompts):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        (surv, val_dist), _, _ = model(_tokens,
                                       values=_values_scaled,
                                       ages=_ages_in_days,
                                       covariates=covariates,
                                       is_generation=True)

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")



DEPRESSION                    leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)
TYPE2DIABETES                 leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)
HF_V3                         leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)
HYPERTENSION                  leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [128]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompt = ["Body_mass_index_3"]
_exp_values = [[12.], [15.], [18.], [21.], [24.], [30.], [40.]]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        (surv, val_dist), _, _ = model(_tokens,
                                       values=_values_scaled,
                                       ages=_ages_in_days,
                                       covariates=covariates,
                                       is_generation=True)

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)} of {_exp_value[0]}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        

Body_mass_index_3 of 12.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.3, 0.1)
Body_mass_index_3 of 15.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.2, 0.1)
Body_mass_index_3 of 18.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.2, 0.1)
Body_mass_index_3 of 21.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.2, 0.1)
Body_mass_index_3 of 24.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.1, 0.1)
Body_mass_index_3 of 30.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(-0.0, 0.1)
Body_mass_index_3 of 40.0     leads to            standardised Diastolic_blood_pressure_5 ~ N(0.1, 0.1)


## Baseline, impact of gender

In [129]:

events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2"
                     ]

_genders = ["M", "F", "I"]
_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [20]
_exp_value = [90.]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _gender in enumerate(_genders):

        _baseline_covariate = {"sex": _gender, "deprivation": 4.0, "ethnicity": "WHITE", "year_of_birth": 1997}
        _covariates = dm.train_set._encode_covariates(**_baseline_covariate).reshape(1,-1).to(device)
        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        (surv, val_dist), _, _ = model(_tokens,
                                       values=_values_scaled,
                                       ages=_ages_in_days,
                                       covariates=_covariates,
                                       is_generation=True)
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_genders)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_genders[p_idx]}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(f"figs/competing_risk/gender/{event_name}.png")


# Appendix: model architectures

In [130]:
display(model)

SurvStreamGPTForCausalModelling(
  (transformer): TTETransformer(
    (wpe): TemporalPositionalEncoding()
    (wte): DataEmbeddingLayer(
      (static_proj): Linear(in_features=16, out_features=384, bias=True)
      (dynamic_embedding_layer): SplitDynamicEmbeddingLayer(
        (cat_event_embed_layer): Embedding(184, 384, padding_idx=0)
        (cat_event_proj): Linear(in_features=384, out_features=384, bias=True)
        (num_value_embed_layer): EmbeddingBag(184, 384, mode=sum, padding_idx=0)
        (num_value_proj): Linear(in_features=384, out_features=384, bias=True)
      )
    )
    (drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadedSelfAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=384, out_features=384, bias=False)
          (v_proj): Lin

In [118]:
!jupyter nbconvert --to html --no-input competing_risk.ipynb

[NbConvertApp] Converting notebook competing_risk.ipynb to html
[NbConvertApp] Writing 611212 bytes to competing_risk.html
