In [10]:
import os
os.getcwd()

%pip install -e .. -qqq

Note: you may need to restart the kernel to use updated packages.


In [11]:
from pathlib import Path
import tensorflow as tf
import numpy as np
import pandas as pd
from datetime import datetime

from utils import save_run, load_run
from config_targets import TARGET_META
from explainer import run_fused_pipeline_for_classes
from selection import build_selection_df_with_aliases

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
ROOT = Path.cwd().parent

MODEL_PATH         = ROOT / "model" / "resnet_final.keras"
SNOMED_CLASSES_NPY = ROOT / "data" / "snomed_classes.npy"
SEL_DF_CSV         = ROOT / "data" / "ecg_model_pred_data.csv"

ECG_FILENAMES_PATH = ROOT / "data" / "ecg_filenames.npy"
PROBS_PATH         = ROOT / "data" / "ecg_model_probs.npy"
CLASS_NAMES_PATH   = ROOT / "data" / "snomed_classes.npy"
Y_TRUE_PATH        = ROOT / "data" / "ecg_y_true.npy"
ECG_DURATIONS_PATH = ROOT / "data" / "ecg_durations.npy"
ECG_PRED_PATH      = ROOT / "data" / "ecg_xai_sel_meta_p0.85_k5.csv"

model = tf.keras.models.load_model(MODEL_PATH, compile=False)
class_names = np.load(SNOMED_CLASSES_NPY, allow_pickle=True)

In [13]:
from config import DATA_ROOT
from utils import import_key_data
from ecg_predict import batched_predict_all
from selection import build_y_true_from_labels

gender, age, labels, ecg_filenames = import_key_data(DATA_ROOT)

# 1) Build ground-truth multi-hot labels
y_true = build_y_true_from_labels(labels, class_names)

# 2) Predict probabilities
probs = batched_predict_all(
    model,
    ecg_filenames,
    maxlen=5000,
    batch_size=32,
)

# 3) Optional: binary predictions (0/1) at some threshold
pred_threshold = 0.5
y_pred = (probs >= pred_threshold).astype(np.int8)

Y_PRED_PATH      = ROOT / "data" / f"ecg_y_pred_{pred_threshold:.2f}.npy"

# 4) Save everything
np.save(ECG_FILENAMES_PATH, ecg_filenames)
np.save(PROBS_PATH, probs)
np.save(Y_TRUE_PATH, y_true)
np.save(Y_PRED_PATH, y_pred)


Indexing ECG headers: 100%|██████████| 13203/13203 [02:31<00:00, 87.09it/s]
Predicting ECGs: 100%|██████████| 413/413 [10:28<00:00,  1.52s/it]


In [14]:
ecg_filenames = np.load(ECG_FILENAMES_PATH, allow_pickle=True)
probs         = np.load(PROBS_PATH)
class_names   = np.load(CLASS_NAMES_PATH, allow_pickle=True)
y_true        = np.load(Y_TRUE_PATH)

sel_df = build_selection_df_with_aliases(
    ecg_filenames=ecg_filenames,
    probs=probs,
    class_names=class_names,
    target_meta=TARGET_META,
    y_true=y_true,
    k_per_class=20,
    min_prob=0.85,
    max_duration_sec=20.0,
    duration_cache_path=str(ECG_DURATIONS_PATH)
)

sel_df.to_csv(ECG_PRED_PATH, index=False)
sel_df

[INFO] Estimating durations and keeping ECGs <= 20.0 s...
[INFO] Saved duration cache to c:\UHull\ecg-xai\data\ecg_durations.npy
[INFO] Duration filter: keeping 11279/13201 ECGs (<= 20.0 s).
[CLASS 164889003 (atrial fibrillation)] picked 20 examples.
[CLASS 426783006 (sinus rhythm)] picked 20 examples.
[INFO] relaxing selection for 17338001 (ventricular premature beats)
[CLASS 17338001 (ventricular premature beats)] picked 20 examples.


