In [1]:
import os
import sys
sys.path.append('../src')
os.getcwd()

%pip install -e .. -qqq

Note: you may need to restart the kernel to use updated packages.


In [2]:
import numpy as np
import pandas as pd
import tensorflow as tf
import zlib
import shutil
from pathlib import Path
from datetime import datetime

from ecgxai.utils import save_run, load_run
from ecgxai.config_targets import TARGET_META
from ecgxai.explainer import run_fused_pipeline_for_classes
from ecgxai.selection import build_selection_df_with_aliases
from ecgxai.stability_eval import run_extra_beat_stability_experiment
from ecgxai.eval import evaluate_all_payloads

%load_ext autoreload
%autoreload 2

In [None]:
ROOT = Path.cwd().parent

MODEL_PATH         = ROOT / "model" / "resnet_final.h5"
SNOMED_CLASSES_NPY = ROOT / "data" / "snomed_classes.npy"

ECG_FILENAMES_PATH = ROOT / "data" / "ecg_filenames.npy"
PROBS_PATH         = ROOT / "data" / "ecg_model_probs.npy"
CLASS_NAMES_PATH   = ROOT / "data" / "snomed_classes.npy"
Y_TRUE_PATH        = ROOT / "data" / "ecg_y_true.npy"
ECG_DURATIONS_PATH = ROOT / "data" / "ecg_durations.npy"
STAB_OUT_DIR       = ROOT / "outputs" / "extra_beat_aug"
ECG_PRED_PATH      = ROOT / "outputs" / "ecg_xai_sel_meta_p0.85_k5.csv"

model = tf.keras.models.load_model(MODEL_PATH, compile=False)
class_names = np.load(SNOMED_CLASSES_NPY, allow_pickle=True)

In [None]:
# from config import DATA_ROOT, MAXLEN
# from utils import import_key_data
# from ecg_predict import batched_predict_all
# from selection import build_y_true_from_labels

# gender, age, labels, ecg_filenames = import_key_data(DATA_ROOT)

# # 1) Build ground-truth multi-hot labels
# y_true = build_y_true_from_labels(labels, class_names)

# # 2) Predict probabilities
# probs = batched_predict_all(
#     model,
#     ecg_filenames,
#     maxlen=MAXLEN,
#     batch_size=32,
# )

# # 3) Optional: binary predictions (0/1) at some threshold
# pred_threshold = 0.5
# y_pred = (probs >= pred_threshold).astype(np.int8)

# Y_PRED_PATH      = ROOT / "outputs" / f"ecg_y_pred_{pred_threshold:.2f}.npy"

# # 4) Save everything
# np.save(ECG_FILENAMES_PATH, ecg_filenames)
# np.save(PROBS_PATH, probs)
# np.save(Y_TRUE_PATH, y_true)
# np.save(Y_PRED_PATH, y_pred)


In [25]:
ecg_filenames = np.load(ECG_FILENAMES_PATH, allow_pickle=True)
probs         = np.load(PROBS_PATH)
class_names   = np.load(CLASS_NAMES_PATH, allow_pickle=True)
y_true        = np.load(Y_TRUE_PATH)

sel_df = build_selection_df_with_aliases(
    ecg_filenames=ecg_filenames,
    probs=probs,
    class_names=class_names,
    target_meta=TARGET_META,
    y_true=y_true,
    k_per_class=50,
    min_prob=0.85,
    max_duration_sec=20.0,
    duration_cache_path=str(ECG_DURATIONS_PATH)
)

sel_df.to_csv(ECG_PRED_PATH, index=False)
sel_df

[INFO] Duration filter enabled: keeping ECGs <= 20.0 s
[INFO] Duration filter: keeping 11265/13187 ECGs.
[CLASS 164889003 (atrial fibrillation)] picked 50 examples.
[CLASS 426783006 (sinus rhythm)] picked 50 examples.


