# Generation Demo Notebook:
## SurvivEHR: Competing Risk Survival Transformer For Causal Sequence Modelling 

In this notebook we demonstrate how a pre-trained model can be used for generation of future patient timelines

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2
%env SLURM_NTASKS_PER_NODE=28       # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/SurvivEHR/notebooks/CompetingRisk/0_pretraining
env: SLURM_NTASKS_PER_NODE=28       # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import logging
from pycox.evaluation import EvalSurv
from tqdm import tqdm
from hydra import compose, initialize
from omegaconf import OmegaConf
from contextlib import redirect_stdout
import pandas as pd
from tqdm import tqdm
import warnings

from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.examples.modelling.SurvivEHR.setup_causal_experiment import CausalExperiment
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

from FastEHR.dataloader import FoundationalDataModule
from FastEHR.database.collector import SQLiteDataCollector

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')
warnings.simplefilter('error', np.VisibleDeprecationWarning)

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Generation with SurvivEHR

In [3]:
def batch_to_sample(batch, idx, remove_masked=True):
    # Take `idx` patient from this batch
    batch = {k: v[[idx]] for k, v in batch.items()}

    if remove_masked:
        mask = (batch["tokens"][0, :] != 0)
        batch["tokens"] = batch["tokens"][:, mask]
        batch["ages"] = batch["ages"][:, mask]
        batch["values"] = batch["values"][:, mask]

    return batch

def clip_outliers(token, unstandardised_value):
    """
    Because of heavy right tails in the value distributions, standardisation of some token values over-estimates the lower quantile.
    This method ensures no values beyond those seen in the data are reported.
    """
    try:
        assert not np.isnan(unstandardised_value)
        token_meta = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["event"] == token]
        _min = token_meta["min"]
        _max = token_meta["max"]
        unstandardised_value = np.min((unstandardised_value, _max))
        unstandardised_value = np.max((unstandardised_value, _min))
    except:
        pass
        
    return unstandardised_value

def report_generation(static, tokens, ages, values, attention_mask, true_seq_len, dm, eos_token="DEATH", **kwargs):
    """
    """
    tokens = tokens[0, :]
    ages = ages[0, :]
    values = values[0, :]
    attention_mask = attention_mask[0, :]

    static = dm.test_set._decode_covariates(static.cpu())
    print("STATIC INFORMATION")
    print("="*120)
    for key, item in static.items():
        print(f"\t{key}:".ljust(20) + f"{item[0]}")

    # Report
    tokens = dm.tokenizer.decode(tokens.tolist()).split(" ")
    diagnoses = []
    last_age_day, last_age_week = 0, 0
    print("\n\nGiven patient context".upper())
    print("="*120)
    print(f"\tEVENT".ljust(75) + "| AGE IN WEEKS (days, years)".ljust(30) + " | VALUE")
    for idx_event, (token, _age, value, attn_mask) in enumerate(zip(tokens, ages, values, attention_mask)):

        if attn_mask == 0:
            break
            
        # Unscale age and bin to week fidelity
        age_day = int(_age * dm.test_set.time_scale)
        age_week = int(age_day / 7) 
        age_years = int(age_day / 365)

        # If new event create break
        if age_week != last_age_week:
            print("\t" + "."*60 + "new week" + "."*60)

        # Report next event
        age = f"{age_week}\t ({age_day}, {age_years})"
        unstandardised_value = clip_outliers(token, dm.unstandardise(token, value))
        value = f"{unstandardised_value:.2f}".ljust(10) + f"({value:.2f})"
        print(f"\t{token.ljust(75)}| {age.ljust(30)}| {value}".ljust(20))
        
        if token.upper() == token:
            diagnoses.append(token)

        if idx_event == true_seq_len - 1:
            print("\n" + "="*120)
            print("Diagnosis summary".upper())
            print(f"{diagnoses}")
            print("="*120)
            print("\n")
            print("Predicted future events".upper())
            print("="*120)
            print(f"\tEVENT".ljust(75) + "| AGE IN WEEKS (days, years)".ljust(30) + "| VALUE")


        last_age_day = age_day
        last_age_week = age_week
        if token == eos_token:
            break

