# Causal evaluation of the pre-trained SurvivEHR foundation model


In this notebook we evaluate the ability of SurvivEHR to perform next-event prediction after pre-training.

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")


Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import wandb
from hydra import compose, initialize
import polars as pl
pl.Config.set_tbl_rows(10000)
# import pandas as pd
# pd.options.display.max_rows = 10000
import logging
logging.basicConfig(level=logging.INFO)
import torch
torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

from FastEHR.dataloader.foundational_loader import FoundationalDataModule
from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}.")

 # TODO: define an env variable to fix for a local hpc environment issue, this shouldn't be needed
%env SLURM_NTASKS_PER_NODE=28   

INFO:numexpr.utils:Note: detected 72 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 72 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.


Using device: cuda.
env: SLURM_NTASKS_PER_NODE=28


## Choosing configurations
The default configuration is for pre-training. Here we modify as necesssary

Here we choose to load in the configuration for a small **pre-trained** 11.4M parameter model, named "CR_11M". We specfiy the `zero-shot` experiment type, which will lead to running a ```CausalExperiment```. 

We tell this experiment that no further training is needed. Additionally, we do choose to perform testing (true by default). As this is a supervised model, this tests the ability to predict the outcomes of interest. In this notebook, this is chosen to be those of the cohort study for predicting Cardiovascular Disease in a Type 2 Diabetes Mellitus population, and we add the folder containing this dataset to the configuration. 

```Note: As this is a supervised dataset, we need to tell the DataModule that the last event observed is a target and must be stripped. This is done by passing a list of targets to the configuration, overriding the null default. This lets the DataModule know that it should process batches as supervised.```

We set the number of workers to be appropriate for the number of CPUs available to reduce bottlenecking, and tell the experiment that we do not want to limit the number of testing batches. In addition, we specify where we want any checkpoints to be saved to avoid bloating the repository.

# Run small (11M) Competing-Risk model experiment

```

```

In [4]:
# pre_trained_model_ids = ['SurvivEHR-cr-small', 'SurvivEHR-cr-small-v1', 'SurvivEHR-cr', 'SurvivEHR-cr-v1', 'SurvivEHR-cr-v1-v1', 'SurvivEHR-cr-384', 'SurvivEHR-cr-384-v1', 'crPreTrain_small_1337',
                        # 'SurvivEHR-cr-small-192', "SurvivEHR-cr-small-192-v1"]


pre_trained_model, config_name = "SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1", "config_CompetingRisk11M"
# pre_trained_model, config_name = "SurvivEHR-cr-small-debug7_exp1000-v1-v4", "config_CompetingRisk11M"
# pre_trained_model, config_name = "SurvivEHR-cr-big-debug3_2_exp1000-v1", "config_CompetingRiskMOTOR"

print(pre_trained_model)

SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1


In [5]:
wandb.finish()

# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../../confs", job_name="causal_metric_testing_notebook"):
    cfg = compose(config_name=config_name, 
                  overrides=[# Experiment setup
                             "experiment.project_name='Evaluating pre-trained models'",
                             f"experiment.run_id='{pre_trained_model}'",
                             "experiment.train=False",
                             "experiment.test=True",
                             "data.batch_size=128",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=12",
                             "optim.limit_test_batches=0.035",
                            ]
                 )     

model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

wandb.finish()

INFO:root:Running cr on 72 CPUs and 1 GPUs
INFO:root:# Loading DataModule for dataset /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/. This will be loaded in causal form.
INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 toke

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.trainer.connectors.signal_connector:SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

Loaded model with 11.20919 M parameters


0,1
Test:+0_LookAheadMetricsprevalence_no_stratify,▁
Test:+0_LookAheadMetricssurv_no_stratify,▁
Test:+10_LookAheadMetricsprevalence_no_stratify,▁
Test:+10_LookAheadMetricssurv_no_stratify,▁
Test:+13_LookAheadMetricsprevalence_no_stratify,▁
Test:+13_LookAheadMetricssurv_no_stratify,▁
Test:+16_LookAheadMetricsprevalence_no_stratify,▁
Test:+16_LookAheadMetricssurv_no_stratify,▁
Test:+19_LookAheadMetricsprevalence_no_stratify,▁
Test:+19_LookAheadMetricssurv_no_stratify,▁