Unnamed: 0,group_class,filename,sel_idx,duration_sec,prob_meta
0,164889003,C:\UHull\Data\Training_WFDB\A2249.mat,5347,16.968,0.940748
1,164889003,C:\UHull\Data\Training_WFDB\A2978.mat,6076,10.000,0.986492
2,164889003,C:\UHull\Data\Training_WFDB\A2618.mat,5716,14.474,0.988416
3,164889003,C:\UHull\Data\Training_WFDB\A0003.mat,3101,10.000,0.959976
4,164889003,C:\UHull\Data\Training_WFDB\A5377.mat,8475,19.122,0.990656
...,...,...,...,...,...
95,426783006,C:\UHull\Data\WFDB\HR00005.mat,8921,10.000,0.993680
96,426783006,C:\UHull\Data\WFDB\HR01101.mat,10017,10.000,0.992534
97,426783006,C:\UHull\Data\WFDB\HR02292.mat,11208,10.000,0.955642
98,426783006,C:\UHull\Data\WFDB\HR04184.mat,13100,10.000,0.993729


In [16]:
# ---- choose mode ----
run_mode = "demo"   # "eval" or "demo"

MODE_CFG = {
    "eval": {"max_examples_per_class": 50, "plot": False},
    "demo": {"max_examples_per_class": 3,  "plot": True},
}

EVAL_CSV_PATH      = ROOT / "outputs" / run_mode / 'df_eval_attauc_deletion.csv'
STAB_CSV_PATH      = ROOT / "outputs" / run_mode / "df_eval_stability.csv"

In [17]:
cfg = MODE_CFG[run_mode]

OUT_BASE = ROOT / "outputs"
RUN_DIR = OUT_BASE / run_mode

# ---- cache check ----
def run_assets_exist(run_dir: Path) -> bool:
    return (
        (run_dir / "all_fused_payloads.joblib").exists()
        and (run_dir / "df_lime_all.parquet").exists()
        and (run_dir / "df_ts_all.parquet").exists()
        and (run_dir / "sel_df.parquet").exists()
    )

target_classes = list(TARGET_META.keys())

if run_assets_exist(RUN_DIR):
    print(f"[{run_mode}] Loading cached assets from: {RUN_DIR}")
    all_fused_payloads, df_lime_all, df_ts_all, sel_df_cached = load_run(RUN_DIR)
    
    # Optionally replace sel_df with cached one to keep indices consistent
    sel_df = sel_df_cached
else:
    print(f"[{run_mode}] No cache found. Running pipeline...\n")
    all_fused_payloads, df_lime_all, df_ts_all = run_fused_pipeline_for_classes(
        target_classes=target_classes,
        sel_df=sel_df,
        model=model,
        class_names=class_names,
        max_examples_per_class=cfg["max_examples_per_class"],
        plot=cfg["plot"],
    )

    save_run(
        RUN_DIR,
        all_fused_payloads,
        df_lime_all,
        df_ts_all,
        sel_df,
        meta={
            "mode": run_mode,
            "saved_at": datetime.now().isoformat(timespec="seconds"),
            "target_classes": target_classes,
            "max_examples_per_class": cfg["max_examples_per_class"],
            "plot": cfg["plot"],
        },
    )
    print(f"[{run_mode}] Saved to: {RUN_DIR}")

[demo] Loading cached assets from: c:\UHull\ecg-xai\outputs\demo


In [18]:
if EVAL_CSV_PATH.exists():
    print(f"Loading cached df_eval_all from: {EVAL_CSV_PATH}")
    df_eval_all = pd.read_csv(EVAL_CSV_PATH)
else:
    print("Cached CSV not found — running evaluate_all_payloads()...")

    df_eval_all = evaluate_all_payloads(
        all_payloads=all_fused_payloads,
        method_label="LIME+TimeSHAP",
        model=model,
        class_names=class_names,
    )

    EVAL_CSV_PATH.parent.mkdir(parents=True, exist_ok=True)
    df_eval_all.to_csv(EVAL_CSV_PATH, index=False)
    print(f"Saved df_eval_all to: {EVAL_CSV_PATH}")

df_eval_all


Cached CSV not found — running evaluate_all_payloads()...
Saved df_eval_all to: c:\UHull\ecg-xai\outputs\demo\df_eval_attauc_deletion.csv


