In [1]:
import os
import sys
sys.path.append('../src')
os.getcwd()

%pip install -e .. -qqq

Note: you may need to restart the kernel to use updated packages.


In [2]:
import numpy as np
import pandas as pd
import tensorflow as tf
import zlib
import shutil
from pathlib import Path
from datetime import datetime

from utils import save_run, load_run
from config_targets import TARGET_META
from explainer import run_fused_pipeline_for_classes
from selection import build_selection_df_with_aliases
from stability_eval import run_extra_beat_stability_experiment
from eval import evaluate_all_payloads

%load_ext autoreload
%autoreload 2

In [None]:
ROOT = Path.cwd().parent

MODEL_PATH         = ROOT / "model" / "resnet_final.keras"
SNOMED_CLASSES_NPY = ROOT / "data" / "snomed_classes.npy"

ECG_FILENAMES_PATH = ROOT / "data" / "ecg_filenames.npy"
PROBS_PATH         = ROOT / "data" / "ecg_model_probs.npy"
CLASS_NAMES_PATH   = ROOT / "data" / "snomed_classes.npy"
Y_TRUE_PATH        = ROOT / "data" / "ecg_y_true.npy"
ECG_DURATIONS_PATH = ROOT / "data" / "ecg_durations.npy"
STAB_OUT_DIR       = ROOT / "outputs" / "extra_beat_aug"


model = tf.keras.models.load_model(MODEL_PATH, compile=False)
class_names = np.load(SNOMED_CLASSES_NPY, allow_pickle=True)

In [None]:
EVAL_CSV_PATH      = ROOT / "outputs" / "eval" / 'df_eval_attauc_deletion.csv'
STAB_CSV_PATH      = ROOT / "outputs" / "eval" / "df_eval_stability.csv"
ECG_PRED_PATH      = ROOT / "outputs" / "ecg_xai_sel_meta_p0.85_k5.csv"

In [None]:
# from config import DATA_ROOT, MAXLEN
# from utils import import_key_data
# from ecg_predict import batched_predict_all
# from selection import build_y_true_from_labels

# gender, age, labels, ecg_filenames = import_key_data(DATA_ROOT)

# # 1) Build ground-truth multi-hot labels
# y_true = build_y_true_from_labels(labels, class_names)

# # 2) Predict probabilities
# probs = batched_predict_all(
#     model,
#     ecg_filenames,
#     maxlen=MAXLEN,
#     batch_size=32,
# )

# # 3) Optional: binary predictions (0/1) at some threshold
# pred_threshold = 0.5
# y_pred = (probs >= pred_threshold).astype(np.int8)

# Y_PRED_PATH      = ROOT / "outputs" / f"ecg_y_pred_{pred_threshold:.2f}.npy"

# # 4) Save everything
# np.save(ECG_FILENAMES_PATH, ecg_filenames)
# np.save(PROBS_PATH, probs)
# np.save(Y_TRUE_PATH, y_true)
# np.save(Y_PRED_PATH, y_pred)


In [5]:
ecg_filenames = np.load(ECG_FILENAMES_PATH, allow_pickle=True)
probs         = np.load(PROBS_PATH)
class_names   = np.load(CLASS_NAMES_PATH, allow_pickle=True)
y_true        = np.load(Y_TRUE_PATH)

sel_df = build_selection_df_with_aliases(
    ecg_filenames=ecg_filenames,
    probs=probs,
    class_names=class_names,
    target_meta=TARGET_META,
    y_true=y_true,
    k_per_class=50,
    min_prob=0.85,
    max_duration_sec=20.0,
    duration_cache_path=str(ECG_DURATIONS_PATH)
)

sel_df.to_csv(ECG_PRED_PATH, index=False)
sel_df

[INFO] Estimating durations and keeping ECGs <= 20.0 s...
[INFO] Duration filter: keeping 11265/13187 ECGs (<= 20.0 s).
[CLASS 164889003 (atrial fibrillation)] picked 50 examples.
[CLASS 426783006 (sinus rhythm)] picked 50 examples.


