# Generation Demo Notebook:
## SurvivEHR: Competing Risk Survival Transformer For Causal Sequence Modelling 

In this notebook we demonstrate how different risk factors contribute to the generation of future patient timelines

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2
%env SLURM_NTASKS_PER_NODE=28       # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvivEHR/notebooks/CompetingRisk/0_pretraining
env: SLURM_NTASKS_PER_NODE=28       # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from contextlib import redirect_stdout
import pandas as pd
from tqdm import tqdm
import warnings

from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.examples.modelling.SurvivEHR.setup_causal_experiment import CausalExperiment
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from CPRD.examples.data.map_to_reduced_names import convert_event_names, EVENT_NAME_SHORT_MAP

from FastEHR.dataloader import FoundationalDataModule
from FastEHR.database.collector import SQLiteDataCollector

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')
warnings.simplefilter('error', np.VisibleDeprecationWarning)

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Demonstration of how input can affect future risk

In [57]:
def map_batch_to_t2dm_risk_profile(batch, risk="low"):
    """
    Given a loaded batch, replace certain risk-factors for cardiovascular disease and t2dm to three levels: low, medium, and high.
    This can then be used to see the 'what-if?' effect of changing these factors in next-event prediction risk
    """

    match risk:
        case "low":            
            # Normal person
            static_covariates = dm.test_set._encode_covariates("F", 1.0, "ASIAN", 1963)

            risk_event_map = {"Ex_smoker_84": "Never_smoked_tobacco_85",
                              "Current_smoker_83": "Never_smoked_tobacco_85",
                             }
            risk_value_map = {"Diastolic_blood_pressure_5": 80,
                              "Systolic_blood_pressure_4": 120,
                              "Body_mass_index_3": 24,
                              "O_E___weight_2": np.nan,
                             }
            # new_risk_events = []
            
        case "medium" | "mid":
            # At some risk
            static_covariates = dm.test_set._encode_covariates("F", 3.0, "ASIAN", 1963)

            risk_event_map = {"Never_smoked_tobacco_85": "Ex_smoker_84",
                              "Current_smoker_83": "Ex_smoker_84"
                             }
            risk_value_map = {"Diastolic_blood_pressure_5": 90,
                              "Systolic_blood_pressure_4": 140,
                              "Body_mass_index_3": 28,
                              "O_E___weight_2": np.nan,
                             }
            # new_risk_events = []
                        
        case "high":
            # At higher risk
            static_covariates = dm.test_set._encode_covariates("F", 5.0, "ASIAN", 1963)

            risk_event_map = {"Never_smoked_tobacco_85": "Current_smoker_83",
                              "Ex_smoker_84": "Current_smoker_83"
                             }
            risk_value_map = {"Diastolic_blood_pressure_5": 100,
                              "Systolic_blood_pressure_4": 150,
                              "Body_mass_index_3": 32,
                              "O_E___weight_2": np.nan,
                             }
            # new_risk_events = ["ALCOHOLMISUSE_V2"]
            
        case _:
            raise NotImplementedError

    # Shape
    bsz, L = batch["tokens"].shape
    device = batch["tokens"].device

    # Tokenize maps
    ###############
    
    # Put all event conversions into token form
    token_event_map = {}
    for old_key, old_item in risk_event_map.items():
        new_key = dm.encode([old_key])[0]
        new_item = dm.encode([old_item])[0]
        token_event_map.update({new_key: new_item})
    
    # Put all value conversions into token form
    token_value_map = {}
    for old_key, item in risk_value_map.items():
        new_key = dm.encode([old_key])[0]
        standardised_item = dm.standardise(old_key, item)
        token_value_map.update({new_key: standardised_item})

    # Update patient profiles
    #########################

    # Set everyone in batch to the same risk-profile's baseline static covariates
    static_covariates = torch.tile(static_covariates, (bsz, 1))
    
    # apply event conversions
    tokens = batch["tokens"].clone()
    for old, new in token_event_map.items():
        mask           = (batch["tokens"] == old)
        tokens[mask]   = new

    # Apply value conversion
    values = batch["values"].clone()
    for old, new in token_value_map.items():
        mask           = (batch["tokens"] == old)
        values[mask]   = new

    new_batch = {
        "static_covariates": static_covariates.to(device),
        "tokens": tokens.to(device),
        "ages": batch["ages"],
        "values": values.to(device),
        "attention_mask": batch["attention_mask"],
        }
    
    return new_batch