0,1
Test:+0_LookAheadMetricsprevalence_no_stratify,0.85492
Test:+0_LookAheadMetricssurv_no_stratify,0.98672
Test:+10_LookAheadMetricsprevalence_no_stratify,0.8498
Test:+10_LookAheadMetricssurv_no_stratify,0.81958
Test:+13_LookAheadMetricsprevalence_no_stratify,0.84834
Test:+13_LookAheadMetricssurv_no_stratify,0.78075
Test:+16_LookAheadMetricsprevalence_no_stratify,0.85047
Test:+16_LookAheadMetricssurv_no_stratify,0.77056
Test:+19_LookAheadMetricsprevalence_no_stratify,0.85183
Test:+19_LookAheadMetricssurv_no_stratify,0.75025


In [6]:
wandb.finish()
print(pre_trained_model)

SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1


In [7]:
# display(dm.encode(['IHDINCLUDINGMI_OPTIMALV2', 'ISCHAEMICSTROKE_V2', 'MINFARCTION', 'STROKEUNSPECIFIED_V2', 'STROKE_HAEMRGIC']))
# display(dm.encode(['HYPERTENSION']))
# # display(dm.decode([95, 175, 263,249]).split(" "))

In [8]:
# dm.tokenizer._event_counts["EVENT"][-5:].to_list()

In [9]:
# valued_events = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
# non_valued_events = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0]["event"].to_list()
# diagnoses = dm.meta_information["diagnosis_table"]["event"].to_list()

# # print(valued_events)

In [10]:
print(pre_trained_model)
os.makedirs(f"figs/metrics/{pre_trained_model}/", exist_ok=True) 

SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1


In [26]:
import matplotlib.pyplot as plt
import numpy as np
import wandb
import polars as pl
wandb.login()

# Load causal_eval results from log
api = wandb.Api()
run = api.run("cwlgadd/Evaluating pre-trained models/2qttxkwm") # 3omvr6q0

In [27]:
history = run.history(keys=None)
raw_data_from_wandb = {}
for key in history.keys():
    raw_data_from_wandb = {**raw_data_from_wandb, key: history[key].to_numpy()[-1]}
# display(raw_data_from_wandb.sort())



# Get dataloader so we can extract event names by type

In [28]:
valued_events = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
non_valued_events = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0]["event"].to_list()
diagnoses = dm.meta_information["diagnosis_table"]["event"].to_list()

# Next event concordance

In [29]:
Cinter_keys = [_key for _key in raw_data_from_wandb.keys() if "Test:CausalMetricssurv_stratify_by_" in _key ]

decoded_cintra_diagnoses = {}
decoded_cintra_non_valued = {}
decoded_cintra_valued = {}

for _key in Cinter_keys:
    _event = int(_key[len("Test:CausalMetricssurv_stratify_by_"):])                 # token
    _event_name = dm.decode([_event]).split(" ")[0]         # string
    _event_cintra = raw_data_from_wandb[_key]               # concordance

    if _event_name in diagnoses:# .upper() == _event_name:
        decoded_cintra_diagnoses = {**decoded_cintra_diagnoses, _event_name: _event_cintra}
    elif _event_name in non_valued_events:
        decoded_cintra_non_valued = {**decoded_cintra_non_valued, _event_name: _event_cintra}
    elif _event_name in valued_events:
        decoded_cintra_valued = {**decoded_cintra_valued, _event_name: _event_cintra}
    else:
        raise NotImplementedError


In [30]:
# display(decoded_cintra_diagnoses)
# display(decoded_cintra_other)

In [31]:
BaseCinter_keys = [_key for _key in raw_data_from_wandb.keys() if "Test:CausalMetricsprevalence_stratify_by_" in _key ]

base_decoded_cintra_diagnoses = {}
base_decoded_cintra_non_valued = {}
base_decoded_cintra_valued = {}

base_prevalence_diagnoses = {}
base_prevalence_non_valued = {}
base_prevalence_valued = {}


for _key in BaseCinter_keys:
    _event = int(_key[len("Test:CausalMetricsprevalence_stratify_by_"):])                 # token
    _event_name = dm.decode([_event]).split(" ")[0]         # string
    _event_cintra = raw_data_from_wandb[_key]               # concordance

    prevalence = dm.tokenizer._event_counts
    prevalence = prevalence.filter(pl.col("EVENT") ==_event_name)["COUNT"][0]

    if _event_name in diagnoses: #.upper() == _event_name:
        base_decoded_cintra_diagnoses = {**base_decoded_cintra_diagnoses, _event_name: _event_cintra}
        base_prevalence_diagnoses = {**base_prevalence_diagnoses, _event_name: prevalence}
        
    elif _event_name in non_valued_events:
        base_decoded_cintra_non_valued = {**base_decoded_cintra_non_valued, _event_name: _event_cintra}
        base_prevalence_non_valued = {**base_prevalence_non_valued, _event_name: prevalence}

    elif _event_name in valued_events:
        base_decoded_cintra_valued = {**base_decoded_cintra_valued, _event_name: _event_cintra}
        base_prevalence_valued = {**base_prevalence_valued, _event_name: prevalence}

    else:
        raise NotImplementedError