def log_generation(tokens, ages, values, attention_mask, observed_seq_lens, dm, eos_token="DEATH"):

    data = []
    
    for patient_idx in range(tokens.shape[0]):

        # patient_observed_seq_len
        observed_seq_len = observed_seq_lens[patient_idx] - 1
        
        # Remove the prompt context at the start of the generation
        generated_tokens = tokens[patient_idx, observed_seq_len:].cpu().numpy()
        generated_ages   = ages[patient_idx, observed_seq_len:].cpu().numpy()
        generated_values = values[patient_idx, observed_seq_len:].cpu().numpy()
        generated_attention_mask = attention_mask[patient_idx, observed_seq_len:].cpu().numpy()
    
        for generation_step in range(generated_tokens.shape[0]-1):

            # If we reached padding then stop
            if generated_attention_mask[generation_step+1] == 0:
                break
            
            # Event transition
            previous_token = dm.decode([generated_tokens[generation_step]])
            next_token = dm.decode([generated_tokens[generation_step+1]])
            
            # Age, unscaled to years old
            next_age = generated_ages[generation_step+1]
            age_day = int(next_age * dm.test_set.time_scale)
            age_week = int(age_day / 7) 
            age_years = int(age_day / 365)

            # Ignore any events occuring after terminating token
            if previous_token == eos_token:
                break

            # Log transition, and how many steps into generation this occurred. 
            record = [previous_token, next_token, generation_step, age_years]
            data.append(record)

    return data

In [4]:
pre_trained_models = ["SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1"]
config_names = ["config_CompetingRisk11M"]

# Generate patient timelines for a handful of selected patients

In [None]:
datasets = [ "PreTrain", "FineTune_Hypertension", "FineTune_CVD", "FineTune_MultiMorbidity50+"]
patients_of_interest = [None,[0],[10],[1]]

for pre_trained_model, config_name in zip(pre_trained_models, config_names):
    os.makedirs(f"figs/generation/{pre_trained_model}/", exist_ok=True) 

    # load the configuration file, override any settings 
    with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
        cfg = compose(config_name=config_name, 
                      overrides=[# Experiment setup
                                 f"experiment.run_id='{pre_trained_model}'",
                                 "experiment.train=False",
                                 "experiment.test=False",
                                 "experiment.log=False",
                                 # Dataloader
                                 "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                 "data.min_workers=3",
                                ]
                     )     
    experiment, dm = run(cfg)     
    print(f"Loaded model {pre_trained_model} with {sum(p.numel() for p in experiment.parameters())/1e6} M parameters")
    
    for idx_dataset, dataset in enumerate(datasets):
        print(f"Generating patient timelines for dataset {dataset}")
        
        gen_save_path = f'figs/generation/{pre_trained_model}/{dataset}_dataset/'
        os.makedirs(gen_save_path, exist_ok=True) 
    
        # Load dataset
        dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                                    path_to_ds=f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/{dataset}/",
                                    overwrite_meta_information=cfg.data.meta_information_path,
                                    load=True,
                                    supervised=False if dataset.lower()=="pretrain" else True,
                                    )
        
        # Load the first batch
        for batch in dm.test_dataloader():

            # Put all items on gpu
            batch = {k: v.to(device) for k, v in batch.items()}

            # # Optionally, get only the samples of interest from batch 
            # #       (shuffle is by default turned off in test data, so this should align to dm.test_set[idx])
            # batch = batch_to_sample(batch, 6, remove_masked=True)
            break
        
        for idx_gen in range(5):

            # Generate forward
            tokens, ages, values, attention_mask = experiment.model.generate(
                static_covariates=batch["static_covariates"],
                tokens=batch['tokens'],
                ages=batch['ages'],
                values=batch['values'],
                attention_mask=batch['attention_mask'],
                max_new_tokens=50,
                exceed_block_size=True,
                )

            # Report generated outcomes
            patients = patients_of_interest[idx_dataset] if patients_of_interest[idx_dataset] is not None else [i for i in range(tokens.shape[0])]
            for idx_patient in tqdm(patients, ascii=True, desc=f"Reporting generation {idx_gen + 1} results for all patients in batch"):
                
                out_dir = gen_save_path + f'patient{idx_patient}/'
                os.makedirs(out_dir, exist_ok=True)
                with open(out_dir + f"generation{idx_gen}.txt", 'w') as f:
                    with redirect_stdout(f):
                        report_generation(
                            static         = batch["static_covariates"][idx_patient], 
                            tokens         = tokens[[idx_patient],:],
                            ages           = ages[[idx_patient],:], 
                            values         = values[[idx_patient], :], 
                            attention_mask = attention_mask[[idx_patient], :],
                            true_seq_len   = batch["attention_mask"][[idx_patient],:].sum(), 
                            dm             = dm
                        )
                        
            # Log for plotting
            # log_generation(tokens, ages, values, observed_seq_len, dm)


