# Testing the pre-trained Transfomer + NN classifier for next event prediction

This model uses the same Transformer input layer as SurvivEHR, but predicts logits to classifier next event

In [9]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

%env SLURM_NTASKS_PER_NODE=28   

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: SLURM_NTASKS_PER_NODE=28


In [10]:
import torch
import logging
from hydra import compose, initialize

from setup_causal_mlp_experiment import CausalMLPExperiment, setup_mlp_experiment
from setup_causal_t_mlp_experiment import CausalTMLPExperiment, setup_t_mlp_experiment
from CPRD.examples.modelling.benchmarks.Pretrain.run_experiment import run

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

%env SLURM_NTASKS_PER_NODE=28   

Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


# Create and run pre-trained experiment

In [11]:
with initialize(version_base=None, config_path="../../SurvivEHR/confs", job_name=""):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[
                             "experiment.run_id='TransformerMLP'",
                             # "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             # "data.min_workers=12",
                             # "experiment.train=False",
                             # "experiment.test=True",
                             # "experiment.log=True",
                             "+static=False",
                             "experiment.project_name='Evaluating pre-trained models'",
                             # f"experiment.run_id='{pre_trained_model}'",
                             "experiment.train=False",
                             "experiment.test=True",
                             "data.batch_size=128",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=12",
                             "optim.limit_test_batches=0.035",
                            ])

model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")


INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/. This will be loaded in causal form.
INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 toke

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

Loaded model with 11.312201 M parameters


In [4]:
import matplotlib.pyplot as plt
import numpy as np
import wandb
import polars as pl
wandb.login()

# Load causal_eval results from log
api = wandb.Api()
run = api.run("cwlgadd/SurvivEHR/runs/1gyzrm07")

In [8]:
history = run.history(keys=None)
raw_data_from_wandb = {}
for key in history.keys():
    raw_data_from_wandb = {**raw_data_from_wandb, key: history[key].to_numpy()[-1]}
display(raw_data_from_wandb)



{'Test:_clfCintra42': 0.6221590909090909,
 'Test:_prevalenceCintra97': 0.36501901140684373,
 'Test:_clfCintra201': 0.3991222770007209,
 'Test:_clfCintra234': 0.8622548865171152,
 'Test:_prevalenceCintra128': 0.48288973384030853,
 'Test:_clfCintra64': 0.2610588645071404,
 'Test:_prevalenceCinter': 0.8595876432531979,
 'Test:_prevalenceCintra257': 0.9733840304180014,
 'Test:_prevalenceCintra248': 0.939163498099471,
 'Test:_clfCintra14': 0.47777777777777775,
 'Test:_clfCintra21': 0.35858585858585856,
 'Test:_clfCintra43': 0.5701368523949168,
 'Test:_prevalenceCintra227': 0.8593155893538502,
 'Test:_prevalenceCintra254': 0.9619771863128516,
 'Test:_prevalenceCintra28': 0.1026615969581749,
 'Test:_clfCintra157': 0.3487940556645585,
 'Test:_clfCintra130': 0.4214485095688359,
 'Test:_prevalenceCintra171': 0.6463878326995584,
 'Test:_prevalenceCintra9': 0.030418250950570335,
 'Test:_prevalenceCintra219': 0.8288973384034037,
 'Test:_clfCintra109': 0.8214413639945547,
 'Test:_prevalenceCintra133

In [6]:
valued_events = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
non_valued_events = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0]["event"].to_list()
diagnoses = dm.meta_information["diagnosis_table"]["event"].to_list()

# print(valued_events)

NameError: name 'dm' is not defined

In [54]:
Cinter_keys = [_key for _key in raw_data_from_wandb.keys() if "Test:_clfCintra" in _key ]

decoded_cintra_diagnoses = {}
decoded_cintra_non_valued = {}
decoded_cintra_valued = {}

for _key in Cinter_keys:
    _event = int(_key[len("Test:_clfCintra"):])                 # token
    _event_name = dm.decode([_event]).split(" ")[0]         # string
    _event_cintra = raw_data_from_wandb[_key]               # concordance

    if _event_name in diagnoses:# .upper() == _event_name:
        decoded_cintra_diagnoses = {**decoded_cintra_diagnoses, _event_name: _event_cintra}
    elif _event_name in non_valued_events:
        decoded_cintra_non_valued = {**decoded_cintra_non_valued, _event_name: _event_cintra}
    elif _event_name in valued_events:
        decoded_cintra_valued = {**decoded_cintra_valued, _event_name: _event_cintra}
    else:
        raise NotImplementedError


In [55]:
BaseCinter_keys = [_key for _key in raw_data_from_wandb.keys() if "Test:_prevalenceCintra" in _key ]

base_decoded_cintra_diagnoses = {}
base_decoded_cintra_non_valued = {}
base_decoded_cintra_valued = {}

base_prevalence_diagnoses = {}
base_prevalence_non_valued = {}
base_prevalence_valued = {}


