# Demo Notebook:
## Survival Transformer For Causal Sequence Modelling 

Including time, and excluding tabular values

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-env-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvStreamGPT


In [2]:
import pytorch_lightning
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from pycox.evaluation import EvalSurv
from tqdm import tqdm 

# TODO:
# replace experiment boilerplate with pytorch lightning

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


## Build configurations

In [3]:
# Set config to be equivalent architecture of kaparthy benchmark, however they are not comparable tasks.
@dataclass
class DemoConfig:
    block_size: int = 128        # what is the maximum context length for predictions?
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    bias: bool = True
    attention_type: str = "global"    
    dropout: float = 0.0
    unk_freq_threshold: float = 0.0
    SurvLayer = "Single-Risk"                                  # "Competing-Risk"
    tokens_for_univariate_regression = None

config = DemoConfig()

@dataclass
class OptConfig:
    batch_size: int = 64
    eval_interval: int = 1
    learning_rate: float = 3e-4
    epochs: int = 30
    
opt = OptConfig()

## Create data loader on a reduced cohort

In [15]:
# from CPRD.data.database import queries

# # Get a list of patients which fit a reduced set of criterion
# PATH_TO_DB = "/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModel/preprocessing/processed/cprd.db"
# conn = sqlite3.connect(PATH_TO_DB)
# cursor = conn.cursor()
# # identifiers1 = queries.query_measurement(["bmi", "diastolic_blood_pressure"], cursor)        
# identifiers2 = queries.query_diagnosis(["DEPRESSION", "TYPE1DM", "TYPE2DIABETES"], cursor)    #  "DEPRESSION"  ,  "ANXIETY"
# # all_identifiers = list(set(identifiers1).intersection(identifiers2))    # Turn smaller list into the set
# all_identifiers = identifiers2

# if True:
#     # Lets take only the first N for faster run-time
#     N = np.min((len(all_identifiers), 20000))
#     print(f"Using N={N} random samples, from the available {len(all_identifiers)}")
#     identifiers = random.choices(all_identifiers, k=N)
# else:
#     print(f"Using all available {len(all_identifiers)} samples")
#     identifiers = all_identifiers

# # Build 
# dm = FoundationalDataModule(identifiers=identifiers,
#                             tokenizer="tabular",
#                             batch_size=opt.batch_size,
#                             max_seq_length=config.block_size,
#                             unk_freq_threshold=config.unk_freq_threshold,
#                             include_measurements=True,
#                             include_diagnoses=True,
#                             preprocess_measurements=True
#                            )


# vocab_size = dm.train_set.tokenizer.vocab_size

# print(f"{len(dm.train_set)} training samples")
# print(f"{len(dm.val_set)} validation samples")
# print(f"{len(dm.test_set)} test samples")
# print(f"{vocab_size} vocab elements")
# # print(dm.train_set.tokenizer._itos)

In [13]:
# Get a list of patients which fit a reduced set of criterion
# path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/archive/Version2/"
path_to_db = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/"