# Generate patient timelines for many test patients and condense into files for plotting

In [None]:
generations_per_patient = 5
datasets = [ "FineTune_MultiMorbidity50+", "FineTune_CVD", "FineTune_Hypertension", "PreTrain",]
# fraction_of_batches = 1.0  # 0.035
number_of_batches = 1000

for pre_trained_model, config_name in zip(pre_trained_models, config_names):
    os.makedirs(f"figs/generation/{pre_trained_model}/", exist_ok=True) 

    # load the configuration file, override any settings 
    with initialize(version_base=None, config_path="../../../confs", job_name="testing_notebook"):
        cfg = compose(config_name=config_name, 
                      overrides=[# Experiment setup
                                 f"experiment.run_id='{pre_trained_model}'",
                                 "experiment.train=False",
                                 "experiment.test=False",
                                 "experiment.log=False",
                                 # Dataloader
                                 "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                                 "data.min_workers=3",
                                ]
                     )     
    
    experiment, dm = run(cfg)     
    print(f"Loaded model {pre_trained_model} with {sum(p.numel() for p in experiment.parameters())/1e6} M parameters")
    
    for idx_dataset, dataset in enumerate(datasets):
        print(f"Generating patient timelines for dataset {dataset}")
        
        # Create save path
        gen_save_path = f'figs/generation/{pre_trained_model}/{dataset}_dataset/'
        os.makedirs(gen_save_path, exist_ok=True) 

        # Load dataset
        dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                                    path_to_ds=f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/{dataset}/",
                                    overwrite_meta_information=cfg.data.meta_information_path,
                                    load=True,
                                    supervised=False if dataset.lower()=="pretrain" else True,
                                    )
        
        dataset_generated_data = []
        # number_of_batches = int(fraction_of_batches * len(dm.test_dataloader())) 
        number_of_batches = np.min((number_of_batches, len(dm.test_dataloader())))
        for idx_batch, batch in tqdm(enumerate(dm.test_dataloader()), total=number_of_batches):
            
            if idx_batch > number_of_batches:
                break

            # Put all items on gpu
            batch = {k: v.to(device) for k, v in batch.items()}

            for idx_gen in range(generations_per_patient):

                # Generate forward
                tokens, ages, values, attention_mask = experiment.model.generate(
                    static_covariates=batch["static_covariates"],
                    tokens=batch['tokens'],
                    ages=batch['ages'],
                    values=batch['values'],
                    attention_mask=batch['attention_mask'],
                    max_new_tokens=5,
                    exceed_block_size=True,
                    )                

                # Report generated outcomes for one patient
                # idx_patient = 0
                # report_generation(
                #     static         = batch["static_covariates"][idx_patient], 
                #     tokens         = tokens[[idx_patient],:],
                #     ages           = ages[[idx_patient],:], 
                #     values         = values[[idx_patient], :], 
                #     attention_mask = attention_mask[[idx_patient], :],
                #     true_seq_len   = batch["attention_mask"][[idx_patient],:].sum(), 
                #     dm             = dm
                #     )
                
                # Log for plotting                
                observed_seq_lens = batch["attention_mask"].sum(axis=1).long().tolist() 
                gen_data = log_generation(tokens, ages, values, attention_mask, observed_seq_lens, dm)
                
                # Record
                dataset_generated_data.append(gen_data)

        dataset_generated_data = np.concatenate(dataset_generated_data)
        dataset_generated_data = pd.DataFrame(dataset_generated_data, columns=["Previous event", "Next event", "Generation step", "Age (years)"])
        dataset_generated_data.to_csv(gen_save_path + f'next_event_{dataset}.csv', index=False)

INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/. This will be loaded in causal form.
INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 toke