Unnamed: 0,group_class,filename,sel_idx,duration_sec,prob_meta
0,164889003,C:\UHull\Data\Training_WFDB\A5170.mat,8278,20.0,0.995147
1,164889003,C:\UHull\Data\Training_WFDB\A4816.mat,7924,14.0,0.972553
2,164889003,C:\UHull\Data\Training_WFDB\A0833.mat,3941,10.0,0.98607
3,164889003,C:\UHull\Data\WFDB\HR03231.mat,12159,10.0,0.923075
4,164889003,C:\UHull\Data\Training_WFDB\A5717.mat,8825,10.0,0.996607
5,164889003,C:\UHull\Data\Training_WFDB\A5159.mat,8267,19.0,0.999407
6,164889003,C:\UHull\Data\Training_WFDB\A0370.mat,3478,10.0,0.89298
7,164889003,C:\UHull\Data\Training_WFDB\A0003_extra_end_ex...,3108,10.0,0.950776
8,164889003,C:\UHull\Data\Training_WFDB\A4162.mat,7270,20.0,0.997943
9,164889003,C:\UHull\Data\Training_WFDB\A3098.mat,6206,10.0,0.976818


In [15]:
# ---- choose mode ----
run_mode = "eval"   # "eval" or "demo"

In [16]:
MODE_CFG = {
    "eval": {"max_examples_per_class": 20, "plot": False},
    "demo": {"max_examples_per_class": 3,  "plot": False},
}

cfg = MODE_CFG[run_mode]

OUT_BASE = ROOT / "outputs"
RUN_DIR = OUT_BASE / run_mode

# ---- cache check ----
def run_assets_exist(run_dir: Path) -> bool:
    return (
        (run_dir / "all_fused_payloads.joblib").exists()
        and (run_dir / "df_lime_all.parquet").exists()
        and (run_dir / "df_ts_all.parquet").exists()
        and (run_dir / "sel_df.parquet").exists()
    )

target_classes = list(TARGET_META.keys())

if run_assets_exist(RUN_DIR):
    print(f"[{run_mode}] Loading cached assets from: {RUN_DIR}")
    all_fused_payloads, df_lime_all, df_ts_all, sel_df_cached = load_run(RUN_DIR)
    
    # Optionally replace sel_df with cached one to keep indices consistent
    sel_df = sel_df_cached
else:
    print(f"[{run_mode}] No cache found. Running pipeline...\n")
    all_fused_payloads, df_lime_all, df_ts_all = run_fused_pipeline_for_classes(
        target_classes=target_classes,
        sel_df=sel_df,
        model=model,
        class_names=class_names,
        max_examples_per_class=cfg["max_examples_per_class"],
        plot=cfg["plot"],
    )

    save_run(
        RUN_DIR,
        all_fused_payloads,
        df_lime_all,
        df_ts_all,
        sel_df,
        meta={
            "mode": run_mode,
            "saved_at": datetime.now().isoformat(timespec="seconds"),
            "target_classes": target_classes,
            "max_examples_per_class": cfg["max_examples_per_class"],
            "plot": cfg["plot"],
        },
    )
    print(f"[{run_mode}] Saved to: {RUN_DIR}")

[eval] No cache found. Running pipeline...


=== [1/3] Processing class: 164889003 ===

 FUSED class=164889003 | window=0.5s | m_event=100 | m_feat=100 | topk_events=5
 LIME done: 20 rows in 5m14s
 TimeSHAP done: 20 rows in 4m05s
  Fusing: 20 common records


                                                                                                        

  Class 164889003 total: 9m20s
— Progress: 1/3 classes | ETA ~ 18m40s

=== [2/3] Processing class: 426783006 ===

 FUSED class=426783006 | window=0.25s | m_event=100 | m_feat=100 | topk_events=5




 LIME done: 20 rows in 5m07s
 TimeSHAP done: 20 rows in 4m05s
  Fusing: 20 common records


                                                                                       

  Class 426783006 total: 9m12s
— Progress: 2/3 classes | ETA ~ 9m16s

=== [3/3] Processing class: 17338001 ===

 FUSED class=17338001 | window=0.4s | m_event=100 | m_feat=100 | topk_events=5




 LIME done: 20 rows in 5m08s
 TimeSHAP done: 20 rows in 4m02s
  Fusing: 20 common records


                                                                                             

  Class 17338001 total: 9m10s
— Progress: 3/3 classes | ETA ~ 0m00s

 All classes complete in 27m43s
[eval] Saved to: c:\UHull\ecg-xai\outputs\eval




In [17]:
from eval import evaluate_all_payloads

df_eval_all = evaluate_all_payloads(
    all_payloads=all_fused_payloads,
    method_label="LIME+TimeSHAP",
    debug=True,
    model=model,
    class_names=class_names,
)

df_eval_all