# Build 
dm = FoundationalDataModule(path_to_db=path_to_db,
                            load=True,
                            tokenizer="tabular",
                            batch_size=opt.batch_size,
                            max_seq_length=config.block_size,
                            unk_freq_threshold=config.unk_freq_threshold,
                            min_workers=20,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

INFO:root:Loading Polars dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/polars/
INFO:root:Using tokenizer tabular
INFO:root:Tokenzier created based on 3584.43M tokens
INFO:root:Creating split=train/ dataset
INFO:root:	 Loading split=train/ hash map for parquet
INFO:root:	 Hash map created for split=train/ with 22,842,428 samples
INFO:root:Creating split=test/ dataset
INFO:root:	 Loading split=test/ hash map for parquet
INFO:root:	 Hash map created for split=test/ with 1,417,644 samples
INFO:root:Creating split=val/ dataset
INFO:root:	 Loading split=val/ hash map for parquet
INFO:root:	 Hash map created for split=val/ with 1,169,583 samples


184 vocab elements


In [5]:
import time
start = time.time()   # starting time
for batch in dm.train_dataloader():
    break
print(f"batch loaded in {time.time()-start} seconds")    
    
for key in batch.keys():
    print(f"{key}".ljust(20) + f"{batch[key].shape}")


batch loaded in 8.500611543655396 seconds
tokens              torch.Size([64, 128])
ages                torch.Size([64, 128])
values              torch.Size([64, 128])
attention_mask      torch.Size([64, 128])


## View the frequency of tokens in the extracted data

In [9]:
import polars as pl
pl.Config.set_tbl_rows(1000)
display(dm.tokenizer._event_counts)

EVENT,COUNT,FREQUENCY
str,u32,f64
"""UNK""",0,0.0
"""ADDISONS_DISEA…",6691,2e-06
"""CYSTICFIBROSIS…",7053,2e-06
"""SYSTEMIC_SCLER…",8772,2e-06
"""SICKLE_CELL_DI…",11159,3e-06
"""ADDISON_DISEAS…",11794,3e-06
"""DOWNSSYNDROME""",17006,5e-06
"""HAEMOCHROMATOS…",18631,5e-06
"""PLASMACELL_NEO…",20301,6e-06
"""SJOGRENSSYNDRO…",23326,7e-06


In [7]:
import pandas as pd
pd.set_option('display.max_rows', 1000) #replace n with the number of columns you want to see completely
display(dm.train_set.meta_measurement)

Unnamed: 0,event,count,count_obs,digest,min,max,mean,approx_lqr,approx_uqr
0,25_Hydroxyvitamin_D2_level_92,782791,693470,"({'m': 0.0, 'c': 9.0}, {'m': 0.1, 'c': 112.0},...",0.0,686.0,3.908721,-4.699694,10.870832
1,25_Hydroxyvitamin_D3_level_90,809104,781118,"({'m': 0.1, 'c': 3.0}, {'m': 1.0, 'c': 314.0},...",0.0,951.8,47.14889,-36.308194,121.286799
2,AST___aspartate_transam_SGOT__46,1738489,1680613,"({'m': 0.0, 'c': 3901.0}, {'m': 0.770571428571...",0.0,15330.0,26.61963,3.417134,41.771075
3,AST_serum_level_47,10837982,10485351,"({'m': 0.0, 'c': 53.0}, {'m': 1.8, 'c': 1.0}, ...",-5.0,20700.0,27.25168,4.558863,41.966985
4,Albumin___creatinine_ratio_37,180911,78420,"({'m': -1.0, 'c': 1.0}, {'m': 0.0, 'c': 4213.0...",-1.0,12821.0,10.67255,-4.329046,8.827713
5,Basophil_count_22,86869779,85642540,"({'m': 0.0, 'c': 37098.0}, {'m': 0.01, 'c': 28...",-0.1,111111.0,0.05008992,-0.093801,0.160919
6,Blood_calcium_level_38,415717,385464,"({'m': 0.0, 'c': 33.0}, {'m': 1.0, 'c': 1.0}, ...",0.0,440.0,2.35298,2.025402,2.62252
7,Blood_urea_28,785766,671861,"({'m': 0.0, 'c': 2746.0}, {'m': 0.09, 'c': 1.0...",0.0,1265.0,6.513018,0.270987,10.954279
8,Body_mass_index_3,99868822,97759312,"({'m': 0.0, 'c': 14.0}, {'m': 0.05, 'c': 1.0},...",-32680.0,2100000000.0,293.305,10.476686,43.320395
9,Brain_natriuretic_peptide_level_66,229202,159318,"({'m': 0.0, 'c': 120.0}, {'m': 0.1, 'c': 1.0},...",0.0,500142.0,416.8786,-245.175243,483.219601


In [11]:
# Extract the measurements, using the fact that the diagnoses are all up upper case. This is needed for automatically setting the configuration below
measurements_for_univariate_regression = [record for record in dm.tokenizer._event_counts["EVENT"] if record.upper() != record]

display(measurements_for_univariate_regression)

# print(dm.encode(measurements_for_univariate_regression))
# print(dm.decode([7,4,3,2]))

['Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70',
 'N_terminal_pro_brain_natriuretic_peptide_level_67',
 'Plasma_B_natriuretic_peptide_level_69',
 'Plasma_pro_brain_natriuretic_peptide_level_64',
 'Albumin___creatinine_ratio_37',
 'Urine_microalbumin_creatinine_ratio_36',
 'Plasma_ferritin_level_62',
 'Brain_natriuretic_peptide_level_66',
 'Serum_pro_brain_natriuretic_peptide_level_65',
 'Serum_vitamin_D2_level_89',
 'Total_25_hydroxyvitamin_D_level_91',
 'Serum_N_terminal_pro_B_type_natriuretic_peptide_conc_68',
 'Blood_calcium_level_38',
 'INR___international_normalised_ratio_81',
 'Combined_total_vitamin_D2_and_D3_level_93',
 'TSH_level_74',
 'Serum_T4_level_78',
 'Plasma_cholesterol_HDL_ratio_96',
 'Plasma_free_T4_level_77',
 '25_Hydroxyvitamin_D2_level_92',
 'Blood_urea_28',
 '25_Hydroxyvitamin_D3_level_90',
 'Plasma_corrected_calcium_level_43',
 'Serum_25_Hydroxy_vitamin_D3_level_88',
 'Plasma_calcium_level_40',
 'Free_T4_level_76',
 'Plasma_LDL_cholesterol_level_104',

## Create models and train

In [14]:
models, m_names = [], []

# My development model
for surv_layer in ["Single-Risk"]: #, "Competing-Risk"]:
    
    ## Create configuration
    config = DemoConfig()
    # Specify which survival head layer to use
    config.SurvLayer = surv_layer   
    # list of univariate measurements to model with Normal distribution
    config.tokens_for_univariate_regression = dm.encode(measurements_for_univariate_regression) 
    
    models.append(SurvStreamGPTForCausalModelling(config, vocab_size).to(device))
    m_names.append(f"SurvStreamGPTForCausalModelling: {surv_layer}")

INFO:root:Using Temporal Positional Encoding. This module uses the patient's age at an event within their time series.
INFO:root:Using Single-Risk DeSurvival head. This module predicts a separate survival curve for each possible future event
INFO:root:Internally scaling time in survival head by 1825 days
INFO:root:In generation forwarding DeSurv on the grid between [0.0, 1825.0], with delta=1/300
INFO:root:ModuleDict(
  (Token 15): Linear(in_features=384, out_features=2, bias=True)
  (Token 17): Linear(in_features=384, out_features=2, bias=True)
  (Token 24): Linear(in_features=384, out_features=2, bias=True)
  (Token 26): Linear(in_features=384, out_features=2, bias=True)
  (Token 41): Linear(in_features=384, out_features=2, bias=True)
  (Token 46): Linear(in_features=384, out_features=2, bias=True)
  (Token 49): Linear(in_features=384, out_features=2, bias=True)
  (Token 50): Linear(in_features=384, out_features=2, bias=True)
  (Token 52): Linear(in_features=384, out_features=2, bias

In [15]:
loss_curves_train = [[] for _ in models]
loss_curves_train_clf = [[] for _ in models]
loss_curves_train_surv = [[] for _ in models]
loss_curves_train_values = [[] for _ in models]

loss_curves_val = [[] for _ in models]
loss_curves_val_clf = [[] for _ in models]
loss_curves_val_surv = [[] for _ in models]
loss_curves_val_values = [[] for _ in models]

In [16]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Training model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=opt.learning_rate)

    best_val, epochs_since_best = np.inf, 0
    for epoch in range(opt.epochs):
        
        epoch_loss, epoch_surv_loss, epoch_values_loss = 0, 0, 0
        model.train()
        for i, batch in tqdm(enumerate(dm.train_dataloader()), desc=f"Training epoch {epoch}", total=len(dm.train_dataloader())):
            if i > 50:
                break

            # evaluate the loss
            _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device), 
                                                        ages=batch['ages'].to(device), 
                                                        values=batch['values'].to(device),
                                                        attention_mask=batch['attention_mask'].to(device)   
                                                        )
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            # record
            epoch_loss += loss.item()            
            epoch_surv_loss += torch.sum(losses_desurv).item()
            epoch_values_loss += loss_values.item()
        
        epoch_loss /= i
        epoch_surv_loss /= i
        epoch_values_loss /= i
        loss_curves_train[m_idx].append(epoch_loss)
        loss_curves_train_surv[m_idx].append(epoch_surv_loss)
        loss_curves_train_values[m_idx].append(epoch_values_loss)

        # evaluate the loss on val set
        with torch.no_grad(): 
            model.eval()
            if epoch % opt.eval_interval == 0 or epoch == opt.epochs - 1:
                val_loss, val_surv_loss, val_values_loss = 0, 0, 0
                for j, batch in tqdm(enumerate(dm.val_dataloader()), desc=f"Validation epoch {epoch}", total=len(dm.val_dataloader())):
                    if j > 20:
                        break
                    _, (losses_desurv, loss_values), loss = model(batch['tokens'].to(device),
                                                                   ages=batch['ages'].to(device),
                                                                   values=batch['values'].to(device),
                                                                   attention_mask=batch['attention_mask'].to(device)
                                                                  )
                    # record
                    val_loss += loss.item()                    
                    val_surv_loss += torch.sum(losses_desurv).item()
                    val_values_loss += loss_values.item()
                    
                val_loss /= j
                val_surv_loss /= j
                val_values_loss /= j
                loss_curves_val[m_idx].append(val_loss)
                loss_curves_val_surv[m_idx].append(val_surv_loss)
                loss_curves_val_values[m_idx].append(val_values_loss)

                print(f"Epoch {epoch}:\tTrain loss {epoch_loss:.2f}: ({epoch_surv_loss:.2f}, {epoch_values_loss:.2f}). Val loss {val_loss:.2f}: ({val_surv_loss:.2f}, {val_values_loss:.2f})")          
                # TODO: Note not fully accurate as last batch is likely not the same size, will be fixed with lightning
        
            if val_loss >= best_val:
                epochs_since_best += 1
                if epochs_since_best >= 5:
                    break
            else:
                best_val = val_loss
                epochs_since_best = 0

    

Training model `SurvStreamGPTForCausalModelling: Single-Risk`, with 10.865304 M parameters


Training epoch 0:   0%|          | 51/356913 [00:34<67:24:19,  1.47it/s]
Validation epoch 0:   0%|          | 21/356913 [00:10<51:42:21,  1.92it/s]


Epoch 0:	Train loss 4.52: (3.85, 5.18). Val loss -0.78: (2.84, -4.41)


Training epoch 1:   0%|          | 51/356913 [00:32<62:53:29,  1.58it/s]
Validation epoch 1:   0%|          | 21/356913 [00:10<50:46:52,  1.95it/s]


Epoch 1:	Train loss -0.76: (3.63, -5.15). Val loss -2.49: (2.67, -7.65)


Training epoch 2:   0%|          | 51/356913 [00:32<62:38:09,  1.58it/s]
Validation epoch 2:   0%|          | 21/356913 [00:10<50:44:36,  1.95it/s]


Epoch 2:	Train loss -1.74: (3.45, -6.93). Val loss -3.20: (2.59, -8.98)


Training epoch 3:   0%|          | 51/356913 [00:33<64:08:31,  1.55it/s]
Validation epoch 3:   0%|          | 21/356913 [00:10<50:41:32,  1.96it/s]


Epoch 3:	Train loss -2.16: (3.31, -7.64). Val loss -3.49: (2.52, -9.49)


Training epoch 4:   0%|          | 51/356913 [00:32<62:16:56,  1.59it/s]
Validation epoch 4:   0%|          | 21/356913 [00:10<50:45:03,  1.95it/s]


Epoch 4:	Train loss -2.36: (3.42, -8.13). Val loss -3.93: (2.47, -10.32)


Training epoch 5:   0%|          | 51/356913 [00:32<62:18:25,  1.59it/s]
Validation epoch 5:   0%|          | 21/356913 [00:10<50:57:11,  1.95it/s]


Epoch 5:	Train loss -2.61: (3.21, -8.43). Val loss -4.25: (2.43, -10.93)


Training epoch 6:   0%|          | 51/356913 [00:32<62:12:29,  1.59it/s]
Validation epoch 6:   0%|          | 21/356913 [00:10<50:55:35,  1.95it/s]


Epoch 6:	Train loss -2.84: (3.08, -8.75). Val loss -4.07: (2.38, -10.53)


Training epoch 7:   0%|          | 51/356913 [00:32<63:23:57,  1.56it/s]
Validation epoch 7:   0%|          | 21/356913 [00:10<50:46:45,  1.95it/s]


Epoch 7:	Train loss -2.98: (3.14, -9.11). Val loss -4.40: (2.32, -11.12)


Training epoch 8:   0%|          | 51/356913 [00:32<63:15:05,  1.57it/s]
Validation epoch 8:   0%|          | 21/356913 [00:10<51:00:09,  1.94it/s]


Epoch 8:	Train loss -2.09: (3.10, -7.27). Val loss -3.22: (2.32, -8.75)


Training epoch 9:   0%|          | 51/356913 [00:31<61:13:25,  1.62it/s]
Validation epoch 9:   0%|          | 21/356913 [00:10<50:54:46,  1.95it/s]


Epoch 9:	Train loss -1.90: (2.97, -6.78). Val loss -4.15: (2.28, -10.57)



KeyboardInterrupt



In [17]:
for m_idx, (model, m_name) in enumerate(zip(models, m_names)):
    
    print(f"Generating from model `{m_name}`, with {sum(p.numel() for p in model.parameters())/1e6} M parameters")
    model = model.to(device)
    
    # # generate: sample the next 10 tokens
    # new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)

    # generated = dm.decode(new_tokens[0].tolist())
    # # report:
    # for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
    #     print(_value)
        
    #     try:
    #         _value = unstandardise(_cat, _value)
    #     except:
    #         pass
    #     print(f"\t {_cat}:{_value:.02f}, at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}

    # Test trained model with a prompt
    # ----------------    
    # set context: diagnosis of depression at 20 years old
    tokens = torch.from_numpy(np.array(dm.encode(["DEPRESSION"])).reshape((1,-1))).to(device)
    ages = torch.tensor([[20*365]], device=device)
    values = torch.tensor([[torch.nan]], device=device)
    
    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages, values, max_new_tokens=10)
    generated = dm.decode(new_tokens[0].tolist())
    # report:
    for _cat, _age, _value in zip(generated.split(" "), new_ages[0, :], new_values[0, :]):
        try:
            _value = unstandardise(_cat, _value)
        except:
            pass
        print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")    # with value {_value}

