# Create Figures for causal evaluation. 

This notebook loads artifacts from logged test callbacks of pre-trained models

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

os.chdir('/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/paper_plots')
print(os.getcwd())

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/paper_plots


In [2]:
import matplotlib.pyplot as plt
import numpy as np
import wandb
import polars as pl
import pandas as pd
from hydra import compose, initialize
import seaborn as sns
import json
import io
from matplotlib.colors import LogNorm, Normalize

from CPRD.examples.data.map_to_reduced_names import convert_event_names, EVENT_NAME_SHORT_MAP
from CPRD.examples.modelling.SurvivEHR.run_experiment import run

%env SLURM_NTASKS_PER_NODE=28   

sns.set(style="ticks", context="notebook")


env: SLURM_NTASKS_PER_NODE=28


## Initialise the dataloader used for pre-training

In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../SurvivEHR/confs", job_name="causal_metric_testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             "experiment.run_id=SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1",
                             "experiment.train=False",
                             "experiment.test=False",
                             "experiment.log=False",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=12",
                            ]
                 )     

model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loaded model with 11.20919 M parameters


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


## Define functions

In [4]:
def get_artifact_data(run_id):
    
    wandb.login()
    
    # Load causal_eval results from log
    api = wandb.Api()
    run = api.run(run_id)

    history = run.history(keys=None)
    raw_data_from_wandb = {}
    for key in history.keys():
        raw_data_from_wandb = {**raw_data_from_wandb, key: history[key].to_numpy()[-1]}

    return raw_data_from_wandb
    
data = get_artifact_data("cwlgadd/Evaluating pre-trained models/3omvr6q0")
# print(data)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcwlgadd[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Plot self-supervised concordance stratified by event type

In [5]:
def _get_stratified_concordance(raw_data_from_wandb, method, dm):
    """
    Load logged IEC scores from wandb, and split them by the event types using the datamodule.
    
    raw_data_from_wandb: loaded raw wandb logging data, obtained through the causal metric callback, 
    mehod:               a flag which helps find what the IEC score was logged as
    dm:                  FastEHR datamodule
    """

    # Get the subgroups of tokens we want to plot by
    lab_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
    medication_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0]["event"].to_list()
    diagnosis_names = dm.meta_information["diagnosis_table"]["event"].to_list()
    
    valid_keys = [_key for _key in raw_data_from_wandb.keys() if f"Test:CausalMetrics{method}_stratify_by_" in _key ]

    diagnoses, medications, labs = {}, {}, {}
    for _key in valid_keys:

        # Get the token 
        _event_token = int(_key[len(f"Test:CausalMetrics{method}_stratify_by_"):])                 # token
        _event_name = dm.decode([_event_token]).split(" ")[0]         # string
        
        _event_concordance = {_event_name: raw_data_from_wandb[_key]}
    
        if _event_name in diagnosis_names:
            diagnoses.update(_event_concordance)
        elif _event_name in medication_names:
            medications.update(_event_concordance)
        elif _event_name in lab_names:
            labs.update(_event_concordance)
        else:
            raise NotImplementedError

    return [diagnoses, medications, labs]

