In [16]:
import os
import re
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import img_to_array, load_img

# =========================
# Config
# =========================
MODEL_PATH = os.path.join("models", "veremi_images_1x3_e.keras")
IMAGE_DIR  = "veremi_images_1x3_e"
OUTDIR     = "shap_explanations"

# How many images to explain (cap for speed)
N_SAMPLES  = 128
IG_STEPS   = 64               # 32–128 typical
CLASS_TO_EXPLAIN = "pred"     # "pred" or an int like 0/1

# If you truly have 12 attributes, name them here:
CHANNEL_NAMES_DEFAULT = [
    "sendTime","sender","posx","posy",
    "spdx_n","spdy_n","aclx","acly",
    "hedx","hedy","hedx_n","hedy_n"
]

os.makedirs(OUTDIR, exist_ok=True)


# =========================
# Helpers
# =========================
def _sorted_image_paths(folder):
    paths = glob.glob(os.path.join(folder, "image_*.png"))
    # natural sort by index in filename image_<idx>.png
    def _key(p):
        m = re.search(r"image_(\d+)\.png$", os.path.basename(p))
        return int(m.group(1)) if m else 10**9
    return sorted(paths, key=_key)

def _load_pngs_as_stack(folder, H, W, C, limit=None):
    """
    Loads PNGs as grayscale (H,W,1), scales to [0,1], and if needed repeats to C.
    Returns X with shape (N,H,W,C).
    """
    paths = _sorted_image_paths(folder)
    if limit is not None:
        paths = paths[:limit]
    Xs = []
    for p in paths:
        img = load_img(p, color_mode='grayscale', target_size=(H, W))
        arr = img_to_array(img).astype(np.float32) / 255.0  # (H,W,1)
        if arr.shape[-1] == 1 and C > 1:
            arr = np.repeat(arr, C, axis=-1)                # replicate channel -> (H,W,C)
        if arr.shape[-1] != C:
            raise ValueError(f"Loaded image has {arr.shape[-1]} channels, expected {C}. File: {p}")
        Xs.append(arr)
    if not Xs:
        raise FileNotFoundError(f"No PNGs found in {folder}. Expected files like image_1.png")
    return np.stack(Xs, axis=0)

def _get_channel_names(C):
    return CHANNEL_NAMES_DEFAULT if C == len(CHANNEL_NAMES_DEFAULT) else [f"ch{c}" for c in range(C)]

def _predicted_classes(model, X):
    y = model(X, training=False)
    y = tf.convert_to_tensor(y)
    if len(y.shape) != 2:
        raise ValueError(f"Unexpected model output shape: {y.shape}")
    K = y.shape[1]
    if K == 1:
        # Probability of class 1
        p1 = tf.squeeze(y, axis=1)
        return tf.cast(p1 >= 0.5, tf.int32).numpy()
    return tf.argmax(y, axis=1).numpy()

def _scores_for_class(y, cls_idx_vec):
    """
    Map model outputs y (N,K) to scalar scores per sample for chosen classes.
    If K==1, y[:,0] == p(class=1).
    """
    y = tf.convert_to_tensor(y)
    N = tf.shape(y)[0]
    K = y.shape[1]
    if K == 1:
        p1 = tf.squeeze(y, axis=1)                     # (N,)
        cls_idx_vec = tf.convert_to_tensor(cls_idx_vec, dtype=tf.int32)
        return tf.where(cls_idx_vec == 1, p1, 1.0 - p1)
    # K > 1
    idx = tf.stack([tf.range(N, dtype=tf.int32), tf.convert_to_tensor(cls_idx_vec, dtype=tf.int32)], axis=1)
    return tf.gather_nd(y, idx)                         # (N,)