Generating from model `SurvStreamGPTForCausalModelling: Single-Risk`, with 10.865304 M parameters
DEPRESSION                                        nan            at age 20 (7300.0 days)
Serum_ferritin_63                                 0.47           at age 23 (8496.3 days)
HYPERTHYROIDISM_V2                                nan            at age 28 (10321.3 days)
Haematocrit___PCV_16                              0.39           at age 29 (10577.7 days)
Plasma_N_terminal_pro_B_type_natriuretic_peptide_conc_70-0.01          at age 34 (12402.7 days)
Serum_vitamin_D2_level_89                         0.54           at age 38 (13959.1 days)
Total_white_cell_count_18                         0.66           at age 40 (14484.0 days)
Serum_folate_80                                   0.36           at age 43 (15546.1 days)
Plasma_HDL_cholesterol_level_101                  0.83           at age 47 (17249.0 days)
Serum_potassium_26                                0.43           at age 52 (19074.0 days

## Comparing output to real data

In [18]:
for batch in dm.train_dataloader():
    break
    
conditions = batch["tokens"].numpy().tolist()
# delta_ages = batch["ages"][:, 1:] - batch["ages"][:, :-1]
for idx, (token, _age, _value) in enumerate(zip(conditions[0], batch["ages"][0,:],  batch["values"][0,:])):
    if token == 0 or idx >= 10:
        break
    _cat = dm.decode([token])
    try:
        _value = unstandardise(_cat, _value)
    except:
        pass
        
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({_age:.1f} days)")

ASTHMA_PUSHASTHMA                                 nan            at age 14 (5081.0 days)
Diastolic_blood_pressure_5                        0.54           at age 30 (10993.0 days)
Systolic_blood_pressure_4                         0.37           at age 30 (10993.0 days)
O_E___height_1                                    0.37           at age 30 (10995.0 days)
Body_mass_index_3                                 0.33           at age 31 (11210.0 days)
O_E___height_1                                    0.37           at age 31 (11210.0 days)
O_E___weight_2                                    0.32           at age 31 (11210.0 days)
Diastolic_blood_pressure_5                        0.20           at age 31 (11275.0 days)
Systolic_blood_pressure_4                         0.37           at age 31 (11275.0 days)
Body_mass_index_3                                 0.28           at age 32 (11694.0 days)


In [29]:
cols = ["k", "r", "b", "y"]

# Plot loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train[m_idx]), len(loss_curves_train[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val[m_idx]), len(loss_curves_val[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss.png")

# Plot DeSurv loss
plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_surv[m_idx]), len(loss_curves_train_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_surv[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_surv[m_idx]), len(loss_curves_val_surv[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_surv[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss_desurv.png")

plt.figure()
for m_idx, _ in enumerate(models):
    # Training
    iterations = np.linspace(0, len(loss_curves_train_values[m_idx]), len(loss_curves_train_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_train_values[m_idx], label=f"{m_names[m_idx]}-train", c=cols[m_idx], linestyle='dashed')
    # Validation
    iterations = np.linspace(0, len(loss_curves_val_values[m_idx]), len(loss_curves_val_values[m_idx])) * opt.eval_interval
    plt.plot(iterations, loss_curves_val_values[m_idx], label=f"{m_names[m_idx]}-val", c=cols[m_idx])
plt.legend()
plt.savefig(f"figs/single_risk/loss_val.png")

# Prompt testing

## Diabetes: How related conditions are impacted by each other
Probability of type II diabetes before and after a type I diagnosis

In [20]:
t1_token = dm.tokenizer._stoi["TYPE1DM"]
t2_token = dm.tokenizer._stoi["TYPE2DIABETES"]


base_prompt = ["DEPRESSION"]
ages_in_years = [20]
base_values = [torch.tensor([torch.nan])]

to_days = lambda a_list: torch.FloatTensor([365 * _a for _a in a_list]).reshape((1,-1)).to(device)

# Create a set of prompts
prompts, ages, values, desc = [], [], [], []
# control prompt
desc.append("Depression")
prompts.append(base_prompt)
ages.append(ages_in_years)
values.append(base_values)
# prompt with type 1 diabetes
desc.append("Depression -> Type 1")
prompts.append(base_prompt + ["TYPE1DM"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

desc.append("Depression - > Type 2")
prompts.append(base_prompt + ["TYPE2DIABETES"])
ages.append(ages_in_years + [21])
values.append(base_values + [torch.tensor([torch.nan])])

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, (prompt, age, value) in enumerate(zip(prompts, ages, values)):
            print(f"\n{desc[p_idx]}: \t ({','.join(prompt)}): ")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=torch.tensor(value).to(device),
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            for p_idx in range(len(prompts)):
                plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"{desc[p_idx]}")
            plt.legend()
            plt.savefig(f"figs/single_risk/diabetes/{event_name}.png")




SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------

Depression: 	 (DEPRESSION): 

Depression -> Type 1: 	 (DEPRESSION,TYPE1DM): 

Depression - > Type 2: 	 (DEPRESSION,TYPE2DIABETES): 


## Values: How increasing BMI affects diagnosis risk

In [30]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2"
                     ]
prompt = ["Body_mass_index_3"]
# values = [torch.tensor([standardise(_cat, v) for _cat in prompt], device=device) for v in [12.,15.,18.,21.,24.,30.,40.]]
values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(0,1,5)]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    # bmi_value = unstandardise("bmi", values[p_idx])
                    bmi_value = values[p_idx]
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"BMI {bmi_value.item():.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/single_risk/bmi/{event_name}.png")




SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------
Value tensor([0.], device='cuda:0')
Value tensor([0.2500], device='cuda:0')
Value tensor([0.5000], device='cuda:0')
Value tensor([0.7500], device='cuda:0')
Value tensor([1.], device='cuda:0')


Exception ignored in: <function _ConnectionBase.__del__ at 0x7fe37dab9cf0>
Traceback (most recent call last):
  File "/rds/bear-apps/2022a/EL8-ice/software/Python/3.10.4-GCCcore-11.3.0/lib/python3.10/multiprocessing/connection.py", line 137, in __del__
    self._close()
  File "/rds/bear-apps/2022a/EL8-ice/software/Python/3.10.4-GCCcore-11.3.0/lib/python3.10/multiprocessing/connection.py", line 366, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor


## Values: How increasing diastolic_blood_pressure affects likelihood of diagnoses

In [25]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2"
                     ]

prompt = ["Diastolic_blood_pressure_5"]
# values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [60.,70.,80.,90.,100.,120.]]
values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(0,1,5)]
age = [40]


for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

        prompt_survs = []
        for p_idx, value in enumerate(values):
            print(f"Value {value}\n======")
            encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
            (surv, val_dist), _, _ = model(encoded_prompt,
                                           values=value,
                                           ages=to_days(age),
                                           is_generation=True)
            prompt_survs.append(surv)

        for si, _ in enumerate(surv):
            plt.close()
            event_name = dm.decode([si + 1])
            
            if event_name in events_of_interest:
                
                for p_idx in range(len(prompt_survs)):
                    # dbp_value = unstandardise("diastolic_blood_pressure", values[p_idx])
                    dbp_value = values[p_idx]
                    plt.plot(model.surv_layer.t_eval / 365, prompt_survs[p_idx][si][0, :], label=f"DBP {dbp_value.item():.2f}")
                plt.xlabel("t (years)")
                plt.ylabel("P(T>t) ()")
                plt.legend()
                plt.savefig(f"figs/single_risk/diastolic_blood_pressure/{event_name}.png")




SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------
Value tensor([0.], device='cuda:0')
Value tensor([0.2500], device='cuda:0')
Value tensor([0.5000], device='cuda:0')
Value tensor([0.7500], device='cuda:0')
Value tensor([1.], device='cuda:0')


## Values: How varying diagnosis affects value of diastolic_blood_pressure

In [26]:
measurements_of_interest = ["Diastolic_blood_pressure_5"]
t1_token = dm.tokenizer._stoi["Diastolic_blood_pressure_5"]

diagnoses = [["DEPRESSION"],["TYPE2DIABETES"], ["HF"], ["HYPERTENSION"]]
values = torch.tensor([torch.nan], device=device)
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")
    with torch.no_grad(): 
        model.eval()

    for p_idx, diagnosis in enumerate(diagnoses):
        print(f"\nDiagnosis {diagnosis}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(diagnosis)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=values,
                                       ages=to_days(age),
                                       is_generation=True)
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")





SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------

Diagnosis ['DEPRESSION']
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Diagnosis ['TYPE2DIABETES']
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Diagnosis ['HF']
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Diagnosis ['HYPERTENSION']
standardised diastolic_blood_pressure ~ N(0.6, 0.2)


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [28]:
t1_token = dm.tokenizer._stoi["Diastolic_blood_pressure_5"]

prompt = ["Body_mass_index_3"]
# values = [torch.tensor([standardise(_cat, _value) for _cat in prompt], device=device) for _value in [12.,15.,18.,21.,24.,30.,40.,50.]]
values = [torch.tensor([float(v) for _cat in prompt], device=device) for v in np.linspace(0,1,5)]
age = [40]

for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n--------------------------------------")

    for p_idx, value in enumerate(values):
        print(f"\nValues {value.tolist()}\n======")
        encoded_prompt = torch.from_numpy(np.array(dm.encode(prompt)).reshape((1,-1))).to(device)
        (surv, val_dist), _, _ = model(encoded_prompt,
                                       values=value,
                                       ages=to_days(age),
                                       is_generation=True)
        
        dist = val_dist[model.value_layer.token_key(t1_token)]
        print(f"standardised diastolic_blood_pressure ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")
        # print(f"\tprobability of type II diabetes: {100*float(probs[0, 0, t2_token].cpu().detach().numpy()):.4f}%")



SurvStreamGPTForCausalModelling: Single-Risk
--------------------------------------

Values [0.0]
standardised diastolic_blood_pressure ~ N(0.4, 0.2)

Values [0.25]
standardised diastolic_blood_pressure ~ N(0.4, 0.2)

Values [0.5]
standardised diastolic_blood_pressure ~ N(0.5, 0.2)

Values [0.75]
standardised diastolic_blood_pressure ~ N(0.6, 0.2)

Values [1.0]
standardised diastolic_blood_pressure ~ N(0.6, 0.2)


# Appendix: model architectures

In [None]:
for model_idx, model in enumerate(models):
    print(f"\n\n{m_names[model_idx]}\n" + "="*len(m_names[model_idx]))
    print(f"\n\n{model}")

In [None]:
!jupyter nbconvert --to html --no-input single_risk.ipynb