def get_event_stratified_concordance(run_names, run_wandb_paths, dm):
    """
    For each provided wandb run, load the IEC scores, stratify using ``_get_stratified_concordance``, 
    """
    
    records = []
    event_counts = dm.tokenizer._event_counts
    
    for run_name, wandb_path in zip(run_names, run_wandb_paths):

        # Load data from artifact
        raw_data_from_wandb = get_artifact_data(wandb_path)

        for risk_strategy in ["surv", "clf", "prevalence"]:
            
            # Get the concordance split by what the true next event was, under the risk_strategy (e.g. SurvivEHR, classifier, or prevalence based risk)
            conc_scores_by_type = _get_stratified_concordance(raw_data_from_wandb, risk_strategy, dm)

            # For each result dictionary, extract required information
            for results, category in zip(conc_scores_by_type,
                                         ["diagnoses", "medications", "lab measurements"]):
    
                for key, value in results.items():
                    key_count = event_counts.filter(pl.col("EVENT") == key)["COUNT"].item()
                    records.append([risk_strategy, key, key_count, category, value])
                
    df = pd.DataFrame(
        records, 
        columns=["Method", "True next event", "Count", "True next event category", "Concordance"]
    ).drop_duplicates()    
    
    # Map strategy codes to descriptive labels
    df["Method"] = df["Method"].replace({
        "surv": "SurvivEHR",
        "prevalence": "Prevalence",
        "clf": "Cross-entropy"
    })

    # Plotting names
    df["Next observed event"] = convert_event_names(df["True next event"], format_to="short")

    # Combine tokens that are merged for plotting using the event counts for weighted sum
    group_cols = ["Method", "True next event category", "Next observed event"]          # add others if needed
    df["w_conc"] = df["Concordance"] * df["Count"]
    agg = (
        df.groupby(group_cols, as_index=False)
          .agg(Count=("Count", "sum"),
               Concordance_sum=("w_conc", "sum"))
    )
    agg["Concordance"] = agg["Concordance_sum"] / agg["Count"]
    agg = agg.drop(columns="Concordance_sum")
    
    return agg
    

def plot_event_stratified_concordance(results):

    for category_plot in results["True next event category"].unique():
        cat_results = results[results["True next event category"] == category_plot]
        # print(cat_results)
        
        # Sort by increasing prevalence scores
        prevalence = cat_results[cat_results['Method']=='Prevalence'].sort_values('Concordance')
        event_order = prevalence['Next observed event'].tolist()

        num_cats = len(cat_results["Next observed event"])
        fig, axis = plt.subplots(1,1,figsize=(5, num_cats / 14), constrained_layout=True)

        sns.barplot(
            data=cat_results, 
            y="Next observed event",
            x="Concordance", 
            hue="Method",
            ax=axis, 
            order=event_order,
            hue_order=["SurvivEHR", "Cross-entropy", "Prevalence"],
            )

        plt.xlabel("Inter-event concordance")
        axis.legend(loc='lower center', bbox_to_anchor=(0.0, 1.02), 
                    ncol=1, frameon=True, fontsize=10)
        
        axis.set_axisbelow(True)                 # keep grid behind the bars
        axis.grid(False, axis='y', which='major')  # no grid on the bar centres
        # axis.grid(which='major', axis='y', linestyle='--', color="grey", alpha=0.5, linewidth=1.8)
        
        n = len(event_order)                     # number of y categories
        axis.set_yticks(np.arange(-0.5, n, 1), minor=True)  # halfway positions
        axis.grid(axis='y', which='minor',
                  linestyle='-', color='grey', alpha=0.5, linewidth=0.8)

        axis.tick_params(axis='y', which='minor', left=False, labelleft=False)

        plt.savefig(f"causal_concordance_{category_plot}.png")
        plt.close(fig)


In [6]:
results = get_event_stratified_concordance(["SurvivEHR", "TransformerMLP"], 
                                           ["cwlgadd/Evaluating pre-trained models/3omvr6q0", "cwlgadd/Evaluating pre-trained models/jvm6o7pu"],
                                           dm)

# print(results[results["True next event (short)"]=="INR"])

plot_event_stratified_concordance(results)

# Plot next-event matrix