In [32]:
# display(base_decoded_cintra_diagnoses)

In [33]:
keys_included_diagnoses = list(set(base_decoded_cintra_diagnoses.keys()) & set(decoded_cintra_diagnoses.keys()))
keys_included_non_valued = list(set(base_decoded_cintra_non_valued.keys()) & set(decoded_cintra_non_valued.keys()))
keys_included_valued = list(set(base_decoded_cintra_valued.keys()) & set(decoded_cintra_valued.keys()))


In [34]:
for dict_name, result_dict, result_dict_base, result_dict_prev, keys_to_include in zip(["diagnoses", "medications", "measurements"],
                                                                     [decoded_cintra_diagnoses, decoded_cintra_non_valued, decoded_cintra_valued], 
                                                                     [base_decoded_cintra_diagnoses, base_decoded_cintra_non_valued, base_decoded_cintra_valued], 
                                                                     [base_prevalence_diagnoses, base_prevalence_non_valued, base_prevalence_valued],
                                                                     [keys_included_diagnoses, keys_included_non_valued, keys_included_valued]
                                                                     ):
    plt.close()
    # plt.figure(figsize=(len(keys_to_include)/5,5))
    fig, ax1 = plt.subplots(figsize=(len(keys_to_include)/4,8))
    ax2 = ax1.twinx()  

    X_axis = np.arange(len(keys_to_include)) 

    Y_base = [result_dict_base[_key] for _key in keys_to_include]
    Y_survivEHR = [result_dict[_key] for _key in keys_to_include]
    Y_log_prevalence = [np.log(result_dict_prev[_key]) for _key in keys_to_include]

    # Sort by prevalence
    arg_sort = np.argsort(Y_log_prevalence)
    Y_base = [Y_base[_i] for _i in arg_sort]
    Y_survivEHR = [Y_survivEHR[_i] for _i in arg_sort]
    Y_log_prevalence = [Y_log_prevalence[_i] for _i in arg_sort]
    keys_to_include = [keys_to_include[_i] for _i in arg_sort]

    width = 0.25
    ax1.bar(X_axis - width, Y_base, width, label = f'Concordance by prevalence (Average over events: {raw_data_from_wandb["Test:CausalMetricsprevalence_no_stratify"]:.3f})', color="mediumblue") 
    ax1.bar(X_axis, Y_survivEHR, width, label = f'Concordance by SurvivEHR (Average over events: {raw_data_from_wandb["Test:CausalMetricssurv_no_stratify"]:.3f})', color="firebrick") 
    ax2.plot(X_axis, Y_log_prevalence, width, label='Log-prevalence', color="darkseagreen", marker=".")  #  + width

    ax1.set_xticks(X_axis, keys_to_include, rotation=90) 
    # ax1.xticks(X_axis, keys_to_include) 
    ax1.set_xlabel("Events") 
    ax1.set_ylabel("Self-supervised Concordance") 
    ax2.set_ylabel("Log Prevalence") 
    ax1.legend(loc="upper left")
    ax2.legend(loc="upper right")
    ax1.set_ylim(0, 1.2)
    ax2.set_ylim(np.min(Y_log_prevalence)*0.95, np.max(Y_log_prevalence)*1.1)
    

    # plt.bar(result_dict.keys(), result_dict.values(), 0.5, color='g')
    # ax1.xticks()

    # ybar = raw_data_from_wandb["Test:Cinter"]
    # ax1.plot([0, len(result_dict)-1], 
    #          [ybar, ybar],
    #          label=f"SurvivEHR marginalised over events",
    #          color="firebrick")

    # ybar = raw_data_from_wandb["Test:base_Cinter"]
    # ax1.plot([0, len(result_dict)-1], 
    #          [ybar, ybar],
    #          label=f"Prevalence marginalised over events",
    #          color="mediumblue")
    
    plt.tight_layout()
    plt.savefig(f"figs/metrics/{pre_trained_model}/inter_causal_eval_{dict_name}.png", bbox_inches="tight")
    plt.close()

# Future events

## SurvivEHR

In [20]:
Cinter_keys = [_key for _key in raw_data_from_wandb.keys() if "+" in _key and "prevalence" not in _key ]