Loaded model SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1 with 11.20919 M parameters
Generating patient timelines for dataset FineTune_MultiMorbidity50+


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_MultiMorbidity50+/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel

Generating patient timelines for dataset FineTune_CVD


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/split=train/ dataset, with 5

Generating patient timelines for dataset FineTune_Hypertension


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hypertension/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_Hyper

Generating patient timelines for dataset PreTrain


INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/split=train/ dataset, with 23,613,894 sample

In [None]:
print(batch.keys())

In [None]:
# # dataset_generated_data = np.concatenate(dataset_generated_data)
# # dataset_generated_data = pd.DataFrame(dataset_generated_data, columns=["Previous event", "Next token", "Generation step", "Age (years)"])
# display(dataset_generated_data.head())
# print(len(dataset_generated_data))
# dataset_generated_data.to_csv(gen_save_path + 'next_event_data.csv', index=False)

In [None]:
number_of_batches

In [None]:
raise NotImplementedErrror

In [None]:
# batch = next(iter(dm.val_dataloader()))
batch = dm.collate_fn.convert_to_supervised(batch, supervised_time_scale=1.0)
print(batch.keys())
batch_size, seq_len = batch['tokens'].shape
print(batch_size)
print(seq_len)
print(batch["tokens"].shape)

In [None]:
token = "25_Hydroxyvitamin_D2_level_92"

print(dm.unstandardise(token, 0))

print()

In [None]:
import polars as pl 
print(dm.tokenizer._event_counts["FREQUENCY"].sum())
print(dm.tokenizer._event_counts)

anxiety_freq = dm.tokenizer._event_counts.filter(pl.col("EVENT") == "ANXIETY")["FREQUENCY"].item()
display(anxiety_freq)

# display([i for i in dm.tokenizer._event_counts["EVENT"] if i.upper() == i ])
display(len(dm.tokenizer._event_counts.filter(pl.col("FREQUENCY") >= anxiety_freq)) - 1)  # remove UNK token
display(len(dm.tokenizer._event_counts.filter(pl.col("FREQUENCY") < anxiety_freq)))

In [None]:
# print(survs[0][0])
# print(l)
plt.plot(experiment.model.surv_layer.t_eval, survs[0][0][0,:])
plt.savefig("fig.png")

# Check against the database

# Generate forward

In [None]:
# path_to_directory = os.getcwd() + "/../data/"
PATH_TO_DB = "/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db"

collector = SQLiteDataCollector(db_path=PATH_TO_DB)
collector.connect()

In [None]:
collector.cursor.execute("""SELECT name FROM sqlite_master WHERE type='table' LIMIT 3;""")   # 
results = collector.cursor.fetchall()
for result in results:
    print(result)

SEX                 | F
IMD                 | 4.0
ETHNICITY           | WHITE
birth_year          | 1933.0

In [None]:
collector.cursor.execute("""SELECT * FROM static_table WHERE sex=='F' AND imd=='4' AND ethnicity=='WHITE' AND YEAR_OF_BIRTH LIKE '1933-%'""")   # 

patient_ids = []
results = collector.cursor.fetchall()
for result in results:
    # print(result)
    patient_ids.append(result[1])
patient_ids_str1 = ", ".join(str(pid) for pid in patient_ids)
print(f"{len(patient_ids)} patients with static match")

In [None]:
# query = f"""
#             SELECT DISTINCT PATIENT_ID
#             FROM diagnosis_table
#             WHERE EVENT IN ('HYPERTENSION', 'ANY_DEAFNESS_HEARING_LOSS_V2', 'IHDINCLUDINGMI_OPTIMALV2, OSTEOARTHRITIS,TYPE2DIABETES') 
#                 AND patient_id IN ({patient_ids_str})
#             GROUP BY patient_id
#             HAVING COUNT(DISTINCT event) >= 5
#             """

query = f"""SELECT patient_id
            FROM diagnosis_table
            WHERE patient_id IN ({patient_ids_str1} )
            GROUP BY patient_id
            HAVING 
                 COUNT(
                      DISTINCT CASE WHEN event IN ('HYPERTENSION', 'ANY_DEAFNESS_HEARING_LOSS_V2', 'IHDINCLUDINGMI_OPTIMALV2', 'OSTEOARTHRITIS', 'TYPE2DIABETES')
                                    THEN event
                               END
                    ) = 5
            ORDER BY patient_id;
            """