Unnamed: 0,meta_code,class_name,sel_idx,mat_path,method,strict_attauc,lenient_attauc,precision_k,strict_p_at_k,lenient_p_at_k,deletion_auc,faithfulness_gain,n_tokens
0,164889003,atrial fibrillation,3106,C:\UHull\Data\Training_WFDB\A0003_extra_end.mat,LIME+TimeSHAP,0.86028,0.650225,20,0.45,0.9,0.289826,-0.045046,300
1,164889003,atrial fibrillation,3108,C:\UHull\Data\Training_WFDB\A0003_extra_end_ex...,LIME+TimeSHAP,0.859418,0.656603,20,0.55,0.9,0.256976,0.07935,288
2,164889003,atrial fibrillation,3129,C:\UHull\Data\Training_WFDB\A0023.mat,LIME+TimeSHAP,0.859834,0.668802,20,0.55,0.9,0.255702,0.081672,228
3,164889003,atrial fibrillation,3478,C:\UHull\Data\Training_WFDB\A0370.mat,LIME+TimeSHAP,0.876087,0.682095,20,0.5,1.0,0.265549,-0.062598,276
4,164889003,atrial fibrillation,3941,C:\UHull\Data\Training_WFDB\A0833.mat,LIME+TimeSHAP,0.84696,0.674,20,0.35,0.85,0.269187,0.105275,300
5,164889003,atrial fibrillation,5643,C:\UHull\Data\Training_WFDB\A2535.mat,LIME+TimeSHAP,0.855664,0.640076,20,0.4,0.8,0.225168,0.212271,192
6,164889003,atrial fibrillation,5681,C:\UHull\Data\Training_WFDB\A2573.mat,LIME+TimeSHAP,0.87716,0.705392,20,0.55,0.85,0.287784,0.04203,216
7,164889003,atrial fibrillation,5767,C:\UHull\Data\Training_WFDB\A2659.mat,LIME+TimeSHAP,0.827551,0.635523,20,0.35,0.9,0.297027,0.00691,84
8,164889003,atrial fibrillation,6206,C:\UHull\Data\Training_WFDB\A3098.mat,LIME+TimeSHAP,0.864839,0.675685,20,0.5,1.0,0.288048,-0.03639,276
9,164889003,atrial fibrillation,6275,C:\UHull\Data\Training_WFDB\A3167.mat,LIME+TimeSHAP,0.852883,0.642687,20,0.5,0.85,0.265366,0.096957,348


In [19]:
df_eval_all.to_csv(ROOT / "data" / 'df_eval_all.csv')

### Stability to an Extra Heartbeat (All Target Classes)

We assess explanation stability under a synthetic perturbation where one existing
heartbeat is duplicated and re-inserted either at the **end** of the ECG or in
the **middle** of the recording.

For each target rhythm class in `TARGET_META`:

1. We take the example that was already explained and evaluated in `df_eval_all`.
2. We run `run_extra_beat_stability_experiment`, which:
   - creates two augmented versions of the ECG
     - one with an **extra beat appended at the end**
     - one with an **extra beat inserted in the middle**
   - recomputes **fused LIME + TimeSHAP** explanations for each version
   - compares region-level importance profiles over the *shared beats* using:
     - **Spearman rank correlation** (global ordering of important regions)
     - **Jaccard@K** overlap of the top-K most important regions.

This gives us per-class stability metrics for:

- original vs **extra-beat-at-end**
- original vs **extra-beat-in-middle**


In [None]:
from eval import run_extra_beat_stability_experiment

target_classes = list(TARGET_META.keys())

stability_results = {}

for snomed_code in target_classes:
    snomed_code_str = str(snomed_code)

    # Pick the first evaluated example for this class
    df_cls = df_eval_all[df_eval_all.meta_code.astype(str) == snomed_code_str]

    if df_cls.empty:
        print(f"[WARN] No evaluated examples for class {snomed_code_str}, skipping.")
        continue

    row = df_cls.iloc[0]
    mat_path = row.mat_path
    class_name = TARGET_META[snomed_code_str]["name"]

    print(f"\n=== Running extra-beat stability for {snomed_code_str} ({class_name}) ===")
    print(f"MAT path: {mat_path}")

    metrics, sel_df_stab, fused_payloads_stab, df_lime_stab, df_ts_stab = run_extra_beat_stability_experiment(
        mat_path=mat_path,
        snomed_code=snomed_code_str,
        model=model,
        class_names=class_names,
    )

    stability_results[snomed_code_str] = {
        "name": class_name,
        "mat_path": mat_path,
        "metrics": metrics,
    }

