In [None]:
from pathlib import Path
import tensorflow as tf
import numpy as np
from pathlib import Path

from config import DATA_ROOT
from config_targets import TARGET_META

from utils import import_key_data
from ecg_predict import batched_predict_all
from eval import evaluate_all_payloads, evaluate_explanation
from explainer import run_fused_pipeline_for_classes
from selection import build_selection_df_with_aliases, build_y_true_from_labels

%load_ext autoreload
%autoreload 2

In [2]:
ROOT = Path.cwd().parent

MODEL_PATH         = ROOT / "model" / "resnet_final.keras"
SNOMED_CLASSES_NPY = ROOT / "data" / "snomed_classes.npy"
SEL_DF_CSV         = ROOT / "data" / "ecg_model_pred_data.csv"

model = tf.keras.models.load_model(MODEL_PATH, compile=False)
class_names = np.load(SNOMED_CLASSES_NPY, allow_pickle=True)

In [None]:
gender, age, labels, ecg_filenames = import_key_data(DATA_ROOT)

# Make sure class_names matches the model's output order
class_names = np.load("snomed_classes.npy", allow_pickle=True).astype(str)

# 1) Build ground-truth multi-hot labels
y_true = build_y_true_from_labels(labels, class_names)

# 2) Predict probabilities
probs = batched_predict_all(
    model,
    ecg_filenames,
    maxlen=5000,
    batch_size=32,
)

# 3) Optional: binary predictions (0/1) at some threshold
pred_threshold = 0.5
y_pred = (probs >= pred_threshold).astype(np.int8)

# 4) Save everything
np.save("ecg_filenames.npy", ecg_filenames)
np.save("ecg_model_probs.npy", probs)
np.save("ecg_y_true.npy", y_true)
np.save(f"ecg_y_pred_{pred_threshold:.2f}.npy", y_pred)


In [None]:
ecg_filenames = np.load("ecg_filenames.npy", allow_pickle=True)
probs         = np.load("ecg_model_probs.npy")
class_names   = np.load("snomed_classes.npy", allow_pickle=True)
y_true        = np.load("ecg_y_true.npy")

sel_df = build_selection_df_with_aliases(
    ecg_filenames=ecg_filenames,
    probs=probs,
    class_names=class_names,
    target_meta=TARGET_META,
    y_true=y_true,
    k_per_class=3,
    min_prob=0.85,
    max_duration_sec=20.0,
    duration_cache_path="ecg_durations.npy"
)

sel_df.to_csv("ecg_xai_sel_meta_p0.85_k5.csv", index=False)
sel_df


[INFO] Estimating durations and keeping ECGs <= 20.0 s...
[INFO] Duration filter: keeping 40665/43101 ECGs (<= 20.0 s).
[INFO] relaxing selection for 17338001 (ventricular premature beats)
[CLASS 17338001 (ventricular premature beats)] picked 3 examples.


Unnamed: 0,group_class,filename,sel_idx,duration_sec,prob_meta
0,17338001,C:\data\georgia-12lead-ecg-challenge-database\...,12161,10.0,0.514689
1,17338001,C:\data\georgia-12lead-ecg-challenge-database\...,14932,10.0,0.999942
2,17338001,C:\data\georgia-12lead-ecg-challenge-database\...,10831,10.0,0.626765


In [4]:
target_classes = list(TARGET_META.keys())   # ["164889003", "426783006", "17338001"]

print("Target classes:", target_classes)

all_fused_payloads, df_lime_all, df_ts_all = run_fused_pipeline_for_classes(
    target_classes=target_classes,
    sel_df=sel_df,
    model=model,
    class_names=class_names,
    max_examples_per_class=5,
    plot=False,
)

Target classes: ['17338001']


In [8]:
df_eval_all = evaluate_all_payloads(
    all_fused_payloads,
    method_label="LIME+TimeSHAP",
    debug=True
)

df_eval_all

Unnamed: 0,meta_code,class_name,sel_idx,mat_path,method,strict_attauc,lenient_attauc,n_tokens
0,17338001,ventricular premature beats,10831,C:\data\georgia-12lead-ecg-challenge-database\...,LIME+TimeSHAP,0.874136,1.0,360
1,17338001,ventricular premature beats,12161,C:\data\georgia-12lead-ecg-challenge-database\...,LIME+TimeSHAP,0.889893,1.0,528
2,17338001,ventricular premature beats,14932,C:\data\georgia-12lead-ecg-challenge-database\...,LIME+TimeSHAP,0.85751,1.0,288


In [9]:
meta_code = target_classes[0]
row = df_eval_all.iloc[0] # pick one row
sel_idx = int(row["sel_idx"])
payload = all_fused_payloads[meta_code][sel_idx]

out = evaluate_explanation(
    mat_path=payload["mat_path"],
    fs=payload.get("fs", 500.0),
    payload=payload,
    class_name=row["class_name"],
    debug=True,
)

print("Strict pos/neg:", out.debug.n_pos_strict, out.debug.n_neg_strict)
print("Lenient pos/neg:", out.debug.n_pos_lenient, out.debug.n_neg_lenient)
print("Top tokens:")
for t in out.debug.top_tokens:
    print(
        f"{t.idx:3d}  {t.lead:3s}  {t.window_type:7s}  "
        f"{t.t_start:5.2f}-{t.t_end:5.2f}  "
        f"score={t.score:7.4f}  "
        f"strict={t.strict_label}  lenient={t.lenient_label}"
    )

Strict pos/neg: 90 270
Lenient pos/neg: 150 210
Top tokens:
 53  II   qrs_term   5.73- 5.85  score= 1.0000  strict=0  lenient=1
124  aVL  qrs       2.93- 3.06  score= 1.0000  strict=0  lenient=0
139  aVL  qrs_term   3.02- 3.14  score= 1.0000  strict=0  lenient=0
244  V3   qrs       2.93- 3.06  score= 1.0000  strict=1  lenient=1
 38  II   qrs       5.64- 5.77  score= 1.0000  strict=0  lenient=1
193  V1   qrs       9.22- 9.35  score= 1.0000  strict=1  lenient=1
167  aVF  qrs_term   1.66- 1.78  score= 1.0000  strict=0  lenient=0
240  V3   qrs       0.22- 0.35  score= 1.0000  strict=1  lenient=1
297  V4   qrs_term   8.43- 8.55  score= 1.0000  strict=0  lenient=1
208  V1   qrs_term   9.31- 9.43  score= 1.0000  strict=1  lenient=1