patient_ids = []
collector.cursor.execute(query)   # measurement_ACE_Inhibitors_D2T
results = collector.cursor.fetchall()
for result in results:
    patient_ids.append(result[0])
patient_ids_str2 = ", ".join(str(pid) for pid in patient_ids)
print(f"{len(patient_ids)} patients with static match and all of these events")

In [None]:
patient_ids = []
collector.cursor.execute(f"""SELECT * FROM diagnosis_table WHERE patient_id IN ({patient_ids_str2}) ORDER BY patient_id ASC, date ASC""")   # event=='ALCOHOLMISUSE_V2' AND date LIKE '2008-%' AND 
results = collector.cursor.fetchall()
for result in results:
    print(result)
    patient_ids.append(result[1])

# collector.cursor.execute("""SELECT * FROM diagnosis_table WHERE event=='HYPERTENSION' AND date LIKE '2003-%' LIMIT 10""")   # measurement_ACE_Inhibitors_D2T
# results = collector.cursor.fetchall()
# for result in results:
#     print(result)


In [None]:
patient_ids_str = ", ".join(str(pid) for pid in patient_ids)
collector.cursor.execute(f"""SELECT * FROM measurement_Systolic_blood_pressure_4 WHERE patient_id == 2666145020970""")   # 5437879821203
results = collector.cursor.fetchall()
for result in results:
    print(result)


In [None]:
collector.cursor.execute("""SELECT * FROM static_table WHERE practice_id=='21573' AND patient_id=='6626432621573'""")   # 

results = collector.cursor.fetchall()
for result in results:
    print(result)


collector.cursor.execute("""SELECT * FROM measurement_Body_mass_index_3 WHERE practice_id=='21573' AND patient_id=='6626432621573'""")   # 

results = collector.cursor.fetchall()
for result in results:
    print(result)

## Build configurations

In [None]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  # overrides=[
                  #     ]
                 )

# Just load in pretrained model
cfg.experiment.train = False
cfg.experiment.test = False
cfg.experiment.log = False
cfg.experiment.run_id="CR_11M"



print(OmegaConf.to_yaml(cfg))

save_path = f"/rds/projects/s/subramaa-mum-predict/CharlesGadd_Oxford/FoundationModelOutput/checkpoints/{cfg.experiment.run_id}/"

In [None]:
 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28      

model, dm = run(cfg)     
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")


In [None]:
dm.train_set.view_sample(10, max_dynamic_events=None, report_time=True)

# for batch in dm.train_dataloader():
#     break
# print(batch["tokens"].shape)


# Generation from real prompts

In [None]:
# define encoding functions (TODO: add this wrap to datamodule
encode_prompt = lambda prompt_list: torch.from_numpy(np.array(dm.encode(prompt_list)).reshape((1,-1))).to(device)
encode_value = lambda prompt_list, value_list: torch.tensor(np.array([dm.standardise(_cat, _val) for _cat, _val in zip(prompt_list, value_list) ]).reshape((1,-1)), dtype=torch.float32).to(device)
encode_age = lambda age_list: torch.tensor([365 * _age for _age in age_list], dtype=torch.int64).reshape((1,-1)).to(device)

def table(_tokens,_ages,_values):
    # print table rows 
    assert _tokens.shape[0] == 1
    assert _ages.shape[0] == 1
    assert _values.shape[0] == 1
    
    for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(_tokens[0].tolist()).split(" "), 
                                                    _ages[0, :], 
                                                    _values[0, :]
                                                    )
                                                ):
        _value = dm.unstandardise(_cat, _value)
        print(f"\t{_cat}".ljust(60) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}



In [None]:
dm.meta_information["diagnosis_table"]["event"].to_list()
dm.meta_information["measurement_tables"]["event"].to_list()

## Brute force search, get some prompts from the test dataset which show some different criteria

In [None]:
# indexing_conditions_to_pivot_on = "TYPE2DIABETES"    # TYPE1DM, HYPERTENSION, OSTEOARTHRITIS, CKDSTAGE3TO5, HF_V3, ISCHAEMICSTROKE_V2, DEPRESSION
# exclude_on_events = ["Statins",
#                      "Metformin_612_A10BD2",
#                      "Lipid_lowering_drugs_Optimal"]