x_survivEHR, y_survivEHR = [], []
for _key in Cinter_keys:
    x_survivEHR.append(int(_key[6:-33]) + 1 )                # steps ahead
    y_survivEHR.append(raw_data_from_wandb[_key] )           # concordance


arg_sort = np.argsort(x_survivEHR)
x_survivEHR = [x_survivEHR[_i] for _i in arg_sort]
y_survivEHR = [y_survivEHR[_i] for _i in arg_sort]

print(x_survivEHR)
print(y_survivEHR)

[1, 2, 3, 4, 5, 8, 11, 14, 17, 20]
[0.9867206756214528, 0.9235438435297677, 0.9036353028748456, 0.8843464116617102, 0.8586765504680968, 0.8083891604804155, 0.8195837363045009, 0.7807488841130765, 0.7705623374024421, 0.7502532407475381]


In [21]:
Cinter_keys = [_key for _key in raw_data_from_wandb.keys() if "+" in _key and "prevalence" in _key ]
# print(Cinter_keys)

x_base, y_base = [], []
for _key in Cinter_keys:
    x_base.append(int(_key[6:-len("_LookAheadMetricsprevalence_no_stratify")]) + 1 )                # steps ahead
    y_base.append(raw_data_from_wandb[_key] )                   # concordance

arg_sort = np.argsort(x_base)
x_base = [x_base[_i] for _i in arg_sort]
y_base = [y_base[_i] for _i in arg_sort]

print(x_base)
print(y_base)

[1, 2, 3, 4, 5, 8, 11, 14, 17, 20]
[0.8549209927966827, 0.8554057884602225, 0.8562878467821434, 0.856383584919344, 0.8553855002076877, 0.8501297603959204, 0.8498032617479898, 0.848335813082052, 0.8504674233111297, 0.8518284716003348]


In [22]:
plt.close()

plt.plot(x_survivEHR, y_survivEHR,
         label=f"SurvivEHR",
         color="firebrick")

plt.plot(x_base, y_base,
         label=f"Prevalence",
         color="mediumblue")

plt.xticks(x_survivEHR)
plt.ylabel("Self-supervised multi-step concordance")
plt.xlabel("Number of steps ahead")
plt.legend()
plt.tight_layout()
plt.savefig(f"figs/metrics/{pre_trained_model}/inter_decay.png", bbox_inches="tight")
    

In [23]:
# new_dictionary

In [24]:
import copy
start = time.time()   # starting time
for batch in dm.train_dataloader():
    # print(batch["tokens"][1,:])
    
    c_batch = convert_batch_to_none_causal(batch)
    # print(c_batch["tokens"][1,:])
    # print(c_batch["target_token"][1])

    # print(batch["tokens"][1,:])
    
    break
    
print(f"batch loaded in {time.time()-start} seconds")    
    
# for key in batch.keys():
#     print(f"{key}".ljust(20) + f"{batch[key].shape}")

# tokens = batch["tokens"][0].tolist()    
# sentence = dm.decode(tokens).split(" ")
# for token, value in zip(sentence, batch["values"][0].tolist()):
#     print(f"{token}:".ljust(40) + f"{value}")

NameError: name 'time' is not defined

In [None]:
display(batch.keys())
display(c_batch.keys())

print(batch["static_covariates"].shape)

# print(dm.train_set.static_1hot)
# print(dm.train_set.static_1hot["SEX"].categories_)
# print(dm.train_set.static_1hot["IMD"].categories_)
# print(dm.train_set.static_1hot["ETHNICITY"].categories_)

print(batch["tokens"][1,:])
print(c_batch["tokens"][1,:])
print(c_batch["target_token"][1])

## View an example sample

In [None]:
dm.test_set.view_sample(11003, max_dynamic_events=None, report_time=True)

# Custom wrapper prediction last token

To begin with, I will just loop over samples individually to test the zero-shot capacity of SurvivEHR. 

In [None]:


# Verifying on datamodule 
for _idx, batch in enumerate(dm.test_dataloader()):
    if _idx > 10:
        break
    print(_idx)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))
    batch = replace_last_non_pad_with_pad(batch)
    print(torch.stack([batch["tokens"][10,:5], 
                       batch["values"][10,:5],  
                       batch["ages"][10,:5],
                       batch["attention_mask"][10,:5]]))

In [None]:
outcome_of_interest = ["COPD", "SUBSTANCEMISUSE"]
outcome_token = dm.encode(outcome_of_interest)[0]
print(outcome_token)
# print(model(batch))