stability_results



### Stability Summary Table

To make the results easier to inspect and report, we summarise the stability
metrics for each target class in a single table.


In [None]:
rows = []

for code, info in stability_results.items():
    m = info["metrics"]
    m_end = m["extra_end"]
    m_mid = m["extra_mid"]

    rows.append({
        "SNOMED code": code,
        "Class name": info["name"],
        "Spearman (extra end)":  m_end["spearman"],
        "Jaccard@K (extra end)": m_end["jaccard_topk"],
        "Spearman (extra mid)":  m_mid["spearman"],
        "Jaccard@K (extra mid)": m_mid["jaccard_topk"],
    })

df_stability_summary = pd.DataFrame(rows)
df_stability_summary


Overall, the fused explanations were **perfectly stable** to appending an
extra heartbeat at the end of the recording across all classes
(Spearman ≈ 1.0, Jaccard@K = 1.0). This indicates that adding extra context
at the tail of the ECG does not affect the explanations for the existing beats.

In contrast, inserting an extra beat in the middle of the recording led to
a reduction in stability. Spearman rank correlation remained high
(≈ 0.94–0.97), suggesting that the global ordering of important regions is
largely preserved, but Jaccard@K dropped (e.g. 0.33–0.60), indicating that the
exact composition of the top-K most important regions is more sensitive to
perturbations in the internal temporal structure of the signal.


In [None]:
import pandas as pd

# Group by SNOMED meta-code + human-readable name
summary = (
    df_eval_all
    .groupby(["meta_code", "class_name"])
    .agg(
        # --- AttAUC: ranking-based localisation accuracy ---
        strict_attauc_mean=("strict_attauc", "mean"),
        strict_attauc_std=("strict_attauc", "std"),
        lenient_attauc_mean=("lenient_attauc", "mean"),
        lenient_attauc_std=("lenient_attauc", "std"),

        # --- Deletion AUC: faithfulness (lower is better) ---
        deletion_auc_mean=("deletion_auc", "mean"),
        deletion_auc_std=("deletion_auc", "std"),

        # Token count (just descriptive)
        n_tokens_mean=("n_tokens", "mean"),
        n_tokens_std=("n_tokens", "std"),
    )
    .reset_index()
)

summary


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# For nicer x-axis labels
classes = summary["class_name"].values
x = np.arange(len(classes))
width = 0.35  # bar width

fig, ax = plt.subplots(figsize=(8, 4))

# Strict AttAUC: explainer vs strict, narrow definition of "correct"
ax.bar(
    x - width/2,
    summary["strict_attauc_mean"],
    width,
    yerr=summary["strict_attauc_std"],
    capsize=4,
    label="Strict AttAUC",
)

# Lenient AttAUC: explainer vs broader, more forgiving ground truth
ax.bar(
    x + width/2,
    summary["lenient_attauc_mean"],
    width,
    yerr=summary["lenient_attauc_std"],
    capsize=4,
    label="Lenient AttAUC",
)

ax.set_xticks(x)
ax.set_xticklabels(classes, rotation=20, ha="right")
ax.set_ylabel("AttAUC (ranking-based)")
ax.set_ylim(0, 1.05)
ax.set_title("Token-level localisation accuracy (AttAUC)")
ax.legend()
plt.tight_layout()
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(8, 4))

# Lower deletion AUC => more faithful (probability drops faster)
ax.bar(
    x,
    summary["deletion_auc_mean"],
    yerr=summary["deletion_auc_std"],
    capsize=4,
)

ax.set_xticks(x)
ax.set_xticklabels(classes, rotation=20, ha="right")
ax.set_ylabel("Deletion AUC (lower = more faithful)")
ax.set_title("Faithfulness of explanations via targeted deletion")
plt.tight_layout()
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

# Use strict AttAUC on x-axis and deletion AUC on y-axis
ax.scatter(
    summary["strict_attauc_mean"],
    summary["deletion_auc_mean"],
)

for _, row in summary.iterrows():
    ax.text(
        row["strict_attauc_mean"] + 0.005,
        row["deletion_auc_mean"] + 0.002,
        row["class_name"],
        fontsize=8,
    )

ax.set_xlabel("Strict AttAUC (higher = more accurate localisation)")
ax.set_ylabel("Deletion AUC (lower = more faithful)")
ax.set_title("Accuracy vs Faithfulness across diagnostic classes")
plt.tight_layout()
plt.show()