Unnamed: 0,group_class,filename,sel_idx,duration_sec,prob_meta
0,164889003,C:\UHull\Data\Training_WFDB\A2249.mat,5347,16.968,0.940748
1,164889003,C:\UHull\Data\Training_WFDB\A2978.mat,6076,10.000,0.986492
2,164889003,C:\UHull\Data\Training_WFDB\A2618.mat,5716,14.474,0.988416
3,164889003,C:\UHull\Data\Training_WFDB\A0003.mat,3101,10.000,0.959976
4,164889003,C:\UHull\Data\Training_WFDB\A5377.mat,8475,19.122,0.990656
...,...,...,...,...,...
95,426783006,C:\UHull\Data\WFDB\HR00005.mat,8921,10.000,0.993680
96,426783006,C:\UHull\Data\WFDB\HR01101.mat,10017,10.000,0.992534
97,426783006,C:\UHull\Data\WFDB\HR02292.mat,11208,10.000,0.955642
98,426783006,C:\UHull\Data\WFDB\HR04184.mat,13100,10.000,0.993729


In [6]:
# ---- choose mode ----
run_mode = "eval"   # "eval" or "demo"

MODE_CFG = {
    "eval": {"max_examples_per_class": 50, "plot": False},
    "demo": {"max_examples_per_class": 3,  "plot": False},
}

In [7]:
cfg = MODE_CFG[run_mode]

OUT_BASE = ROOT / "outputs"
RUN_DIR = OUT_BASE / run_mode

# ---- cache check ----
def run_assets_exist(run_dir: Path) -> bool:
    return (
        (run_dir / "all_fused_payloads.joblib").exists()
        and (run_dir / "df_lime_all.parquet").exists()
        and (run_dir / "df_ts_all.parquet").exists()
        and (run_dir / "sel_df.parquet").exists()
    )

target_classes = list(TARGET_META.keys())

if run_assets_exist(RUN_DIR):
    print(f"[{run_mode}] Loading cached assets from: {RUN_DIR}")
    all_fused_payloads, df_lime_all, df_ts_all, sel_df_cached = load_run(RUN_DIR)
    
    # Optionally replace sel_df with cached one to keep indices consistent
    sel_df = sel_df_cached
else:
    print(f"[{run_mode}] No cache found. Running pipeline...\n")
    all_fused_payloads, df_lime_all, df_ts_all = run_fused_pipeline_for_classes(
        target_classes=target_classes,
        sel_df=sel_df,
        model=model,
        class_names=class_names,
        max_examples_per_class=cfg["max_examples_per_class"],
        plot=cfg["plot"],
    )

    save_run(
        RUN_DIR,
        all_fused_payloads,
        df_lime_all,
        df_ts_all,
        sel_df,
        meta={
            "mode": run_mode,
            "saved_at": datetime.now().isoformat(timespec="seconds"),
            "target_classes": target_classes,
            "max_examples_per_class": cfg["max_examples_per_class"],
            "plot": cfg["plot"],
        },
    )
    print(f"[{run_mode}] Saved to: {RUN_DIR}")

[eval] Loading cached assets from: c:\UHull\ecg-xai\outputs\eval


In [8]:
if EVAL_CSV_PATH.exists():
    print(f"Loading cached df_eval_all from: {EVAL_CSV_PATH}")
    df_eval_all = pd.read_csv(EVAL_CSV_PATH)
else:
    print("Cached CSV not found — running evaluate_all_payloads()...")

    df_eval_all = evaluate_all_payloads(
        all_payloads=all_fused_payloads,
        method_label="LIME+TimeSHAP",
        model=model,
        class_names=class_names,
    )

    EVAL_CSV_PATH.parent.mkdir(parents=True, exist_ok=True)
    df_eval_all.to_csv(EVAL_CSV_PATH, index=False)
    print(f"Saved df_eval_all to: {EVAL_CSV_PATH}")

df_eval_all


Loading cached df_eval_all from: c:\UHull\ecg-xai\data\df_eval_attauc_deletion.csv