Unnamed: 0,meta_code,class_name,sel_idx,mat_path,method,strict_attauc,lenient_attauc,precision_k,strict_p_at_k,lenient_p_at_k,deletion_auc,faithfulness_gain,n_tokens
0,164889003,atrial fibrillation,3954,C:\UHull\Data\Training_WFDB\A0850.mat,LIME+TimeSHAP,0.881522,0.616818,20,0.5,0.8,0.299139,0.002795,276
1,164889003,atrial fibrillation,7920,C:\UHull\Data\Training_WFDB\A4816.mat,LIME+TimeSHAP,0.857396,0.649963,20,0.35,1.0,0.286059,0.01218,156
2,164889003,atrial fibrillation,8285,C:\UHull\Data\Training_WFDB\A5181.mat,LIME+TimeSHAP,0.908459,0.661862,20,0.65,1.0,0.23307,0.186725,276
3,426783006,sinus rhythm,9042,C:\UHull\Data\WFDB\HR00120.mat,LIME+TimeSHAP,0.8425,1.0,20,0.35,1.0,0.294592,0.003249,120
4,426783006,sinus rhythm,11786,C:\UHull\Data\WFDB\HR02864.mat,LIME+TimeSHAP,0.847314,1.0,20,0.3,1.0,0.296682,0.002256,132
5,426783006,sinus rhythm,12105,C:\UHull\Data\WFDB\HR03183.mat,LIME+TimeSHAP,0.870118,1.0,20,0.4,1.0,0.294881,0.003958,156


### Stability to an Extra Heartbeat (All Target Classes)

We assess explanation stability under a synthetic perturbation where one existing
heartbeat is duplicated and re-inserted at the **middle** of the recording.

For each target rhythm class in `TARGET_META`:

1. We take the example that was already explained and evaluated in `df_eval_all`.
2. We run `run_extra_beat_stability_experiment`, which:
   - creates augmented versions of the ECG
     - one with an **extra beat inserted in the middle**
   - recomputes **fused LIME + TimeSHAP** explanations for each version
   - compares region-level importance profiles over the *shared beats* using:
     - **Spearman rank correlation** (global ordering of important regions)
     - **Jaccard@K** overlap of the top-K most important regions.

This gives us per-class stability metrics for:
- original vs **extra-beat-in-middle**


In [19]:
# ---- Load if exists, otherwise run + save ----
if STAB_CSV_PATH.exists():
    print(f"Loading cached df_stability from: {STAB_CSV_PATH}")
    df_stability = pd.read_csv(STAB_CSV_PATH)
else:
    print("Cached df_stability not found — running stability eval...")

    # Clean run: wipe outputs dir first
    if STAB_OUT_DIR.exists():
        shutil.rmtree(STAB_OUT_DIR)
    STAB_OUT_DIR.mkdir(parents=True, exist_ok=True)

    MAX_PER_CLASS = 50
    BASE_SEED = 1234
    stability_rows = []

    df_eval = df_eval_all.copy()
    df_eval["meta_code"] = df_eval["meta_code"].astype(str)

    for meta_code_str, df_cls in df_eval.groupby("meta_code", sort=False):
        class_name = str(TARGET_META[str(meta_code_str)]["name"])

        # Skip already augmented mats (optional; safe to keep)
        df_cls = df_cls[~df_cls["mat_path"].str.contains(r"_extra_.*\.mat$", case=False, na=False)]

        if MAX_PER_CLASS is not None:
            df_cls = df_cls.head(MAX_PER_CLASS)

        print(f"\n=== Class {meta_code_str} ({class_name}) | n={len(df_cls)} ===")

        for _, row in df_cls.iterrows():
            mat_path = row["mat_path"]
            sel_idx  = int(row.get("sel_idx", 0))

            # deterministic per record
            seed = (int(BASE_SEED) ^ (zlib.crc32(mat_path.encode("utf-8")) & 0xFFFFFFFF)) & 0xFFFFFFFF
            print(f"  -> [{meta_code_str}] sel_idx={sel_idx} | seed={seed} | {mat_path}")

            metrics, *_ = run_extra_beat_stability_experiment(
                mat_path=mat_path,
                snomed_code=str(meta_code_str),
                model=model,
                class_names=class_names,
                augment_root=STAB_OUT_DIR,
                seed=seed,
            )

            extra = metrics.get("extra") or {}

            stability_rows.append({
                "meta_code": str(meta_code_str),
                "class_name": class_name,
                "sel_idx": sel_idx,
                "mat_path": mat_path,

                "seed": seed,
                "augment_beat_index": metrics.get("augment_beat_index"),
                "augmented_file": metrics.get("augmented_file"),

                "extra_spearman": extra.get("spearman"),
                "extra_jaccard": extra.get("jaccard_topk"),
                "extra_rbo": extra.get("rbo"),
                "extra_wjacc": extra.get("weighted_jaccard"),
                "extra_k_eff": extra.get("k_eff"),
            })

    df_stability = pd.DataFrame(stability_rows)

    # Save
    STAB_CSV_PATH.parent.mkdir(parents=True, exist_ok=True)
    df_stability.to_csv(STAB_CSV_PATH, index=False)
    print(f"Saved df_stability to: {STAB_CSV_PATH}")