def clip_outliers(token, unstandardised_value):
    """
    Because of heavy right tails in the value distributions, standardisation of some token values over-estimates the lower quantile.
    This method ensures no values beyond those seen in the data are reported.
    """
    try:
        assert not np.isnan(unstandardised_value)
        token_meta = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["event"] == token]
        _min = token_meta["min"]
        _max = token_meta["max"]
        unstandardised_value = np.min((unstandardised_value, _max))
        unstandardised_value = np.max((unstandardised_value, _min))
    except:
        pass
        
    return unstandardised_value

def report_generation(static, tokens, ages, values, attention_mask, true_seq_len, dm, eos_token="DEATH", **kwargs):
    """
    """
    tokens = tokens[0, :]
    ages = ages[0, :]
    values = values[0, :]
    attention_mask = attention_mask[0, :]

    static = dm.test_set._decode_covariates(static.cpu())
    print("STATIC INFORMATION")
    print("="*120)
    for key, item in static.items():
        print(f"\t{key}:".ljust(20) + f"{item[0]}")

    # Report
    tokens = dm.tokenizer.decode(tokens.tolist()).split(" ")
    diagnoses = []
    last_age_day, last_age_week = 0, 0
    print("\n\nGiven patient context".upper())
    print("="*120)
    print(f"\tEVENT".ljust(75) + "| AGE IN WEEKS (days, years)".ljust(30) + " | VALUE")
    for idx_event, (token, _age, value, attn_mask) in enumerate(zip(tokens, ages, values, attention_mask)):

        if attn_mask == 0:
            break
            
        # Unscale age and bin to week fidelity
        age_day = int(_age * dm.test_set.time_scale)
        age_week = int(age_day / 7) 
        age_years = int(age_day / 365)

        # If new event create break
        if age_week != last_age_week:
            print("\t" + "."*60 + "new week" + "."*60)

        # Report next event
        age = f"{age_week}\t ({age_day}, {age_years})"
        unstandardised_value = clip_outliers(token, dm.unstandardise(token, value))
        value = f"{unstandardised_value:.2f}".ljust(10) + f"({value:.2f})"
        print(f"\t{token.ljust(75)}| {age.ljust(30)}| {value}".ljust(20))
        
        if token.upper() == token:
            diagnoses.append(token)

        if idx_event == true_seq_len - 1:
            print("\n" + "="*120)
            print("Diagnosis summary".upper())
            print(f"{diagnoses}")
            print("="*120)
            print("\n")
            print("Predicted future events".upper())
            print("="*120)
            print(f"\tEVENT".ljust(75) + "| AGE IN WEEKS (days, years)".ljust(30) + "| VALUE")


        last_age_day = age_day
        last_age_week = age_week
        if token == eos_token:
            break

In [58]:
pre_trained_models = ["SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1"]
config_names = ["config_CompetingRisk11M"]

datasets = [ "PreTrain",
             #"FineTune_Hypertension",
             #"FineTune_CVD", 
             #"FineTune_MultiMorbidity50+"
           ]
patients_of_interest = [[10],
                        #[0],
                        #[10],
                        #[1]
                       ]

risk_levels = ["low", "mid", "high",]
outcomes_of_interest = ["HYPERTENSION", "Body_mass_index_3", "HF_V3", "TYPE2DIABETES",  "TYPE1DM", "DEATH"]


# Generate "what-if?" next-event risks for a handful of selected patients