Unnamed: 0.1,Unnamed: 0,meta_code,class_name,sel_idx,mat_path,method,strict_attauc,lenient_attauc,precision_k,strict_p_at_k,lenient_p_at_k,deletion_auc,faithfulness_gain,n_tokens
0,0,164889003,atrial fibrillation,1992,C:\UHull\Data\Training_2\Q0946.mat,LIME+TimeSHAP,0.878189,0.723214,20,0.50,1.0,0.293073,0.024195,168
1,1,164889003,atrial fibrillation,2212,C:\UHull\Data\Training_2\Q1179.mat,LIME+TimeSHAP,0.859948,0.635435,20,0.25,1.0,0.250375,0.105301,204
2,2,164889003,atrial fibrillation,2487,C:\UHull\Data\Training_2\Q1462.mat,LIME+TimeSHAP,0.844199,0.635585,20,0.30,0.7,0.294993,0.013402,372
3,3,164889003,atrial fibrillation,3101,C:\UHull\Data\Training_WFDB\A0003.mat,LIME+TimeSHAP,0.860280,0.650225,20,0.45,0.9,0.289826,-0.045046,300
4,4,164889003,atrial fibrillation,3124,C:\UHull\Data\Training_WFDB\A0026.mat,LIME+TimeSHAP,0.896007,0.739583,20,0.55,1.0,0.230502,0.204329,144
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,95,426783006,sinus rhythm,12529,C:\UHull\Data\WFDB\HR03613.mat,LIME+TimeSHAP,0.846094,1.000000,20,0.35,1.0,0.257836,0.014502,96
96,96,426783006,sinus rhythm,12542,C:\UHull\Data\WFDB\HR03626.mat,LIME+TimeSHAP,0.843778,1.000000,20,0.35,1.0,0.253848,0.039210,180
97,97,426783006,sinus rhythm,13100,C:\UHull\Data\WFDB\HR04184.mat,LIME+TimeSHAP,0.835764,1.000000,20,0.25,1.0,0.297479,-0.002890,144
98,98,426783006,sinus rhythm,13148,C:\UHull\Data\WFDB\HR04232.mat,LIME+TimeSHAP,0.856500,1.000000,20,0.35,1.0,0.289625,0.000269,120


### Stability to an Extra Heartbeat (All Target Classes)

We assess explanation stability under a synthetic perturbation where one existing
heartbeat is duplicated and re-inserted at the **middle** of the recording.

For each target rhythm class in `TARGET_META`:

1. We take the example that was already explained and evaluated in `df_eval_all`.
2. We run `run_extra_beat_stability_experiment`, which:
   - creates augmented versions of the ECG
     - one with an **extra beat inserted in the middle**
   - recomputes **fused LIME + TimeSHAP** explanations for each version
   - compares region-level importance profiles over the *shared beats* using:
     - **Spearman rank correlation** (global ordering of important regions)
     - **Jaccard@K** overlap of the top-K most important regions.

This gives us per-class stability metrics for:
- original vs **extra-beat-in-middle**


In [9]:
# ---- Load if exists, otherwise run + save ----
if STAB_CSV_PATH.exists():
    print(f"Loading cached df_stability from: {STAB_CSV_PATH}")
    df_stability = pd.read_csv(STAB_CSV_PATH)