indexing_conditions_to_pivot_on = ["POLYCYSTIC_OVARIAN_SYNDROME_PCOS_V2",
                                   "COPD",
                                   # "ENDOMETRIOSIS_ADENOMYOSIS_V2"
                                  ]
exclude_on_events = []



In [None]:
indexing_token_to_pivot_on = dm.encode(indexing_conditions_to_pivot_on)
print(indexing_token_to_pivot_on)


tokens_to_exclude_on = dm.encode(exclude_on_events)
print(tokens_to_exclude_on)

patients_satisfying_criteria = []
samples_satisfying_criteria = []
example_count = 0

for _idx, sample in tqdm(enumerate(dm.test_set), total=len(dm.test_set)):

    number_of_index_events = sum([tkn for tkn in indexing_token_to_pivot_on if tkn in sample["tokens"]])
    
    if (len(sample["tokens"]) > 5) and (number_of_index_events==len(indexing_token_to_pivot_on)):

        # todo: this is excluded events at any time, change to before the index event
        number_of_excluded_events = sum([tkn for tkn in tokens_to_exclude_on if tkn in sample["tokens"]])

        if number_of_excluded_events == 0:
            patients_satisfying_criteria.append(_idx)
            samples_satisfying_criteria.append(sample)

            if example_count >= 4:
                break
            else:
                example_count += 1
                print(example_count)

    # elif _idx > 100000:
    #     break
    else:
        pass

In [None]:
print(patients_satisfying_criteria)
# patients_satisfying_criteria = [724, 1760, 2055, 2099, 2167]

In [None]:
for _patient_idx in patients_satisfying_criteria:
    print(_patient_idx)

    # Get the sample
    sample = dm.test_set[_patient_idx]
    _index = (sample["tokens"] == indexing_token_to_pivot_on).nonzero(as_tuple=True)[0].item()

    # chunk by day
    _day_at_index = int(sample["ages"][_index])
    _index_pre = sum(sample["ages"] < _day_at_index)
    _index_inc = sum(sample["ages"] <= _day_at_index)
    
    for _phase, _split_at in enumerate([_index_pre, _index_inc]):

        if _phase == 0:
            print(f"\n\nBefore {dm.decode([_token]).lower()} is seen in the medical history")
        else:
            print('\n------------------------------------ page break ------------------------------------')
            print(f"\n\nAfter the diagnosis of {dm.decode([_token]).lower()} is then seen in the medical history")

        _covariates = sample["static_covariates"].reshape((1,-1))
        _tokens = sample["tokens"][:_split_at].reshape((1,-1))
        _ages = sample["ages"][:_split_at].reshape((1,-1))
        _values = sample["values"][:_split_at].reshape((1,-1))

        # Report the initial part of their historical context
        _dec_covariates = dm.train_set._decode_covariates(_covariates)
        print(f"\n\nMedical history of a \n\t" + \
                        f"{_dec_covariates['ETHNICITY'][0].lower()}, " + \
                        f"{'male' if _dec_covariates['SEX'][0] == 'M' else 'female'} patient, " + \
                        f"born in {int(_dec_covariates['birth_year'][0])}, " + \
                        f"with IMD (deprivation) level {int(_dec_covariates['IMD'][0])}. \n\n" 
              )
        table(_tokens, _ages, _values)

        
        # Predict the future and report
        new_tokens, new_ages, new_values = model.generate(_tokens.to(device), _ages.to(device), _values.to(device), _covariates.to(device), max_new_tokens=20)
        print(f"""\nSurvivEHR then predicts the next events to be:
               """)
        table(new_tokens[:, _tokens.shape[1]:].reshape((1,-1)), 
              new_ages[:, _tokens.shape[1]:].reshape((1,-1)),
              new_values[:, _tokens.shape[1]:].reshape((1,-1))
             )

    print('\n----------------------------------------------------------------------------------------')
    print('------------------------------------ document break ------------------------------------')
    print('----------------------------------------------------------------------------------------')



In [None]:
dm.test_set.view_sample(patients_satisfying_criteria[1], max_dynamic_events=None, report_time=True)

# Generation from fixed prompts

### Sampling from the model

In [None]:
model= model.to(device)