df_stability


Cached df_stability not found — running stability eval...

=== Class 164889003 (atrial fibrillation) | n=3 ===
  -> [164889003] sel_idx=3954 | seed=1484973458 | C:\UHull\Data\Training_WFDB\A0850.mat
FUSED 164889003 | n=2 | win=0.5s mE=100 mF=100 topkE=5 | LIME=0m35s TS=0m28s | fused=2 | total=1m03s
Fused pipeline complete: 1 classes in 1m03s
  -> [164889003] sel_idx=7920 | seed=3357105374 | C:\UHull\Data\Training_WFDB\A4816.mat
FUSED 164889003 | n=2 | win=0.5s mE=100 mF=100 topkE=5 | LIME=0m37s TS=0m29s | fused=2 | total=1m06s
Fused pipeline complete: 1 classes in 1m06s
  -> [164889003] sel_idx=8285 | seed=67408971 | C:\UHull\Data\Training_WFDB\A5181.mat
FUSED 164889003 | n=2 | win=0.5s mE=100 mF=100 topkE=5 | LIME=0m38s TS=0m34s | fused=2 | total=1m13s
Fused pipeline complete: 1 classes in 1m13s

=== Class 426783006 (sinus rhythm) | n=3 ===
  -> [426783006] sel_idx=9042 | seed=2831083760 | C:\UHull\Data\WFDB\HR00120.mat
FUSED 426783006 | n=2 | win=0.25s mE=100 mF=100 topkE=5 | LIME=0m

Unnamed: 0,meta_code,class_name,sel_idx,mat_path,seed,augment_beat_index,augmented_file,extra_spearman,extra_jaccard,extra_rbo,extra_wjacc,extra_k_eff
0,164889003,atrial fibrillation,3954,C:\UHull\Data\Training_WFDB\A0850.mat,1484973458,12,c:\UHull\ecg-xai\outputs\extra_beat_aug\A0850\...,0.939491,0.538462,0.354224,0.958656,20.0
1,164889003,atrial fibrillation,7920,C:\UHull\Data\Training_WFDB\A4816.mat,3357105374,1,c:\UHull\ecg-xai\outputs\extra_beat_aug\A4816\...,0.9402,1.0,0.46113,0.946926,10.0
2,164889003,atrial fibrillation,8285,C:\UHull\Data\Training_WFDB\A5181.mat,67408971,12,c:\UHull\ecg-xai\outputs\extra_beat_aug\A5181\...,0.85402,0.16129,0.09266,0.902965,18.0
3,426783006,sinus rhythm,9042,C:\UHull\Data\WFDB\HR00120.mat,2831083760,2,c:\UHull\ecg-xai\outputs\extra_beat_aug\HR0012...,0.838932,1.0,0.587905,0.897584,10.0
4,426783006,sinus rhythm,11786,C:\UHull\Data\WFDB\HR02864.mat,286412424,6,c:\UHull\ecg-xai\outputs\extra_beat_aug\HR0286...,0.942036,1.0,0.424008,0.97489,10.0
5,426783006,sinus rhythm,12105,C:\UHull\Data\WFDB\HR03183.mat,3221927589,2,c:\UHull\ecg-xai\outputs\extra_beat_aug\HR0318...,0.945377,0.391304,0.343491,0.948073,16.0