def integrated_gradients(model, X, class_sel="pred", steps=64, baseline=None):
    """
    Compute Integrated Gradients for each sample. Returns attributions same shape as X.
    """
    X = np.asarray(X, dtype=np.float32)
    N, H, W, C = X.shape
    if baseline is None:
        baseline = np.zeros_like(X, dtype=np.float32)
    elif baseline.shape != X.shape and baseline.shape != (1, H, W, C):
        raise ValueError("Baseline shape mismatch.")
    if baseline.shape == (1, H, W, C):
        baseline = np.repeat(baseline, N, axis=0)

    # decide class per sample
    if class_sel == "pred":
        cls_vec = _predicted_classes(model, X)
    elif isinstance(class_sel, int):
        cls_vec = np.full((N,), int(class_sel), dtype=np.int32)
    else:
        cls_vec = np.asarray(class_sel, dtype=np.int32)
        if cls_vec.shape != (N,):
            raise ValueError("class_sel must be 'pred', int, or an array of shape (N,)")

    alphas = tf.linspace(0.0, 1.0, steps)
    attributions = np.zeros_like(X, dtype=np.float32)

    for i in range(N):
        xi = tf.convert_to_tensor(X[i:i+1])            # (1,H,W,C)
        bi = tf.convert_to_tensor(baseline[i:i+1])
        direction = xi - bi

        grads_accum = 0.0
        for a in alphas:
            xt = bi + a * direction                    # (1,H,W,C)
            with tf.GradientTape() as tape:
                tape.watch(xt)
                y = model(xt, training=False)          # (1,K)
                # scalar score for class cls_vec[i]
                score = _scores_for_class(y, np.array([cls_vec[i]], dtype=np.int32))
                score = tf.squeeze(score, axis=0)      # scalar
            grads = tape.gradient(score, xt)           # (1,H,W,C)
            grads_accum = grads_accum + grads

        avg_grads = grads_accum / tf.cast(steps, tf.float32)
        attr = (xi - bi) * avg_grads                   # (1,H,W,C)
        attributions[i] = attr.numpy()[0]

    return attributions  # (N,H,W,C)

def occlusion_per_channel(model, X, class_sel="pred"):
    """
    Zero out one channel at a time; return (mean |Δscore| per channel, std across samples).
    """
    X = np.asarray(X, dtype=np.float32)
    N, H, W, C = X.shape

    # base
    y_base = model(X, training=False)
    if class_sel == "pred":
        cls_vec = _predicted_classes(model, X)
    elif isinstance(class_sel, int):
        cls_vec = np.full((N,), int(class_sel), dtype=np.int32)
    else:
        cls_vec = np.asarray(class_sel, dtype=np.int32)
    base_scores = _scores_for_class(y_base, cls_vec).numpy()  # (N,)

    means = np.zeros((C,), dtype=np.float32)
    stds  = np.zeros((C,), dtype=np.float32)

    for ch in range(C):
        X_occ = X.copy()
        X_occ[:, :, :, ch] = 0.0
        y_occ = model(X_occ, training=False)
        occ_scores = _scores_for_class(y_occ, cls_vec).numpy()  # (N,)
        diffs = np.abs(base_scores - occ_scores)                # (N,)
        means[ch] = diffs.mean().astype(np.float32)
        stds[ch]  = diffs.std().astype(np.float32)

    return means, stds  # (C,), (C,)

def _barplot(path, names, values, title, ylabel="importance"):
    plt.figure(figsize=(10, 4.2))
    xs = np.arange(len(names))
    plt.bar(xs, values)
    plt.xticks(xs, names, rotation=45, ha='right')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path, dpi=160)
    plt.close()

def _barplot_with_error(path, names, means, stds, title, ylabel="importance"):
    plt.figure(figsize=(10, 4.2))
    xs = np.arange(len(names))
    plt.bar(xs, means, yerr=stds, capsize=4)
    plt.xticks(xs, names, rotation=45, ha='right')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path, dpi=160)
    plt.close()

def _heatmap(path, A, title):
    plt.figure(figsize=(4.5, 4))
    plt.imshow(A, interpolation='nearest')
    plt.title(title)
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(path, dpi=160)
    plt.close()