baseline_covariates = {"sex": "F", "deprivation": 5.0, "ethnicity": "WHITE", "year_of_birth": 1997-65}

multimorbidity_conditions = ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION", "IHDINCLUDINGMI_OPTIMALV2", "COPD"]
at_ages = [40, 40, 44, 49, 65]

prompt, ages_in_years, values = [], [], []

for condition, age in zip(multimorbidity_conditions, at_ages):
    # Default context start
    prompt.append(condition)
    ages_in_years.append(age)
    values.append(np.nan)

    # Convert for model
    covariates = dm.train_set._encode_covariates(**baseline_covariates).reshape(1,-1).to(device)
    tokens = encode_prompt(prompt)
    values_scaled = encode_value(prompt, values)
    ages_in_days = encode_age(ages_in_years)

    # generate: sample the next 10 tokens
    new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, covariates, max_new_tokens=10)
    
    # report:
    print(f"Baseline covariates: \n{baseline_covariates}\n" + "="*90)
    print(f"PROMPT:")
    for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                    new_ages[0, :], 
                                                    new_values[0, :]
                                                   )
                                               ):
        _value = dm.unstandardise(_cat, _value)
        print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}
        if _idx == tokens.shape[-1] - 1:
            print("="*90)
            print(f"GENERATION")



# Prompt testing

In [None]:
# generate: sample the next 10 tokens
new_tokens, new_ages, new_values = model.generate(tokens, ages_in_days, values_scaled, covariates, max_new_tokens=50)

# report:
print(f"Baseline covariates: \n{baseline_covariates}\n" + "="*90)
print(f"PROMPT:")
for _idx, (_cat, _age, _value) in enumerate(zip(dm.decode(new_tokens[0].tolist()).split(" "), 
                                                new_ages[0, :], 
                                                new_values[0, :]
                                               )
                                           ):
    # _value = dm.unstandardise(_cat, _value)
    print(f"{_cat}".ljust(50) + f"{_value:.02f}".ljust(15) + f"at age {_age/365:.0f} ({int(_age)} days)")    # with value {_value}
    if _idx == tokens.shape[-1] - 1:
        print("="*90)
        print(f"GENERATION")

## Diagnoses: How related conditions are impacted by each other - multi-morbidity

In [None]:
exp_prompts = [["TYPE2DIABETES", "Metformin_612_A10BD2"],
               ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION",],
               ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION", "IHDINCLUDINGMI_OPTIMALV2"],
               ["TYPE2DIABETES", "Metformin_612_A10BD2", "DEPRESSION", "IHDINCLUDINGMI_OPTIMALV2", "COPD"],
              ]
exp_promps_lbl = ["T2D+Metformin", "+ Depression", "+IHD/MI", "+COPD"]
exp_ages = [[40, 40],
            [40, 40, 44],
            [40, 40, 44, 49],
            [40, 40, 44, 49, 65],
           ]
exp_values = [[np.nan, np.nan],
              [np.nan, np.nan, np.nan],
              [np.nan, np.nan, np.nan, np.nan],
              [np.nan, np.nan, np.nan, np.nan, np.nan],
              ]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, (_exp_prompt, _exp_age, _exp_value) in enumerate(zip(exp_prompts, 
                                                                    exp_ages, 
                                                                    exp_values)):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True,
                              return_loss=False,
                              return_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        for p_idx in range(len(exp_prompts)):
            plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{exp_promps_lbl[p_idx]}")
        plt.xlabel("Time (years)")
        plt.ylabel(f"$P(T>t)$ ({event_name})")
        plt.legend()
        plt.savefig(save_path + f"multimorbidity/{event_name}.png")


## Diagnoses: How related conditions are impacted by each other


In [None]:
exp_prompts = [["DEPRESSION"], ["TYPE1DM"], ["TYPE2DIABETES"], ["Never_smoked_tobacco_85"], ["Ex_smoker_84"]]
exp_ages = [[20] for _ in range(len(exp_prompts))]
exp_values = [[np.nan] for _ in range(len(exp_prompts))]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, (_exp_prompt, _exp_age, _exp_value) in enumerate(zip(exp_prompts, 
                                                                    exp_ages, 
                                                                    exp_values)):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_generation=True,
                              return_loss=False,
                              return_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        for p_idx in range(len(exp_prompts)):
            plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{'->'.join(exp_prompts[p_idx]).lower()}")
        plt.xlabel("Time (years)")
        plt.ylabel(f"$P(T>t)$ ({event_name})")
        plt.legend()
        plt.savefig(save_path + f"diabetes/{event_name}.png")