else:
    print("Cached df_stability not found — running stability eval...")

    # Clean run: wipe outputs dir first
    if STAB_OUT_DIR.exists():
        shutil.rmtree(STAB_OUT_DIR)
    STAB_OUT_DIR.mkdir(parents=True, exist_ok=True)

    MAX_PER_CLASS = 50
    BASE_SEED = 1234
    stability_rows = []

    df_eval = df_eval_all.copy()
    df_eval["meta_code"] = df_eval["meta_code"].astype(str)

    for meta_code_str, df_cls in df_eval.groupby("meta_code", sort=False):
        class_name = str(TARGET_META[str(meta_code_str)]["name"])

        # Skip already augmented mats (optional; safe to keep)
        df_cls = df_cls[~df_cls["mat_path"].str.contains(r"_extra_.*\.mat$", case=False, na=False)]

        if MAX_PER_CLASS is not None:
            df_cls = df_cls.head(MAX_PER_CLASS)

        print(f"\n=== Class {meta_code_str} ({class_name}) | n={len(df_cls)} ===")

        for _, row in df_cls.iterrows():
            mat_path = row["mat_path"]
            sel_idx  = int(row.get("sel_idx", 0))

            # deterministic per record
            seed = (int(BASE_SEED) ^ (zlib.crc32(mat_path.encode("utf-8")) & 0xFFFFFFFF)) & 0xFFFFFFFF
            print(f"  -> [{meta_code_str}] sel_idx={sel_idx} | seed={seed} | {mat_path}")

            metrics, *_ = run_extra_beat_stability_experiment(
                mat_path=mat_path,
                snomed_code=str(meta_code_str),
                model=model,
                class_names=class_names,
                augment_root=STAB_OUT_DIR,
                seed=seed,
            )

            extra = metrics.get("extra") or {}

            stability_rows.append({
                "meta_code": str(meta_code_str),
                "class_name": class_name,
                "sel_idx": sel_idx,
                "mat_path": mat_path,

                "seed": seed,
                "augment_beat_index": metrics.get("augment_beat_index"),
                "augmented_file": metrics.get("augmented_file"),

                "extra_spearman": extra.get("spearman"),
                "extra_jaccard": extra.get("jaccard_topk"),
                "extra_rbo": extra.get("rbo"),
                "extra_wjacc": extra.get("weighted_jaccard"),
                "extra_k_eff": extra.get("k_eff"),
            })

    df_stability = pd.DataFrame(stability_rows)

    # Save
    STAB_CSV_PATH.parent.mkdir(parents=True, exist_ok=True)
    df_stability.to_csv(STAB_CSV_PATH, index=False)
    print(f"Saved df_stability to: {STAB_CSV_PATH}")

df_stability


Loading cached df_stability from: c:\UHull\ecg-xai\data\df_eval_stability.csv


Unnamed: 0,meta_code,class_name,sel_idx,mat_path,seed,augment_beat_index,augmented_file,extra_spearman,extra_jaccard,extra_rbo,extra_wjacc,extra_k_eff
0,164889003,atrial fibrillation,1992,C:\UHull\Data\Training_2\Q0946.mat,1722255653,1,c:\UHull\ecg-xai\outputs\extra_beat_aug\Q0946\...,0.923705,0.250000,0.184673,0.914618,10.0
1,164889003,atrial fibrillation,2212,C:\UHull\Data\Training_2\Q1179.mat,2636522147,15,c:\UHull\ecg-xai\outputs\extra_beat_aug\Q1179\...,1.000000,1.000000,0.717570,1.000000,12.0
2,164889003,atrial fibrillation,2487,C:\UHull\Data\Training_2\Q1462.mat,1939009456,11,c:\UHull\ecg-xai\outputs\extra_beat_aug\Q1462\...,0.909040,0.333333,0.346345,0.944855,20.0
3,164889003,atrial fibrillation,3101,C:\UHull\Data\Training_WFDB\A0003.mat,2080391318,4,c:\UHull\ecg-xai\outputs\extra_beat_aug\A0003\...,0.854352,0.200000,0.089974,0.911727,18.0
4,164889003,atrial fibrillation,3124,C:\UHull\Data\Training_WFDB\A0026.mat,4180176621,8,c:\UHull\ecg-xai\outputs\extra_beat_aug\A0026\...,0.895691,0.500000,0.202715,0.920192,12.0
...,...,...,...,...,...,...,...,...,...,...,...,...
95,426783006,sinus rhythm,12529,C:\UHull\Data\WFDB\HR03613.mat,581097955,4,c:\UHull\ecg-xai\outputs\extra_beat_aug\HR0361...,0.950845,1.000000,0.285237,0.967647,10.0
96,426783006,sinus rhythm,12542,C:\UHull\Data\WFDB\HR03626.mat,1825973309,1,c:\UHull\ecg-xai\outputs\extra_beat_aug\HR0362...,0.877713,0.058824,0.006607,0.912702,18.0
97,426783006,sinus rhythm,13100,C:\UHull\Data\WFDB\HR04184.mat,2028954540,4,c:\UHull\ecg-xai\outputs\extra_beat_aug\HR0418...,0.977784,0.400000,0.528836,0.975189,10.0
98,426783006,sinus rhythm,13148,C:\UHull\Data\WFDB\HR04232.mat,2894101586,7,c:\UHull\ecg-xai\outputs\extra_beat_aug\HR0423...,0.970517,0.250000,0.330261,0.980306,10.0


In [10]:
import json
import pandas as pd