for _key in BaseCinter_keys:
    _event = int(_key[len("Test:_prevalenceCintra"):])                 # token
    _event_name = dm.decode([_event]).split(" ")[0]         # string
    _event_cintra = raw_data_from_wandb[_key]               # concordance

    prevalence = dm.tokenizer._event_counts
    prevalence = prevalence.filter(pl.col("EVENT") ==_event_name)["COUNT"][0]

    if _event_name in diagnoses: #.upper() == _event_name:
        base_decoded_cintra_diagnoses = {**base_decoded_cintra_diagnoses, _event_name: _event_cintra}
        base_prevalence_diagnoses = {**base_prevalence_diagnoses, _event_name: prevalence}
        
    elif _event_name in non_valued_events:
        base_decoded_cintra_non_valued = {**base_decoded_cintra_non_valued, _event_name: _event_cintra}
        base_prevalence_non_valued = {**base_prevalence_non_valued, _event_name: prevalence}

    elif _event_name in valued_events:
        base_decoded_cintra_valued = {**base_decoded_cintra_valued, _event_name: _event_cintra}
        base_prevalence_valued = {**base_prevalence_valued, _event_name: prevalence}

    else:
        raise NotImplementedError


In [56]:
keys_included_diagnoses = list(set(base_decoded_cintra_diagnoses.keys()) & set(decoded_cintra_diagnoses.keys()))
keys_included_non_valued = list(set(base_decoded_cintra_non_valued.keys()) & set(decoded_cintra_non_valued.keys()))
keys_included_valued = list(set(base_decoded_cintra_valued.keys()) & set(decoded_cintra_valued.keys()))


In [67]:
for dict_name, result_dict, result_dict_base, result_dict_prev, keys_to_include in zip(["diagnoses", "medications", "measurements"],
                                                                     [decoded_cintra_diagnoses, decoded_cintra_non_valued, decoded_cintra_valued], 
                                                                     [base_decoded_cintra_diagnoses, base_decoded_cintra_non_valued, base_decoded_cintra_valued], 
                                                                     [base_prevalence_diagnoses, base_prevalence_non_valued, base_prevalence_valued],
                                                                     [keys_included_diagnoses, keys_included_non_valued, keys_included_valued]
                                                                     ):
    plt.close()
    # plt.figure(figsize=(len(keys_to_include)/5,5))
    fig, ax1 = plt.subplots(figsize=(len(keys_to_include)/4,8))
    ax2 = ax1.twinx()  

    X_axis = np.arange(len(keys_to_include)) 

    Y_base = [result_dict_base[_key] for _key in keys_to_include]
    Y_survivEHR = [result_dict[_key] for _key in keys_to_include]
    Y_log_prevalence = [np.log(result_dict_prev[_key]) for _key in keys_to_include]

    # Sort by prevalence
    arg_sort = np.argsort(Y_log_prevalence)
    Y_base = [Y_base[_i] for _i in arg_sort]
    Y_survivEHR = [Y_survivEHR[_i] for _i in arg_sort]
    Y_log_prevalence = [Y_log_prevalence[_i] for _i in arg_sort]
    keys_to_include = [keys_to_include[_i] for _i in arg_sort]

    width = 0.25
    ax1.bar(X_axis - width, Y_base, width, label = f'Concordance by prevalence (Average over events: {raw_data_from_wandb["Test:_prevalenceCinter"]:.3f})', color="mediumblue") 
    ax1.bar(X_axis, Y_survivEHR, width, label = f'Concordance by TransformerMLP (Average over events: {raw_data_from_wandb["Test:_clfCinter"]:.3f})', color="firebrick") 
    ax2.plot(X_axis, Y_log_prevalence, width, label='Log-prevalence', color="darkseagreen", marker=".")  #  + width

    ax1.set_xticks(X_axis, keys_to_include, rotation=90) 
    # ax1.xticks(X_axis, keys_to_include) 
    ax1.set_xlabel("Events") 
    ax1.set_ylabel("Self-supervised Concordance") 
    ax2.set_ylabel("Log Prevalence") 
    ax1.legend(loc="upper left")
    ax2.legend(loc="upper right")
    ax1.set_ylim(0, 1.2)
    ax2.set_ylim(np.min(Y_log_prevalence)*0.95, np.max(Y_log_prevalence)*1.1)
    

    # plt.bar(result_dict.keys(), result_dict.values(), 0.5, color='g')
    # ax1.xticks()

    # ybar = raw_data_from_wandb["Test:Cinter"]
    # ax1.plot([0, len(result_dict)-1], 
    #          [ybar, ybar],
    #          label=f"SurvivEHR marginalised over events",
    #          color="firebrick")

    # ybar = raw_data_from_wandb["Test:base_Cinter"]
    # ax1.plot([0, len(result_dict)-1], 
    #          [ybar, ybar],
    #          label=f"Prevalence marginalised over events",
    #          color="mediumblue")
    
    plt.tight_layout()
    plt.savefig(f"figs/inter_causal_eval_{dict_name}.png", bbox_inches="tight")
    plt.close()