In [7]:
def _get_stratified_next_event_matrix(wandb_path, dm, events_of_interest=None):

    # Load data from artifact
    api = wandb.Api()
    run = api.run(wandb_path)

    history = run.history(keys=["Test:NextEventMatrixTruth"])
    table_dict = history["Test:NextEventMatrixTruth"].iloc[-1]

    # Download may return either a local filepath (str) or a file‐handle (TextIOWrapper)
    file_obj = run.file(table_dict["path"])
    downloaded = file_obj.download(replace=True)
    
    # If `downloaded` is a string, that's the path on disk. If it's TextIOWrapper, it’s already open.
    if isinstance(downloaded, str):
        # downloaded is the path to the .json file
        with open(downloaded, "r") as f:
            table_json = json.load(f)
    elif isinstance(downloaded, io.TextIOBase):
        # downloaded is already an open file‐handle
        table_json = json.load(downloaded)
    else:
        raise RuntimeError(f"Unexpected return type from download(): {type(downloaded)}")
    
    # Now rebuild the DataFrame
    rows = table_json["data"]
    cols = table_json["columns"]
    df = pd.DataFrame(rows, columns=cols)
    
    # Extract numeric matrix (drop "Prior event" if it exists)
    if "Prior event" in df.columns:
        df = df.loc[:, df.columns != "Prior event"]

    df.columns = [dm.tokenizer._itos[int(col[len("next_event_"):])] for col in df.columns]
    df.index = [EVENT_NAME_SHORT_MAP[col] if col in EVENT_NAME_SHORT_MAP else col for col in df.columns ]
    df = df.rename(columns=EVENT_NAME_SHORT_MAP)

    df = df.groupby(df.columns, axis=1).sum()
    df = df.groupby(df.index).sum()

    print(df)
    
    # Drop columns and rows 
    if events_of_interest is not None:
        rows_to_keep = df.index.intersection(events_of_interest)
        cols_to_keep = df.columns.intersection(events_of_interest)
        df = df.loc[rows_to_keep, cols_to_keep]

    return df


def plot_stratified_next_event_matrix(run_names, run_wandb_paths, dm, events_of_interest=None, save_name="next_event.png"):

    for run_name, wandb_path in zip(run_names, run_wandb_paths):
        
        df = _get_stratified_next_event_matrix(wandb_path, dm, events_of_interest=events_of_interest)

        fig, axis = plt.subplots(1,1,figsize=(10,10), constrained_layout=True)
                                
        sns.heatmap(df, xticklabels=True, yticklabels=True,  norm=LogNorm())

        # col_indices = np.arange(matrix.shape[1]) 
        # axis.set_xticks(col_indices)
        # axis.set_xticklabels(events_of_interest, rotation=90, ha='right')

        # axis.set_yticks(col_indices)
        # axis.set_yticklabels(events_of_interest)

        plt.xticks(fontsize=8)
        plt.yticks(fontsize=8)

        plt.xlabel("Prior event")
        plt.ylabel("Next event")

        plt.grid()
        
        # plt.colorbar()
        plt.savefig(save_name)
        plt.close()


In [8]:
# display([i for i in dm.tokenizer._event_counts["EVENT"]])
# print(dm.tokenizer._itos)

# events_of_interest = ["CCB",
#                       "ACEI ",
#                       "Diuretic",
#                       "Diuretics excluding lactones ",
#                       "Lipid lowering drug",
#                       "Antiplatelet ",
#                       "Cholesterol:HDL ratio",
#                       # "Plasma_LDL_cholesterol_level_104",
#                       "HbA1c",
#                       "Urea",
#                       "Antipsychotics",
#                       "Benzodiazepines",
#                       "Propanolol",
#                       "T2DM",
#                       "Hypertension"
#                      ]

events_of_interest = [EVENT_NAME_SHORT_MAP[col] if col in EVENT_NAME_SHORT_MAP else col for col in dm.meta_information["diagnosis_table"]["event"].to_list()]
events_of_interest = [EVENT_NAME_SHORT_MAP[col] for col in dm.meta_information["diagnosis_table"]["event"].to_list()]

In [9]:
# Get the subgroups of tokens we want to plot by
lab_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
medication_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0]["event"].to_list()
diagnosis_names = dm.meta_information["diagnosis_table"]["event"].to_list()