LEADS12 = ["I","II","III","aVR","aVL","aVF","V1","V2","V3","V4","V5","V6"]

def _lead_map_from_row(row):
    # your CSV has: "I,II,III,aVR,aVL,aVF,V1,V2,V3,V4,V5,V6"
    names = str(row.get("lead_names", "")).split(",")
    if len(names) == 12:
        return names
    return LEADS12

def payloads_from_lime_df(df_lime: pd.DataFrame):
    out = {}
    for _, row in df_lime.iterrows():
        meta_code = str(int(row["group_class"]))      # e.g., "164889003"
        sel_idx   = str(int(row["val_idx"]))          # matches your fused keys style
        lead_names = _lead_map_from_row(row)

        spans_top5 = json.loads(row["perlead_spans_top5_json"])  # keys are lead indices as strings
        perlead_spans = {}
        for k, spans in spans_top5.items():
            li = int(k)
            if 0 <= li < len(lead_names):
                perlead_spans[lead_names[li]] = spans  # spans are [t0,t1,w] in seconds

        out.setdefault(meta_code, {})[sel_idx] = {
            "mat_path": row["mat_path"],
            "perlead_spans": perlead_spans,
            "page_seconds": float(row.get("page_seconds", 10.0)),
        }
    return out

def payloads_from_timeshap_df(df_ts: pd.DataFrame):
    out = {}
    for _, row in df_ts.iterrows():
        meta_code = str(int(row["group_class"]))
        sel_idx   = str(int(row["val_idx"]))
        lead_names = _lead_map_from_row(row)

        spans_top5 = json.loads(row["perlead_timeshap_top5_json"])
        perlead_spans = {}
        for k, spans in spans_top5.items():
            li = int(k)
            if 0 <= li < len(lead_names):
                perlead_spans[lead_names[li]] = spans

        out.setdefault(meta_code, {})[sel_idx] = {
            "mat_path": row["mat_path"],
            "perlead_spans": perlead_spans,
            "page_seconds": float(row.get("page_seconds", 10.0)),
        }
    return out


In [13]:
# =====================================
# 2) NOTEBOOK: reload eval + run Table Y
# =====================================
import json
import pandas as pd
from pathlib import Path
import importlib

import eval as eval_mod
importlib.reload(eval_mod)          # <- reload after editing eval.py
from eval import evaluate_all_payloads  # <- re-import updated function

In [14]:
# =====================================
# 2) NOTEBOOK: reload eval + run Table Y
# =====================================

# Load existing files
df_lime = pd.read_csv(RUN_DIR / "df_lime_all.csv")
df_ts   = pd.read_csv(RUN_DIR / "df_ts_all.csv")

with open(RUN_DIR / "df_fused_all.json", "r") as f:
    fused_payloads = json.load(f)

lime_payloads = payloads_from_lime_df(df_lime)
ts_payloads   = payloads_from_timeshap_df(df_ts)

# Run eval + cache
out_dir = Path(RUN_DIR / "explain_ablations")
out_dir.mkdir(parents=True, exist_ok=True)

# X0: LIME (priors OFF)
df_eval_lime = evaluate_all_payloads(
    lime_payloads,
    method_label="LIME",
    model=model,
    class_names=class_names,
    use_priors=False,
)

# X1: TimeSHAP (priors OFF)
df_eval_ts = evaluate_all_payloads(
    ts_payloads,
    method_label="TimeSHAP",
    model=model,
    class_names=class_names,
    use_priors=False,
)

# X2: Fusion (priors OFF)
df_eval_fused_no = evaluate_all_payloads(
    fused_payloads,
    method_label="Fusion_no_priors",
    model=model,
    class_names=class_names,
    use_priors=False,
)

# X3: Fusion (priors ON)
df_eval_fused_yes = evaluate_all_payloads(
    fused_payloads,
    method_label="Fusion_with_priors",
    model=model,
    class_names=class_names,
    use_priors=True,
    prior_alpha=0.8,   # keep same as before; can sweep if you want later
)