In [59]:
for pre_trained_model, config_name in zip(pre_trained_models, config_names):
    os.makedirs(f"figs/generation/{pre_trained_model}/", exist_ok=True) 

    # load the configuration file, override any settings 
    with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
        cfg = compose(config_name=config_name, 
                      overrides=[# Experiment setup
                                 f"experiment.run_id='{pre_trained_model}'",
                                 "experiment.train=False",
                                 "experiment.test=False",
                                 "experiment.log=False",
                                 # Dataloader
                                 "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                 "data.min_workers=3",
                                ]
                     )     
    experiment, dm = run(cfg)     
    print(f"Loaded model {pre_trained_model} with {sum(p.numel() for p in experiment.parameters())/1e6} M parameters")
    
    for idx_dataset, dataset in enumerate(datasets):
        print(f"Generating patient's next event risk for dataset {dataset}")        
        gen_save_path = f'figs/generation/{pre_trained_model}/{dataset}_dataset/'
        os.makedirs(gen_save_path, exist_ok=True) 

        # store dataset results in
        data_rows = []

        # Load dataset
        dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                                    path_to_ds=f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/{dataset}/",
                                    overwrite_meta_information=cfg.data.meta_information_path,
                                    load=True,
                                    supervised=False if dataset.lower()=="pretrain" else True,
                                    )
        
        # Load the first batch
        for batch in dm.test_dataloader():
            batch = {k: v.to(device) for k, v in batch.items()}
            break

        # The patients within the batch we are interested in
        bsz, L = batch["tokens"].shape
        patients = patients_of_interest[idx_dataset] if patients_of_interest[idx_dataset] is not None else [i for i in range(bsz)]

        # For each risk level, map the patients of interest to the risk level, report them to file, and plot the next-event risks
        batches_by_risk = []
        for risk_level in risk_levels:

            # Map to risk level
            ###################
            batch_risk_adjusted = map_batch_to_t2dm_risk_profile(batch, risk=risk_level)
            batches_by_risk.append(batch_risk_adjusted)

            # Report mapped timelines for each patient of interest
            ######################################################
            for idx_patient in tqdm(patients, ascii=True, desc=f"Saving {risk_level}-converted timeline for all considered patients in batch"):
                out_dir = gen_save_path + f'patient{idx_patient}/'
                os.makedirs(out_dir, exist_ok=True)
                with open(out_dir + f"mapped_to_risk_level_{risk_level}.txt", 'w') as f:
                    with redirect_stdout(f):
                        report_generation(
                            static         = batch_risk_adjusted["static_covariates"][idx_patient], 
                            tokens         = batch_risk_adjusted["tokens"][[idx_patient],:],
                            ages           = batch_risk_adjusted["ages"][[idx_patient],:], 
                            values         = batch_risk_adjusted["values"][[idx_patient], :], 
                            attention_mask = batch_risk_adjusted["attention_mask"][[idx_patient], :],
                            true_seq_len   = batch_risk_adjusted["attention_mask"][[idx_patient], :].sum(), 
                            dm             = dm
                        )

            # Plot final risks after full timeline has been given to SurvivEHR
            ###################################################################

            for l in tqdm(range(1, L, 1), 
                          ascii=True, desc=f"Recording {risk_level}-converted outcome risks for all considered patients in batch"):
                # Get next-event risks for patients in batch mapped to the risk level
                outputs, _, _  = experiment.model(
                    covariates        = batch_risk_adjusted["static_covariates"],
                    tokens            = batch_risk_adjusted['tokens'][:, :l],
                    ages              = batch_risk_adjusted['ages'][:, :l],
                    values            = batch_risk_adjusted['values'][:, :l],
                    attention_mask    = batch_risk_adjusted['attention_mask'][:, :l],
                    is_generation     = True,
                    return_loss       = False,
                    return_generation = True,
                    )
                pred_surv = outputs["surv"]["surv_CDF"]
                pred_values = outputs["values_dist"]
    
                for idx_patient in patients:
                    out_dir = gen_save_path + f'patient{idx_patient}/'
                    patient_true_seq_len = batch_risk_adjusted["attention_mask"][[idx_patient], :].sum()
                    
                    last_observed_token = batch_risk_adjusted['tokens'][idx_patient, l-1]
                    last_observed_event = dm.decode([last_observed_token.tolist()])

                    
                    if l > patient_true_seq_len:
                        continue
                    
                    for event_name in outcomes_of_interest:
    
                        tkn_of_interest = dm.encode([event_name])[0]
                        event_surv_pred = pred_surv[tkn_of_interest - 1][idx_patient]
    
                        data_rows.append(
                            dict(dataset             = dataset,
                                 risk_level          = risk_level,
                                 patient_idx         = idx_patient,
                                 token               = tkn_of_interest,
                                 survival_risk       = event_surv_pred,
                                 total_survival_risk = np.mean(event_surv_pred),
                                 context_len         = l,
                                 last_observed_event = last_observed_event,
                                 )
                        )

        df = pd.DataFrame(data_rows)
        print(df.head())
        df.to_pickle(gen_save_path + "risk_level_survival_table.pkl")



INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/. This will be loaded in causal form.
INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 toke

Loaded model SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1 with 11.20919 M parameters
Generating patient's next event risk for dataset PreTrain


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,613,894 sample

    dataset risk_level  patient_idx  token  \
0  PreTrain        low           10    129   
1  PreTrain        low           10    247   
2  PreTrain        low           10     74   
3  PreTrain        low           10    100   
4  PreTrain        low           10     36   

                                       survival_risk  total_survival_risk  \
0  [0.0, 9.656151e-08, 1.3165229e-07, 1.4441537e-...         1.514882e-07   
1  [0.0, 3.3511304e-05, 4.8969396e-05, 5.6102104e...         6.210396e-05   
2  [0.0, 5.6980434e-09, 9.533974e-09, 1.2116994e-...         1.740161e-08   
3  [0.0, 3.266222e-08, 5.3049597e-08, 6.578297e-0...         8.682732e-08   
4  [0.0, 1.4359887e-09, 2.2147377e-09, 2.6371827e...         3.131878e-09   

   context_len         last_observed_event  
0            1  Diastolic_blood_pressure_5  
1            1  Diastolic_blood_pressure_5  
2            1  Diastolic_blood_pressure_5  
3            1  Diastolic_blood_pressure_5  
4            1  Diastolic_blood_pre




In [60]:
# for i in list(dm.tokenizer._event_counts["EVENT"]):
#     if "diab" in i.lower():
#         print(i)

# print(batch["static_covariates"].shape)
# # static = dm.test_set._decode_covariates(batch["static_covariates"].cpu())
# # print(static)

# static = dm.test_set._encode_covariates("F", 1.0, "ASIAN", 1963)
# print(static.shape)
# print(dm.test_set._decode_covariates(static))

# print(map_batch_to_t2dm_risk_profile(batch, risk="low"))
# print(map_batch_to_t2dm_risk_profile(batch, risk="high"))

In [79]:
datasets = [ "PreTrain",
             "FineTune_Hypertension",
             "FineTune_CVD", 
             "FineTune_MultiMorbidity50+"]
patients_of_interest = [[10],
                        [0],
                        [10],
                        [1]]


for pre_trained_model, config_name in zip(pre_trained_models, config_names):
    os.makedirs(f"figs/generation/{pre_trained_model}/", exist_ok=True) 
    
    for idx_dataset, dataset in enumerate(datasets):
        gen_save_path = f'figs/generation/{pre_trained_model}/{dataset}_dataset/'
        patients = patients_of_interest[idx_dataset] if patients_of_interest[idx_dataset] is not None else [i for i in range(bsz)]
        df = pd.read_pickle(gen_save_path + "risk_level_survival_table.pkl")

        for idx_patient in patients:
            out_dir = gen_save_path + f'patient{idx_patient}/'

            for event_name in outcomes_of_interest:
    
                tkn_of_interest = dm.encode([event_name])[0]

                # Create colour pallete
                custom_colors = {
                    "low": sns.xkcd_rgb['avocado'], 
                    "mid": sns.xkcd_rgb["golden rod"], 
                    "high": sns.xkcd_rgb["burnt red"],
                }
    
                
                plot_df = (df.query("dataset==@dataset & patient_idx==@idx_patient & token==@tkn_of_interest")
                           .sort_values('context_len', ascending=False)                            # largest first
                           .drop_duplicates(['dataset', "risk_level", "token", 'patient_idx'])     # keep first row it meets per group, the one with largest context
                           .reset_index(drop=True)
                          )

                # Explode along the survival_risk vector to expand the dataframe
                plot_df_long = (
                    plot_df
                      .explode("survival_risk")                 # duplicates meta‑columns
                      .assign(                                  # add the matching time point
                          time_idx=lambda d: d.groupby(
                              ["dataset", "patient_idx", "risk_level"]
                          ).cumcount()
                      )
                )

                # Get the times SurvivEHR forecasts over
                model_time = experiment.model.surv_layer.t_eval
                scale_time = model_time = dm.train_set.time_scale
                plot_df_long["Time to event (years)"] = plot_df_long["time_idx"].apply(lambda i: experiment.model.surv_layer.t_eval[i] * scale_time / 365)

                plt.close()

                plt.figure(figsize=(6*0.9,4*0.9))
                sns.set(style="whitegrid")
                ax = sns.lineplot(
                    data=plot_df_long,
                    x="Time to event (years)", 
                    y="survival_risk",
                    hue="risk_level",
                    legend=False,
                    palette=custom_colors,
                    lw=2.5,
                 )
                
                # ax.set_xlabel("Time index")
                ax.set_ylabel("Risk of event")
                # ax.set_title(f"Survival curves for {event_name} by risk level")
                
                # plt.grid()
                plt.tight_layout()
                plt.savefig(out_dir + f"risk_{event_name}.png") 
        
                # print(event_name)
                # print(event_surv_pred.shape)
                # plt.close()
                # plt.plot(experiment.model.surv_layer.t_eval / 365, ) #  label=f"{'->'.join(exp_prompts[p_idx]).lower()}"
                # plt.savefig(out_dir + f"{risk_level}_{event_name}.png")
                    
                    # for p_idx in range(len(exp_prompts)):
                        # plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{'->'.join(exp_prompts[p_idx]).lower()}")
                        
                    # plt.xlabel("Time (years)")
                    # plt.ylabel(f"$P(T>t)$ ({event_name})")
                    # plt.legend()
                    # plt.savefig(out_dir + f"{event_name}.png")

  ax = sns.lineplot(
  ax = sns.lineplot(
  ax = sns.lineplot(


In [68]:
print(plot_df_long.columns)
print(plot_df_long["risk_level"].unique())


Index(['dataset', 'risk_level', 'patient_idx', 'token', 'survival_risk',
       'total_survival_risk', 'context_len', 'last_observed_event', 'time_idx',
       'Time to event (years)'],
      dtype='object')
['high' 'mid' 'low']


In [56]:
for pre_trained_model, config_name in zip(pre_trained_models, config_names):
    os.makedirs(f"figs/generation/{pre_trained_model}/", exist_ok=True) 
    
    for idx_dataset, dataset in enumerate(datasets):
        gen_save_path = f'figs/generation/{pre_trained_model}/{dataset}_dataset/'
        patients = patients_of_interest[idx_dataset] if patients_of_interest[idx_dataset] is not None else [i for i in range(bsz)]
        df = pd.read_pickle(gen_save_path + "risk_level_survival_table.pkl")

        df["last_observed_event"] = df["last_observed_event"].map(EVENT_NAME_SHORT_MAP)

        for idx_patient in patients:
            out_dir = gen_save_path + f'patient{idx_patient}/'

            for event_name in outcomes_of_interest:
    
                tkn_of_interest = dm.encode([event_name])[0]
                
                plot_df = df.query("dataset==@dataset & patient_idx==@idx_patient & token==@tkn_of_interest")
                          
                plt.close()
                plt.figure(dpi=600)
                sns.set(style="whitegrid")
                ax = sns.lineplot(
                        data=plot_df,
                        x="context_len", y="total_survival_risk",
                        hue="risk_level",
                        marker='.'
                     )

                # add labels here
                for v in plot_df.iterrows():
                    plt.text(v[1][6], v[1][5], f'{v[1][7]}', size=4)

                ax.set_xlabel("Context length")
                ax.set_ylabel("Predicted Restricted Mean Survival Time")
                ax.set_title(f"AUC survival curves for {event_name} against provided context length, stratified by risk level")
                plt.tight_layout()
                plt.savefig(out_dir + f"rmst_{event_name}.png") 
        

In [None]:
for v in plot_df.iterrows():
    print(v[1])
    print(v[1][5])
    print(v[1][6])
    print(v[1][7])
    break

In [None]:
print(idx_patient)
print(df.head())
rows_of_interest = (df.query("dataset==@dataset & patient_idx==@idx_patient")
                    .sort_values('context_len', ascending=False)  # largest first
                    .drop_duplicates(['dataset', "risk_level", "token", 'patient_idx'])     # keep max per group
                    .reset_index(drop=True))
print(rows_of_interest)

# Generate "what-if?" next-event risks for a handful of curated cases

In [None]:
raise NotImplementedError

# Demo Version of SurvStreamGPT

## Build configurations

In [None]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_SingleRisk11M", 
                  # overrides=[
                  #     ]
                 )

# Just load in pretrained model
cfg.experiment.train = False
cfg.experiment.test = False
cfg.experiment.log = False
cfg.experiment.run_id = "SR_11M" 
print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

In [None]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

model, dm = run(cfg)     
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")


In [None]:
dm.train_set.view_sample(100, report_time=True)

# for batch in dm.train_dataloader():
#     break
# print(batch)

In [None]:
import polars as pl
pl.Config.set_tbl_rows(200)
pl.Config.set_fmt_str_lengths(100)
display(dm.tokenizer._event_counts)

### Real data

In [None]:
display(dm.meta_information["measurement_tables"])

## Generation

### Sampling from the model

In [None]:
# Default context start
baseline_covariates = {"sex": "F", "deprivation": 1.0, "ethnicity": "WHITE", "year_of_birth": 1997-40}
prompt = ["O_E___height_1", "O_E___weight_2"]
values = [163, 80]
ages_in_years = [18.2, 18.2]

# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

# Convert for model
covariates = dm.train_set._encode_covariates(**baseline_covariates).reshape(1,-1).to(device)
tokens = encode_prompt(prompt)
values_scaled = encode_value(prompt, values)
ages_in_days = encode_age(ages_in_years)

print(values_scaled)

In [None]:
# generate: sample the next 10 tokens
new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, covariates, max_new_tokens=40)

# report:
print(f"Baseline covariates: \n{baseline_covariates}\n" + "="*90)
print(f"PROMPT:")
for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                new_ages[0, :], 
                                                new_values[0, :]
                                               )
                                           ):
    # _value = dm.unstandardise(_cat, _value)
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}
    if _idx == tokens.shape[-1] - 1:
        print("="*90)
        print(f"GENERATION")

# Prompt testing

## Diagnoses: How related conditions are impacted by each other


In [None]:
exp_prompts = [["DEPRESSION"], ["TYPE1DM"], ["TYPE2DIABETES"], ["Never_smoked_tobacco_85"], ["Ex_smoker_84"]]
exp_ages = [[20] for _ in range(len(exp_prompts))]
exp_values = [[np.nan] for _ in range(len(exp_prompts))]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, (_exp_prompt, _exp_age, _exp_value) in enumerate(zip(exp_prompts, 
                                                                    exp_ages, 
                                                                    exp_values)):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        for p_idx in range(len(exp_prompts)):
            plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{'->'.join(exp_prompts[p_idx]).lower()}")
        plt.xlabel("Time (years)")
        plt.ylabel(f"$P(T>t)$ ({event_name})")
        plt.legend()
        plt.savefig(save_path + f"diabetes/{event_name}.png")


## Values: How increasing BMI affects diagnosis risk

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]

_exp_prompt = ["Body_mass_index_3"]
_exp_age = [40]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel(f"$P(T>t)$ ({event_name})")
            plt.legend()
            plt.savefig(save_path + f"bmi/{event_name}.png")


## Values: How increasing DBP affects diagnosis risk

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]


_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [40]
_exp_values = [[60.], [70.], [80.], [90.], [100.], [110.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"diastolic_blood_pressure/{event_name}.png")


## Values: How varying diagnosis affects value of DBP

In [None]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompts = [["DEPRESSION"], ["TYPE2DIABETES"], ["HF_V3"], ["HYPERTENSION"]]
_exp_age = [20]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_prompt in enumerate(_exp_prompts):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [None]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompt = ["Body_mass_index_3"]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True)
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)} of {_exp_value[0]}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


## Baseline, impact of gender

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "POLYCYSTIC_OVARIAN_SYNDROME_PCOS_V2",
                      "DEATH"
                     ]

_genders = ["M", "F", "I"]
_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [20]
_exp_value = [90.]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _gender in enumerate(_genders):

        _baseline_covariate = {"sex": _gender, "deprivation": 4.0, "ethnicity": "WHITE", "year_of_birth": 1997}
        _covariates = dm.train_set._encode_covariates(**_baseline_covariate).reshape(1,-1).to(device)
        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=_covariates,
                              is_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_genders)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_genders[p_idx]}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"gender/{event_name}.png")


# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input generation.ipynb