# =========================
# Main
# =========================
def main():
    # Load model
    model = tf.keras.models.load_model(MODEL_PATH)
    _, H, W, C = model.input_shape
    if None in (H, W, C):
        # fallback if model.input_shape is (None, 5, 5, 12)
        H, W, C = 5, 5, 12

    # Load data (same preprocessing as training: grayscale → repeat to 12 channels)
    X_all = _load_pngs_as_stack(IMAGE_DIR, H, W, C, limit=N_SAMPLES)
    N = X_all.shape[0]
    print(f"[INFO] Loaded {N} images with shape {X_all.shape}")

    chan_names = _get_channel_names(C)

    # ---------- Integrated Gradients ----------
    ig = integrated_gradients(model, X_all, class_sel=CLASS_TO_EXPLAIN, steps=IG_STEPS)  # (N,H,W,C)
    # Per-sample per-channel importance = sum of |attr| over H,W
    per_sample_ch = np.sum(np.abs(ig), axis=(1, 2))  # (N,C)
    ig_mean = per_sample_ch.mean(axis=0)
    ig_std  = per_sample_ch.std(axis=0)

    # Plot per-channel IG with error bars
    _barplot_with_error(
        os.path.join(OUTDIR, "1x3_e_b_ig_per_channel.png"),
        chan_names, ig_mean, ig_std,
        "Integrated Gradients — per-channel importance", ylabel="IG |attr| (mean ± std)"
    )

    # Also save a global per-pixel heatmap (sum over samples & channels)
    per_pixel = np.sum(np.abs(ig), axis=(0, 3))  # (H,W)
    _heatmap(
        os.path.join(OUTDIR, "1x3_e_b_ig_per_pixel.png"),
        per_pixel, "Integrated Gradients — per-pixel importance"
    )

    # Print top channels (IG)
    order = np.argsort(ig_mean)[::-1]
    print("\n[IG] Top channels:")
    for r, ch in enumerate(order, 1):
        print(f"{r:2d}. {chan_names[ch]}  IG={ig_mean[ch]:.6f} ± {ig_std[ch]:.6f}")

    # ---------- Occlusion ----------
    occ_mean, occ_std = occlusion_per_channel(model, X_all, class_sel=CLASS_TO_EXPLAIN)  # (C,),(C,)
    _barplot_with_error(
        os.path.join(OUTDIR, "1x3_e_b_occlusion_per_channel.png"),
        chan_names, occ_mean, occ_std,
        "Occlusion — Δscore per channel", ylabel="|Δscore| (mean ± std)"
    )

    order2 = np.argsort(occ_mean)[::-1]
    print("\n[Occlusion] Top channels:")
    for r, ch in enumerate(order2, 1):
        print(f"{r:2d}. {chan_names[ch]}  Δscore={occ_mean[ch]:.6f} ± {occ_std[ch]:.6f}")

    print(f"\n[Saved] Plots in: {OUTDIR}")

if __name__ == "__main__":
    main()

[INFO] Loaded 128 images with shape (128, 3, 1, 12)

[IG] Top channels:
 1. sender  IG=0.000000 ± 0.000000
 2. hedy_n  IG=0.000000 ± 0.000000
 3. spdy_n  IG=0.000000 ± 0.000000
 4. acly  IG=0.000000 ± 0.000000
 5. aclx  IG=0.000000 ± 0.000000
 6. posx  IG=0.000000 ± 0.000000
 7. posy  IG=0.000000 ± 0.000000
 8. hedx_n  IG=0.000000 ± 0.000000
 9. hedx  IG=0.000000 ± 0.000000
10. spdx_n  IG=0.000000 ± 0.000000
11. hedy  IG=0.000000 ± 0.000000
12. sendTime  IG=0.000000 ± 0.000000

[Occlusion] Top channels:
 1. hedy_n  Δscore=0.000000 ± 0.000000
 2. hedx_n  Δscore=0.000000 ± 0.000000
 3. hedy  Δscore=0.000000 ± 0.000000
 4. hedx  Δscore=0.000000 ± 0.000000
 5. acly  Δscore=0.000000 ± 0.000000
 6. aclx  Δscore=0.000000 ± 0.000000
 7. spdy_n  Δscore=0.000000 ± 0.000000
 8. spdx_n  Δscore=0.000000 ± 0.000000
 9. posy  Δscore=0.000000 ± 0.000000
10. posx  Δscore=0.000000 ± 0.000000
11. sender  Δscore=0.000000 ± 0.000000
12. sendTime  Δscore=0.000000 ± 0.000000

[Saved] Plots in: shap_explanati