df_eval_lime.to_csv(out_dir / "df_eval_lime_priorsOFF.csv", index=False)
df_eval_ts.to_csv(out_dir / "df_eval_timeshap_priorsOFF.csv", index=False)
df_eval_fused_no.to_csv(out_dir / "df_eval_fused_priorsOFF.csv", index=False)
df_eval_fused_yes.to_csv(out_dir / "df_eval_fused_priorsON.csv", index=False)

print("Saved:", out_dir)


Saved: c:\UHull\ecg-xai\outputs\eval\explain_ablations


In [18]:
# =====================================
# 3) Build Table Y summary CSV (robust)
# =====================================
import pandas as pd
import numpy as np

def _pick_col(df, candidates):
    """Return the first existing column name from a list of candidates."""
    for c in candidates:
        if c in df.columns:
            return c
    return None

def summarise_for_table_y(df: pd.DataFrame, variant: str) -> pd.DataFrame:
    # Handle both possible naming conventions
    col_group = _pick_col(df, ["group_class", "meta_code"])
    if col_group is None:
        raise KeyError(f"Couldn't find group column. Available columns: {list(df.columns)}")

    # Your eval outputs usually have these exact names, but we guard anyway
    col_strict_attauc = _pick_col(df, ["strict_attauc", "strict_auc"])
    col_lenient_attauc = _pick_col(df, ["lenient_attauc", "lenient_auc"])
    col_strict_p = _pick_col(df, ["strict_p_at_k", "strict_p_at_20", "p_strict"])
    col_lenient_p = _pick_col(df, ["lenient_p_at_k", "lenient_p_at_20", "p_lenient"])
    col_del_auc = _pick_col(df, ["deletion_auc", "del_auc"])
    col_faith = _pick_col(df, ["faithfulness_gain", "faith_gain"])

    needed = {
        "strict_attauc": col_strict_attauc,
        "lenient_attauc": col_lenient_attauc,
        "strict_p_at_k": col_strict_p,
        "lenient_p_at_k": col_lenient_p,
        "deletion_auc": col_del_auc,
        "faithfulness_gain": col_faith,
    }

    missing = [k for k,v in needed.items() if v is None]
    if missing:
        raise KeyError(
            f"Missing required metrics columns: {missing}\n"
            f"Available columns: {list(df.columns)}"
        )

    # Build a clean, standardised view
    work = df[[col_group] + list(needed.values())].copy()
    work = work.rename(columns={
        col_group: "group_class",
        col_strict_attauc: "strict_attauc",
        col_lenient_attauc: "lenient_attauc",
        col_strict_p: "strict_p_at_k",
        col_lenient_p: "lenient_p_at_k",
        col_del_auc: "deletion_auc",
        col_faith: "faithfulness_gain",
    })

    out = work.groupby("group_class", as_index=False).mean(numeric_only=True)
    out.insert(1, "variant", variant)
    return out

table_y = pd.concat([
    summarise_for_table_y(df_eval_lime,      "LIME (priors off)"),
    summarise_for_table_y(df_eval_ts,        "TimeSHAP (priors off)"),
    summarise_for_table_y(df_eval_fused_no,  "Fusion (priors off)"),
    summarise_for_table_y(df_eval_fused_yes, "Fusion (priors on)"),
], ignore_index=True)

table_y.to_csv(out_dir / "table_Y_ablation_summary.csv", index=False)
table_y


Unnamed: 0,group_class,variant,strict_attauc,lenient_attauc,strict_p_at_k,lenient_p_at_k,deletion_auc,faithfulness_gain
0,164889003,LIME (priors off),0.620541,0.585299,0.471,0.919,0.267543,0.081332
1,426783006,LIME (priors off),0.513141,0.557734,0.146,0.5,0.280654,0.006317
2,164889003,TimeSHAP (priors off),0.604608,0.581377,0.434,0.911,0.264257,0.093673
3,426783006,TimeSHAP (priors off),0.54109,0.511139,0.2,0.335,0.28064,0.006126
4,164889003,Fusion (priors off),0.617845,0.611438,0.3,0.862,0.26452,0.092843
5,426783006,Fusion (priors off),0.53861,0.54912,0.201,0.51,0.280246,0.007555
6,164889003,Fusion (priors on),0.865854,0.668023,0.45,0.901,0.26452,0.092843
7,426783006,Fusion (priors on),0.851929,1.0,0.343,1.0,0.280246,0.007555