for name, events_of_interest in zip(["lab_measurements", "medications", "diagoses"], [lab_names, medication_names, diagnosis_names]):

    events_of_interest = [EVENT_NAME_SHORT_MAP[col] if col in EVENT_NAME_SHORT_MAP else col for col in events_of_interest]
        
    fig = plot_stratified_next_event_matrix(["SurvivEHR", "TransformerMLP"], 
                                          ["cwlgadd/Evaluating pre-trained models/3omvr6q0", "cwlgadd/Evaluating pre-trained models/jvm6o7pu"],
                                          dm,
                                          events_of_interest=events_of_interest,
                                          save_name=f"next_event_pre_train_{name}.png" 
                                          )


             ACEI   AF  ALP   ALT   ARB   AST  Acarbose  Addison's  Albumin  \
ACEI         79.0  1.0  2.0   1.0   2.0  10.0       0.0        0.0    156.0   
AF            0.0  0.0  0.0   0.0   0.0   0.0       0.0        0.0      0.0   
ALP           5.0  0.0  4.0   0.0   0.0   5.0       0.0        0.0      7.0   
ALT           0.0  0.0  0.0   1.0   0.0   0.0       0.0        0.0      0.0   
ARB          12.0  0.0  1.0   0.0  21.0   5.0       0.0        0.0     18.0   
...           ...  ...  ...   ...   ...   ...       ...        ...      ...   
Warfarin      1.0  0.0  0.0   0.0   0.0   2.0       0.0        0.0      1.0   
Weak opiate   0.0  0.0  0.0   0.0   0.0  26.0       0.0        0.0      0.0   
Weight       12.0  0.0  0.0   0.0   6.0   2.0       0.0        0.0      1.0   
eGFR         79.0  0.0  0.0  48.0   0.0   2.0       0.0        0.0     27.0   
urine ACR     0.0  0.0  0.0   0.0   0.0   0.0       0.0        0.0      0.0   

             Alcohol misuse  ...  Visual impairment

## Extrapolating into future 

In [23]:
def _get_projection_to_future_data(raw_data_from_wandb, method):

    valid_keys = [_key for _key in raw_data_from_wandb.keys() if f"LookAheadMetrics{method}_no_stratify" in _key and "Test:+" in _key ]
    if len(valid_keys) == 0:
        return None

    results = []
    for _key in valid_keys:
        k_ahead= int(_key[6:-len(f"_LookAheadMetrics{method}_no_stratify")])
        results.append([method, k_ahead, raw_data_from_wandb[_key]])

    return results
    
def get_projection_to_future_data(run_names, run_wandb_paths, dm):

    """
    """
    
    results = []
    
    for run_name, wandb_path in zip(run_names, run_wandb_paths):

        # Load data from artifact
        raw_data_from_wandb = get_artifact_data(wandb_path)

        for risk_strategy in ["surv", "clf", "prevalence"]:
            # Get the concordance split by what the true next event was, under the risk_strategy (e.g. SurvivEHR, classifier, or prevalence based risk)
            strat_results = _get_projection_to_future_data(raw_data_from_wandb, risk_strategy)
            
            if strat_results is not None:
                for row in strat_results:
                    results.append(row)

    results = pd.DataFrame(data=results, columns=["Method", "Look ahead by", "Concordance"])

    results = results.replace("surv", "SurvivEHR")
    results = results.replace("prevalence", "Prevalence")
    results = results.replace("clf", "Cross-entropy")
    results = results.drop_duplicates()
    results = results[results["Look ahead by"] < 15]
    
    fig, axis = plt.subplots(1,1,figsize=(5,2.5), constrained_layout=True)

    sns.lineplot(data=results, x="Look ahead by", y="Concordance", hue="Method", ax=axis, lw=3, hue_order=["SurvivEHR", "Cross-entropy", "Prevalence"], legend=False)

    plt.ylabel("Marginal lookahead IEC")
    plt.xlabel("Number of steps ahead")
    # plt.legend()
    plt.grid()
    
    plt.savefig(f"look_ahead.png")
    plt.close(fig)
        

get_projection_to_future_data(["SurvivEHR", "TransformerMLP"], 
                                           ["cwlgadd/Evaluating pre-trained models/3omvr6q0", "cwlgadd/Evaluating pre-trained models/jvm6o7pu"],
                                           dm)