In [None]:
Hs, labels = [], []
mins,maxes=[],[]
for _idx, batch in enumerate(dm.test_dataloader()):

    batch = replace_last_non_pad_with_pad(batch)
    print(batch["tokens".shape)
    outputs, _, hidden_states = model(batch, is_generation=True)
    print(outputs)
    
    hidden_states = hidden_states.cpu().detach().numpy()                           # (64, 128, 384) 
    Hs.append( hidden_states.reshape(hidden_states.shape[0], -1) )
    labels.append((batch["target_token"] == outcome_token).long().numpy())

    if _idx == 9:
        break



# Visualise hidden dimension labelled by target

In [None]:
import umap
from sklearn.preprocessing import StandardScaler

H = np.concatenate(Hs, 0)
lbl = np.concatenate(labels, 0)

H = StandardScaler().fit_transform(H)
reducer = umap.UMAP()
H_proj = reducer.fit_transform(H)

plt.close()
plt.scatter(H_proj[:,0], H_proj[:,1], c=lbl)
plt.savefig(save_path + f"zero_shot/hidden_umap.png")

In [None]:
print(outputs["surv"]["surv_CDF"][outcome_token].shape)

# The first two tokens in the vocab correspond to the PAD and UNK tokens. There is no CDF corresponding to the PAD token, so the indexing for surv_CDF begins as ["UNK", "ADDISONS_DISEASE", ...]
# print(dm.decode([0,1,2]))

outcomes = ["COPD", "SUBSTANCEMISUSE"]
outcome_tokens = dm.encode(outcomes)

# for outcome in outcomes:
    # observed_outcome_token = dm.encode([outcome])[0]
cdf = np.zeros_like(outputs["surv"]["surv_CDF"][0])
lbls = np.zeros_like(batch["target_token"])

for _outcome_token in outcome_tokens:
    cdf += outputs["surv"]["surv_CDF"][_outcome_token - 1] 
    lbls += (batch["target_token"] == _outcome_token).long().numpy()

plt.close()
cdf_true = cdf[lbls==1,:]
cdf_false = cdf[lbls==0,:]
for i in range(cdf_true.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_true[i,:], c="r", label="outcome occurred next" if i == 0 else None, alpha=1)
for i in range(cdf_false.shape[0]):
    plt.plot(np.linspace(1,1826,1826), cdf_false[i,:], c="k", label="outcome did not occur next" if i == 0 else None, alpha=0.3)

plt.legend(loc=2)
plt.xlabel("days")
plt.ylabel(f"P(t>T) - outcomes={','.join(outcomes)}")
plt.savefig(save_path + f"zero_shot/cdf_outcomes.png")

In [None]:
print(batch["target_token"].unique())
print(len(outputs["surv"]["surv_CDF"]))

In [None]:
dm.decode([2])

In [None]:
outputs["surv"]["surv_CDF"][observed_outcome_token - 1]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import wandb
import polars as pl
wandb.login()

# Load causal_eval results from log
api = wandb.Api()
run = api.run("cwlgadd/Evaluating pre-trained models/1felbu63")

history = run.history(keys=None)
raw_data_from_wandb = {}
for key in history.keys():
    raw_data_from_wandb = {**raw_data_from_wandb, key: history[key].to_numpy()[-1]}



In [None]:
import wandb
import json
import pandas as pd
import io

history = run.history(keys=["Test:NextEventMatrixTruth"])
table_dict = history["Test:NextEventMatrixTruth"].iloc[-1]

print("Available keys in table_dict:", list(table_dict.keys()))
print("table_dict['path'] =", table_dict["path"])

# 2) Download may return either a local filepath (str) or a file‐handle (TextIOWrapper)
file_obj = run.file(table_dict["path"])
downloaded = file_obj.download(replace=True)

# 3) If `downloaded` is a string, that's the path on disk. If it's TextIOWrapper, it’s already open.
if isinstance(downloaded, str):
    # downloaded is the path to the .json file
    with open(downloaded, "r") as f:
        table_json = json.load(f)
elif isinstance(downloaded, io.TextIOBase):
    # downloaded is already an open file‐handle
    table_json = json.load(downloaded)
else:
    raise RuntimeError(f"Unexpected return type from download(): {type(downloaded)}")

# 4) Now rebuild the DataFrame
rows = table_json["data"]
cols = table_json["columns"]
df = pd.DataFrame(rows, columns=cols)

# 5) Extract numeric matrix (drop "Prior event" if it exists)
if "Prior event" in df.columns:
    matrix = df.loc[:, df.columns != "Prior event"].to_numpy()
else:
    matrix = df.to_numpy()

print("Recovered matrix shape:", matrix.shape)
print(matrix)