## Values: How increasing BMI affects diagnosis risk

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]

_exp_prompt = ["Body_mass_index_3"]
_exp_age = [40]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True)
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel(f"$P(T>t)$ ({event_name})")
            plt.legend()
            plt.savefig(save_path + f"bmi/{event_name}.png")


## Values: How increasing DBP affects diagnosis risk

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "DEATH"
                     ]


_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [40]
_exp_values = [[60.], [70.], [80.], [90.], [100.], [110.]]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_exp_values)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_exp_values[p_idx][0]:.2f}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"diastolic_blood_pressure/{event_name}.png")


## Values: How varying diagnosis affects value of DBP

In [None]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompts = [["DEPRESSION"], ["TYPE2DIABETES"], ["HF_V3"], ["HYPERTENSION"]]
_exp_age = [20]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_prompt in enumerate(_exp_prompts):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


## Values: How increasing bmi affects value of diastolic_blood_pressure

In [None]:
measurements_of_interest = "Diastolic_blood_pressure_5"


_exp_prompt = ["Body_mass_index_3"]
_exp_values = [[18.], [21.], [24.], [30.], [40.]]
_exp_value = [np.nan]


with torch.no_grad(): 
    model.eval()

    for p_idx, _exp_value in enumerate(_exp_values):

        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)
        
        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )
        val_dist = outputs["values_dist"]

        dist = val_dist[model.value_layer.token_key(dm.tokenizer._stoi[measurements_of_interest])]
        print(f"{'->'.join(_exp_prompt)} of {_exp_value[0]}".ljust(30) + "leads to".ljust(20) + f"standardised {measurements_of_interest} ~ N({dist.loc.item():.1f}, {dist.scale.item():.1f})")


## Baseline, impact of gender

In [None]:
events_of_interest = ["Body_mass_index_3", "Diastolic_blood_pressure_5", 
                      "TYPE1DM", "TYPE2DIABETES",
                      "HYPERTENSION", "OSTEOARTHRITIS",
                      "CKDSTAGE3TO5",
                      "HF_V3", "ISCHAEMICSTROKE_V2",
                      "POLYCYSTIC_OVARIAN_SYNDROME_PCOS_V2",
                      "DEATH",
                      "COCP_reg_contraception",
                      "all_contraceptive"
                     ]

_genders = ["M", "F", "I"]
_exp_prompt = ["Diastolic_blood_pressure_5"]
_exp_age = [20]
_exp_value = [90.]

with torch.no_grad(): 
    model.eval()

    _exp_survs = []
    for p_idx, _gender in enumerate(_genders):

        _baseline_covariate = {"sex": _gender, "deprivation": 4.0, "ethnicity": "WHITE", "year_of_birth": 1997}
        _covariates = dm.train_set._encode_covariates(**_baseline_covariate).reshape(1,-1).to(device)
        _tokens = encode_prompt(_exp_prompt)
        _values_scaled = encode_value(_exp_prompt, _exp_value)
        _ages_in_days = encode_age(_exp_age)

        outputs, _, _ = model(_tokens,
                              values=_values_scaled,
                              ages=_ages_in_days,
                              covariates=_covariates,
                              is_causal=False,
                              return_loss=False,
                              return_generation=True
                             )        
        surv = outputs["surv"]["surv_CDF"]
        _exp_survs.append(surv)

    for si, _ in enumerate(surv):
        plt.close()
        event_name = dm.decode([si + 1])
        if event_name in events_of_interest:
            for p_idx in range(len(_genders)):
                plt.plot(model.surv_layer.t_eval / 365, _exp_survs[p_idx][si][0, :], label=f"{_genders[p_idx]}")
            plt.xlabel("t (years)")
            plt.ylabel("P(T>t) ()")
            plt.legend()
            plt.savefig(save_path + f"gender/{event_name}.png")


# Appendix: model architectures

In [None]:
display(model)

In [None]:
!jupyter nbconvert --to html --no-input 2_generation.ipynb