In [74]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# =========================
# Config
# =========================
MODEL_PATH = os.path.join("models", "veremi_images_m_2x2_e_i.keras")
# Provide a path to a single sample .npy; shape must match model input (H,W,C).
# If None, a synthetic demo sample is used.
SAMPLE_NPY = None  # e.g., r"samples\example_0001.npy"

# Number of perturbations per feature for MI estimation
K = 80
# Perturbation std (as a fraction of value range). Tune as needed.
PERTURB_STD_FRAC = 0.10
# Number of bins to discretize both perturbed feature values and predicted probs for MI
N_BINS = 10
# Set which class to explain: 'pred' for the model's predicted class, or an integer class id
CLASS_TO_EXPLAIN = 'pred'

# Value clipping range for the inputs (change if your preprocessing differs)
CLIP_MIN, CLIP_MAX = 0.0, 1.0

# Output directory
OUTDIR = "mime_explanations"
os.makedirs(OUTDIR, exist_ok=True)

# Channel names (guarded to fall back if count mismatches)
CHANNEL_NAMES_DEFAULT = [
    "sendTime", "sender", "posx", "posy",
    "spdx_n", "spdy_n", "aclx", "acly",
    "hedx", "hedy", "hedx_n", "hedy_n"
]

# =========================
# Utilities
# =========================
def pick_classification_head(model, prefer_name=None):
    """
    Normalize Keras outputs to a single classification probability vector.
    Picks a head whose last-dim looks like a class dimension.
    """
    outputs = model.outputs
    if isinstance(outputs, tf.Tensor):
        return outputs

    chosen = None
    if prefer_name is not None and hasattr(model, "output_names"):
        try:
            idx = list(model.output_names).index(prefer_name)
            chosen = outputs[idx]
        except Exception:
            pass

    if chosen is None:
        # Prefer any head with last-dim >= 2 (classification)
        for t in outputs:
            try:
                last = int(t.shape[-1])
                if last >= 2:
                    chosen = t
                    break
            except Exception:
                continue

    if chosen is None:
        chosen = outputs[0]

    return chosen

def ensure_softmax_probs(logits_or_probs):
    """Apply softmax if necessary (handles binary or multi-class)."""
    arr = np.asarray(logits_or_probs)
    s = arr.sum(axis=-1, keepdims=True)
    if np.all(arr >= -1e-6) and np.all(np.abs(s - 1.0) < 1e-3):
        return arr
    exps = np.exp(arr - np.max(arr, axis=-1, keepdims=True))
    return exps / np.clip(exps.sum(axis=-1, keepdims=True), 1e-12, None)

def discretize(values, n_bins):
    """Discretize 1D array into n_bins equal-width bins, returning bin indices [0..n_bins-1]."""
    v = np.asarray(values).ravel()
    vmin, vmax = np.min(v), np.max(v)
    if vmax == vmin:
        # Degenerate; put everything in a single bin
        return np.zeros_like(v, dtype=int), vmin, vmax
    edges = np.linspace(vmin, vmax, n_bins + 1)
    bins = np.clip(np.digitize(v, edges[:-1], right=False) - 1, 0, n_bins - 1)
    return bins, vmin, vmax

def mutual_information_discrete(x_bins, y_bins, n_x_bins, n_y_bins, eps=1e-12):
    """
    Empirical MI(X;Y) with discrete bins.
    x_bins, y_bins: integer arrays in [0..n_x_bins-1], [0..n_y_bins-1]
    """
    x_bins = np.asarray(x_bins).ravel()
    y_bins = np.asarray(y_bins).ravel()
    assert x_bins.shape == y_bins.shape

    N = len(x_bins)
    joint = np.zeros((n_x_bins, n_y_bins), dtype=float)
    for xb, yb in zip(x_bins, y_bins):
        joint[xb, yb] += 1.0
    joint /= max(N, 1)

    px = joint.sum(axis=1, keepdims=True)
    py = joint.sum(axis=0, keepdims=True)

    with np.errstate(divide='ignore', invalid='ignore'):
        ratio = joint / (px @ py + eps)
        term = joint * np.log(np.clip(ratio, eps, None))

    mi = np.nansum(term)
    return float(mi)

def load_sample(sample_path, expected_shape):
    """Load sample matching expected (H,W,C). If None, create synthetic in [0,1]."""
    H, W, C = expected_shape
    if sample_path is not None and os.path.isfile(sample_path):
        x = np.load(sample_path)
        if x.shape != (H, W, C):
            raise ValueError(f"Expected sample shape {(H,W,C)}, got {x.shape}")
        return x.astype(np.float32)[None, ...]
    # Fallback synthetic sample in [0,1]
    rng = np.random.RandomState(0)
    x = rng.rand(H, W, C).astype(np.float32)
    return x[None, ...]

class MIMEPredictor:
    def __init__(self, model, head_tensor, h, w, c):
        # Build a submodel that outputs the classification head
        if isinstance(model.inputs, (list, tuple)):
            inputs_for_submodel = list(model.inputs)
        else:
            inputs_for_submodel = [model.input]
        self.head_model = tf.keras.Model(inputs=inputs_for_submodel, outputs=head_tensor)

        # Keep input meta for passing the right structure later
        self._single_input = (len(self.head_model.inputs) == 1)
        self._input_names = getattr(self.head_model, "input_names", None)

        # Fixed input signature to avoid retracing
        self._sig = tf.TensorSpec(shape=[None, h, w, c], dtype=tf.float32)

        @tf.function(reduce_retracing=True, input_signature=[self._sig])
        def _predict_tf(x):
            # Match the model's expected input structure:
            if self._single_input:
                # Option 1 (usually enough to remove the warning): wrap as a list
                return self.head_model([x], training=False)

                # Option 2 (even stricter by name; uncomment if you still see warnings)
                # if self._input_names:
                #     return self.head_model({self._input_names[0]: x}, training=False)
                # else:
                #     return self.head_model([x], training=False)
            else:
                # If you ever have multi-input, you'd map tensors accordingly here.
                raise ValueError("This predictor currently supports single-input models only.")

        self._predict_tf = _predict_tf

    def probs(self, xb_np):
        x = tf.convert_to_tensor(xb_np, dtype=tf.float32)
        out = self._predict_tf(x).numpy()
        return ensure_softmax_probs(out)

# =========================
# MIME (local MI) explanation
# =========================
def mime_local_importance(predictor, x1, class_to_explain='pred',
                          k=80, perturb_std_frac=0.10, n_bins=10,
                          clip_min=0.0, clip_max=1.0):
    base_probs = predictor.probs(x1)  # (1, C)
    num_classes = base_probs.shape[-1]
    if class_to_explain == 'pred':
        explained_class = int(np.argmax(base_probs[0]))
    else:
        explained_class = int(class_to_explain)
        if explained_class < 0 or explained_class >= num_classes:
            raise ValueError(f"class_to_explain out of range: {explained_class}")

    H, W, C = x1.shape[1:]
    mi_map = np.zeros((H, W, C), dtype=np.float32)

    vmin, vmax = float(np.min(x1)), float(np.max(x1))
    if vmax == vmin:
        vmax = vmin + 1.0
    value_range = max(vmax - vmin, 1e-6)
    sigma = perturb_std_frac * value_range

    xb = np.repeat(x1.astype(np.float32), repeats=k, axis=0)  # (k,H,W,C), fixed shape

    total_feats = H * W * C
    idx = 0
    for i in range(H):
        for j in range(W):
            for ch in range(C):
                base_val = float(x1[0, i, j, ch])
                noise = np.random.normal(loc=0.0, scale=sigma, size=(k,))
                pert_vals = np.clip(base_val + noise, clip_min, clip_max).astype(np.float32)

                xb[:] = x1
                xb[:, i, j, ch] = pert_vals

                probs = predictor.probs(xb)          # (k, C)
                y = probs[:, explained_class]        # (k,)

                x_bins, _, _ = discretize(pert_vals, n_bins)
                y_bins, _, _ = discretize(y, n_bins)
                mi = mutual_information_discrete(x_bins, y_bins, n_bins, n_bins)
                mi_map[i, j, ch] = mi

                idx += 1
                if idx % 50 == 0:
                    print(f"[MIME] Processed {idx}/{total_feats} features...")

    return mi_map, explained_class

# =========================
# Plotting helpers
# =========================
def plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR):
    H, W, C = mi_map.shape

    # Aggregate views
    per_pixel = mi_map.sum(axis=2)
    per_channel = mi_map.sum(axis=(0,1))

    def _normalize(a):
        a = np.asarray(a, dtype=float)
        m, M = np.min(a), np.max(a)
        if M > m:
            return (a - m) / (M - m + 1e-12)
        return np.zeros_like(a)

    # Custom colormap
    from matplotlib.colors import LinearSegmentedColormap
    cmap = LinearSegmentedColormap.from_list("wlb", ["white", "lightblue", "blue"])

    # Choose channel names safely
    if len(CHANNEL_NAMES_DEFAULT) == C:
        channel_names = CHANNEL_NAMES_DEFAULT
    else:
        channel_names = [f"channel_{i}" for i in range(C)]

    # Per-pixel heatmap
    fig = plt.figure(figsize=(5,4))
    im = plt.imshow(_normalize(per_pixel), interpolation='nearest', cmap=cmap)
    plt.title(f"MIME per-pixel MI (class {explained_class}) [{H}x{W}]")
    plt.colorbar(im, label="normalized MI")
    plt.xticks(range(W)); plt.yticks(range(H))
    path_pixel = os.path.join(outdir, f"{H}x{W}_e_bi_per_pixel_class_{explained_class}.png")
    plt.savefig(path_pixel, dpi=200)
    plt.close(fig)

    # Per-channel bar
    fig = plt.figure(figsize=(max(6, C*0.7), 4))
    xs = np.arange(C)
    plt.bar(xs, per_channel)
    plt.xlabel("Channel")
    plt.ylabel("MI (sum over spatial)")
    plt.title(f"MIME per-channel MI (class {explained_class}) [{H}x{W}]")
    plt.xticks(xs, channel_names, rotation=45, ha="right")
    plt.tight_layout()
    path_chan = os.path.join(outdir, f"{H}x{W}_e_bi_per_channel_class_{explained_class}.png")
    plt.savefig(path_chan, dpi=200)
    plt.close(fig)

    # Grid of channel heatmaps
    import matplotlib.colors as mcolors
    vmin = float(np.min(mi_map))
    vmax = float(np.max(mi_map))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    n_cols = 4
    n_rows = int(math.ceil(C / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    axs = np.atleast_1d(axs).reshape(n_rows, n_cols)
    im_last = None
    for ch in range(C):
        r, c = divmod(ch, n_cols)
        ax = axs[r, c]
        im_last = ax.imshow(mi_map[:,:,ch], interpolation='nearest', norm=norm, cmap=cmap)
        ax.set_title(channel_names[ch], fontsize=9)
        ax.set_xticks(range(W)); ax.set_yticks(range(H))
    for idx in range(C, n_rows*n_cols):
        r, c = divmod(idx, n_cols)
        axs[r, c].axis("off")
    cbar = fig.colorbar(im_last, ax=axs, shrink=0.95, pad=0.02)
    cbar.set_label("Mutual Information (MI)")
    fig.suptitle(f"MIME per-feature MI by channel (class {explained_class}) [{H}x{W}]", y=0.995, fontsize=12)
    path_grid = os.path.join(outdir, f"{H}x{W}_e_bi_per_channel_grid_class_{explained_class}.png")
    plt.savefig(path_grid, dpi=200)
    plt.close(fig)

    print("[MIME] Saved:")
    print(" -", path_pixel)
    print(" -", path_chan)
    print(" -", path_grid)

# =========================
# Main
# =========================
def main():
    model = tf.keras.models.load_model(MODEL_PATH)

    # Infer (H, W, C) from the model
    _, H, W, C = model.input_shape
    print(f"[INFO] Model expects input shape: (None, {H}, {W}, {C})")

    # Build predictor with the correct spatial dims and channels
    head = pick_classification_head(model, prefer_name="classification_output")
    predictor = MIMEPredictor(model, head_tensor=head, h=H, w=W, c=C)

    # Load sample of correct shape (or synthesize)
    x1 = load_sample(SAMPLE_NPY, expected_shape=(H, W, C)).astype(np.float32)
    x1 = np.clip(x1, CLIP_MIN, CLIP_MAX)

    # Peek prediction (uses cached predictor)
    base_probs = predictor.probs(x1)
    cexp = 0

    mi_map, explained_class = mime_local_importance(
        predictor, x1,
        class_to_explain=cexp,
        k=K,
        perturb_std_frac=PERTURB_STD_FRAC,
        n_bins=N_BINS,
        clip_min=CLIP_MIN,
        clip_max=CLIP_MAX
    )

    print(f"[RESULT] mi_map shape: {mi_map.shape}")
    print(f"[RESULT] explained_class: {explained_class}")
    print(f"[RESULT] total MI sum: {mi_map.sum():.6f}")

    plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR)

if __name__ == "__main__":
    main()

[INFO] Model expects input shape: (None, 3, 1, 12)
[RESULT] mi_map shape: (3, 1, 12)
[RESULT] explained_class: 0
[RESULT] total MI sum: 51.740101
[MIME] Saved:
 - mime_explanations\3x1_e_bi_per_pixel_class_0.png
 - mime_explanations\3x1_e_bi_per_channel_class_0.png
 - mime_explanations\3x1_e_bi_per_channel_grid_class_0.png


In [75]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# =========================
# Config
# =========================
MODEL_PATH = os.path.join("models", "veremi_images_1x3_e_i.keras")
# Provide a path to a single sample .npy; shape must match model input (H,W,C).
# If None, a synthetic demo sample is used.
SAMPLE_NPY = None  # e.g., r"samples\example_0001.npy"

# Number of perturbations per feature for MI estimation
K = 80
# Perturbation std (as a fraction of value range). Tune as needed.
PERTURB_STD_FRAC = 0.10
# Number of bins to discretize both perturbed feature values and predicted probs for MI
N_BINS = 10
# Set which class to explain: 'pred' for the model's predicted class, or an integer class id
CLASS_TO_EXPLAIN = 'pred'

# Value clipping range for the inputs (change if your preprocessing differs)
CLIP_MIN, CLIP_MAX = 0.0, 1.0

# Output directory
OUTDIR = "mime_explanations"
os.makedirs(OUTDIR, exist_ok=True)

# Channel names (guarded to fall back if count mismatches)
CHANNEL_NAMES_DEFAULT = [
    "sendTime", "sender", "posx", "posy",
    "spdx_n", "spdy_n", "aclx", "acly",
    "hedx", "hedy", "hedx_n", "hedy_n"
]

# =========================
# Utilities
# =========================
def pick_classification_head(model, prefer_name=None):
    """
    Normalize Keras outputs to a single classification probability vector.
    Picks a head whose last-dim looks like a class dimension.
    """
    outputs = model.outputs
    if isinstance(outputs, tf.Tensor):
        return outputs

    chosen = None
    if prefer_name is not None and hasattr(model, "output_names"):
        try:
            idx = list(model.output_names).index(prefer_name)
            chosen = outputs[idx]
        except Exception:
            pass

    if chosen is None:
        # Prefer any head with last-dim >= 2 (classification)
        for t in outputs:
            try:
                last = int(t.shape[-1])
                if last >= 2:
                    chosen = t
                    break
            except Exception:
                continue

    if chosen is None:
        chosen = outputs[0]

    return chosen

def ensure_softmax_probs(logits_or_probs):
    """Apply softmax if necessary (handles binary or multi-class)."""
    arr = np.asarray(logits_or_probs)
    s = arr.sum(axis=-1, keepdims=True)
    if np.all(arr >= -1e-6) and np.all(np.abs(s - 1.0) < 1e-3):
        return arr
    exps = np.exp(arr - np.max(arr, axis=-1, keepdims=True))
    return exps / np.clip(exps.sum(axis=-1, keepdims=True), 1e-12, None)

def discretize(values, n_bins):
    """Discretize 1D array into n_bins equal-width bins, returning bin indices [0..n_bins-1]."""
    v = np.asarray(values).ravel()
    vmin, vmax = np.min(v), np.max(v)
    if vmax == vmin:
        # Degenerate; put everything in a single bin
        return np.zeros_like(v, dtype=int), vmin, vmax
    edges = np.linspace(vmin, vmax, n_bins + 1)
    bins = np.clip(np.digitize(v, edges[:-1], right=False) - 1, 0, n_bins - 1)
    return bins, vmin, vmax

def mutual_information_discrete(x_bins, y_bins, n_x_bins, n_y_bins, eps=1e-12):
    """
    Empirical MI(X;Y) with discrete bins.
    x_bins, y_bins: integer arrays in [0..n_x_bins-1], [0..n_y_bins-1]
    """
    x_bins = np.asarray(x_bins).ravel()
    y_bins = np.asarray(y_bins).ravel()
    assert x_bins.shape == y_bins.shape

    N = len(x_bins)
    joint = np.zeros((n_x_bins, n_y_bins), dtype=float)
    for xb, yb in zip(x_bins, y_bins):
        joint[xb, yb] += 1.0
    joint /= max(N, 1)

    px = joint.sum(axis=1, keepdims=True)
    py = joint.sum(axis=0, keepdims=True)

    with np.errstate(divide='ignore', invalid='ignore'):
        ratio = joint / (px @ py + eps)
        term = joint * np.log(np.clip(ratio, eps, None))

    mi = np.nansum(term)
    return float(mi)

def load_sample(sample_path, expected_shape):
    """Load sample matching expected (H,W,C). If None, create synthetic in [0,1]."""
    H, W, C = expected_shape
    if sample_path is not None and os.path.isfile(sample_path):
        x = np.load(sample_path)
        if x.shape != (H, W, C):
            raise ValueError(f"Expected sample shape {(H,W,C)}, got {x.shape}")
        return x.astype(np.float32)[None, ...]
    # Fallback synthetic sample in [0,1]
    rng = np.random.RandomState(0)
    x = rng.rand(H, W, C).astype(np.float32)
    return x[None, ...]

class MIMEPredictor:
    def __init__(self, model, head_tensor, h, w, c):
        # Build a submodel that outputs the classification head
        if isinstance(model.inputs, (list, tuple)):
            inputs_for_submodel = list(model.inputs)
        else:
            inputs_for_submodel = [model.input]
        self.head_model = tf.keras.Model(inputs=inputs_for_submodel, outputs=head_tensor)

        # Keep input meta for passing the right structure later
        self._single_input = (len(self.head_model.inputs) == 1)
        self._input_names = getattr(self.head_model, "input_names", None)

        # Fixed input signature to avoid retracing
        self._sig = tf.TensorSpec(shape=[None, h, w, c], dtype=tf.float32)

        @tf.function(reduce_retracing=True, input_signature=[self._sig])
        def _predict_tf(x):
            # Match the model's expected input structure:
            if self._single_input:
                # Option 1 (usually enough to remove the warning): wrap as a list
                return self.head_model([x], training=False)

                # Option 2 (even stricter by name; uncomment if you still see warnings)
                # if self._input_names:
                #     return self.head_model({self._input_names[0]: x}, training=False)
                # else:
                #     return self.head_model([x], training=False)
            else:
                # If you ever have multi-input, you'd map tensors accordingly here.
                raise ValueError("This predictor currently supports single-input models only.")

        self._predict_tf = _predict_tf

    def probs(self, xb_np):
        x = tf.convert_to_tensor(xb_np, dtype=tf.float32)
        out = self._predict_tf(x).numpy()
        return ensure_softmax_probs(out)

# =========================
# MIME (local MI) explanation
# =========================
def mime_local_importance(predictor, x1, class_to_explain='pred',
                          k=80, perturb_std_frac=0.10, n_bins=10,
                          clip_min=0.0, clip_max=1.0):
    base_probs = predictor.probs(x1)  # (1, C)
    num_classes = base_probs.shape[-1]
    if class_to_explain == 'pred':
        explained_class = int(np.argmax(base_probs[0]))
    else:
        explained_class = int(class_to_explain)
        if explained_class < 0 or explained_class >= num_classes:
            raise ValueError(f"class_to_explain out of range: {explained_class}")

    H, W, C = x1.shape[1:]
    mi_map = np.zeros((H, W, C), dtype=np.float32)

    vmin, vmax = float(np.min(x1)), float(np.max(x1))
    if vmax == vmin:
        vmax = vmin + 1.0
    value_range = max(vmax - vmin, 1e-6)
    sigma = perturb_std_frac * value_range

    xb = np.repeat(x1.astype(np.float32), repeats=k, axis=0)  # (k,H,W,C), fixed shape

    total_feats = H * W * C
    idx = 0
    for i in range(H):
        for j in range(W):
            for ch in range(C):
                base_val = float(x1[0, i, j, ch])
                noise = np.random.normal(loc=0.0, scale=sigma, size=(k,))
                pert_vals = np.clip(base_val + noise, clip_min, clip_max).astype(np.float32)

                xb[:] = x1
                xb[:, i, j, ch] = pert_vals

                probs = predictor.probs(xb)          # (k, C)
                y = probs[:, explained_class]        # (k,)

                x_bins, _, _ = discretize(pert_vals, n_bins)
                y_bins, _, _ = discretize(y, n_bins)
                mi = mutual_information_discrete(x_bins, y_bins, n_bins, n_bins)
                mi_map[i, j, ch] = mi

                idx += 1
                if idx % 50 == 0:
                    print(f"[MIME] Processed {idx}/{total_feats} features...")

    return mi_map, explained_class

# =========================
# Plotting helpers
# =========================
def plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR):
    H, W, C = mi_map.shape

    # Aggregate views
    per_pixel = mi_map.sum(axis=2)
    per_channel = mi_map.sum(axis=(0,1))

    def _normalize(a):
        a = np.asarray(a, dtype=float)
        m, M = np.min(a), np.max(a)
        if M > m:
            return (a - m) / (M - m + 1e-12)
        return np.zeros_like(a)

    # Custom colormap
    from matplotlib.colors import LinearSegmentedColormap
    cmap = LinearSegmentedColormap.from_list("wlb", ["white", "lightblue", "blue"])

    # Choose channel names safely
    if len(CHANNEL_NAMES_DEFAULT) == C:
        channel_names = CHANNEL_NAMES_DEFAULT
    else:
        channel_names = [f"channel_{i}" for i in range(C)]

    # Per-pixel heatmap
    fig = plt.figure(figsize=(5,4))
    im = plt.imshow(_normalize(per_pixel), interpolation='nearest', cmap=cmap)
    plt.title(f"MIME per-pixel MI (class {explained_class}) [{H}x{W}]")
    plt.colorbar(im, label="normalized MI")
    plt.xticks(range(W)); plt.yticks(range(H))
    path_pixel = os.path.join(outdir, f"{H}x{W}_e_bi_per_pixel_class_{explained_class}.png")
    plt.savefig(path_pixel, dpi=200)
    plt.close(fig)

    # Per-channel bar
    fig = plt.figure(figsize=(max(6, C*0.7), 4))
    xs = np.arange(C)
    plt.bar(xs, per_channel)
    plt.xlabel("Channel")
    plt.ylabel("MI (sum over spatial)")
    plt.title(f"MIME per-channel MI (class {explained_class}) [{H}x{W}]")
    plt.xticks(xs, channel_names, rotation=45, ha="right")
    plt.tight_layout()
    path_chan = os.path.join(outdir, f"{H}x{W}_e_bi_per_channel_class_{explained_class}.png")
    plt.savefig(path_chan, dpi=200)
    plt.close(fig)

    # Grid of channel heatmaps
    import matplotlib.colors as mcolors
    vmin = float(np.min(mi_map))
    vmax = float(np.max(mi_map))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    n_cols = 4
    n_rows = int(math.ceil(C / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    axs = np.atleast_1d(axs).reshape(n_rows, n_cols)
    im_last = None
    for ch in range(C):
        r, c = divmod(ch, n_cols)
        ax = axs[r, c]
        im_last = ax.imshow(mi_map[:,:,ch], interpolation='nearest', norm=norm, cmap=cmap)
        ax.set_title(channel_names[ch], fontsize=9)
        ax.set_xticks(range(W)); ax.set_yticks(range(H))
    for idx in range(C, n_rows*n_cols):
        r, c = divmod(idx, n_cols)
        axs[r, c].axis("off")
    cbar = fig.colorbar(im_last, ax=axs, shrink=0.95, pad=0.02)
    cbar.set_label("Mutual Information (MI)")
    fig.suptitle(f"MIME per-feature MI by channel (class {explained_class}) [{H}x{W}]", y=0.995, fontsize=12)
    path_grid = os.path.join(outdir, f"{H}x{W}_e_bi_per_channel_grid_class_{explained_class}.png")
    plt.savefig(path_grid, dpi=200)
    plt.close(fig)

    print("[MIME] Saved:")
    print(" -", path_pixel)
    print(" -", path_chan)
    print(" -", path_grid)

# =========================
# Main
# =========================
def main():
    model = tf.keras.models.load_model(MODEL_PATH)

    # Infer (H, W, C) from the model
    _, H, W, C = model.input_shape
    print(f"[INFO] Model expects input shape: (None, {H}, {W}, {C})")

    # Build predictor with the correct spatial dims and channels
    head = pick_classification_head(model, prefer_name="classification_output")
    predictor = MIMEPredictor(model, head_tensor=head, h=H, w=W, c=C)

    # Load sample of correct shape (or synthesize)
    x1 = load_sample(SAMPLE_NPY, expected_shape=(H, W, C)).astype(np.float32)
    x1 = np.clip(x1, CLIP_MIN, CLIP_MAX)

    # Peek prediction (uses cached predictor)
    base_probs = predictor.probs(x1)
    cexp = 1

    mi_map, explained_class = mime_local_importance(
        predictor, x1,
        class_to_explain=cexp,
        k=K,
        perturb_std_frac=PERTURB_STD_FRAC,
        n_bins=N_BINS,
        clip_min=CLIP_MIN,
        clip_max=CLIP_MAX
    )

    print(f"[RESULT] mi_map shape: {mi_map.shape}")
    print(f"[RESULT] explained_class: {explained_class}")
    print(f"[RESULT] total MI sum: {mi_map.sum():.6f}")

    plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR)

if __name__ == "__main__":
    main()

[INFO] Model expects input shape: (None, 3, 1, 12)
[RESULT] mi_map shape: (3, 1, 12)
[RESULT] explained_class: 1
[RESULT] total MI sum: -0.000000
[MIME] Saved:
 - mime_explanations\3x1_e_bi_per_pixel_class_1.png
 - mime_explanations\3x1_e_bi_per_channel_class_1.png
 - mime_explanations\3x1_e_bi_per_channel_grid_class_1.png


In [31]:
import os
import math
import glob
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# =========================
# Config
# =========================
MODEL_PATH = os.path.join("models", "veremi_images_m_7x7_ea_i.keras")

SAMPLE_NPY = None          # e.g., r"samples\example_0001.npy"
FAC_IMAGE_DIR = "veremi_multilevel_images_7x7_ea"

K = 80
PERTURB_STD_FRAC = 0.10
N_BINS = 10
CLASS_TO_EXPLAIN = 'pred'

CLIP_MIN, CLIP_MAX = 0.0, 1.0

OUTDIR = "mime_explanations"
os.makedirs(OUTDIR, exist_ok=True)

CHANNEL_NAMES_DEFAULT = [
    "sendTime", "sender", "posx", "posy",
    "spdx_n", "spdy_n", "aclx", "acly",
    "hedx", "hedy", "hedx_n", "hedy_n"
]

# =========================
# Utilities
# =========================
def pick_classification_head(model, prefer_name=None):
    outputs = model.outputs
    if isinstance(outputs, tf.Tensor):
        return outputs
    chosen = None
    if prefer_name is not None and hasattr(model, "output_names"):
        try:
            idx = list(model.output_names).index(prefer_name)
            chosen = outputs[idx]
        except Exception:
            pass
    if chosen is None:
        for t in outputs:
            try:
                last = int(t.shape[-1])
                if last >= 2:
                    chosen = t
                    break
            except Exception:
                continue
    if chosen is None:
        chosen = outputs[0]
    return chosen

def ensure_softmax_probs(logits_or_probs):
    arr = np.asarray(logits_or_probs)
    s = arr.sum(axis=-1, keepdims=True)
    if np.all(arr >= -1e-6) and np.all(np.abs(s - 1.0) < 1e-3):
        return arr
    exps = np.exp(arr - np.max(arr, axis=-1, keepdims=True))
    return exps / np.clip(exps.sum(axis=-1, keepdims=True), 1e-12, None)

def discretize(values, n_bins):
    v = np.asarray(values).ravel()
    vmin, vmax = np.min(v), np.max(v)
    if vmax == vmin:
        return np.zeros_like(v, dtype=int), vmin, vmax
    edges = np.linspace(vmin, vmax, n_bins + 1)
    bins = np.clip(np.digitize(v, edges[:-1], right=False) - 1, 0, n_bins - 1)
    return bins, vmin, vmax

def mutual_information_discrete(x_bins, y_bins, n_x_bins, n_y_bins, eps=1e-12):
    x_bins = np.asarray(x_bins).ravel()
    y_bins = np.asarray(y_bins).ravel()
    N = len(x_bins)
    joint = np.zeros((n_x_bins, n_y_bins), dtype=float)
    for xb, yb in zip(x_bins, y_bins):
        joint[xb, yb] += 1.0
    joint /= max(N, 1)
    px = joint.sum(axis=1, keepdims=True)
    py = joint.sum(axis=0, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        ratio = joint / (px @ py + eps)
        term = joint * np.log(np.clip(ratio, eps, None))
    return float(np.nansum(term))

def find_sample(sample_path, fallback_dir, expected_shape):
    H, W, C = expected_shape
    if sample_path and os.path.isfile(sample_path):
        x = np.load(sample_path)
        if x.shape != (H, W, C):
            raise ValueError(f"Expected {expected_shape}, got {x.shape}")
        return x.astype(np.float32)[None, ...]
    if fallback_dir and os.path.isdir(fallback_dir):
        candidates = sorted(glob.glob(os.path.join(fallback_dir, "image_*.npy")))
        if not candidates:
            candidates = sorted(glob.glob(os.path.join(fallback_dir, "*.npy")))
        for fp in candidates:
            try:
                x = np.load(fp)
                if x.shape == (H, W, C):
                    return x.astype(np.float32)[None, ...]
            except Exception:
                pass
    # Fallback synthetic
    rng = np.random.RandomState(0)
    x = rng.rand(H, W, C).astype(np.float32)
    return x[None, ...]

class MIMEPredictor:
    def __init__(self, model, head_tensor, h, w, c):
        if isinstance(model.inputs, (list, tuple)):
            inputs_for_submodel = list(model.inputs)
        else:
            inputs_for_submodel = [model.input]
        self.head_model = tf.keras.Model(inputs=inputs_for_submodel, outputs=head_tensor)
        self._single_input = (len(self.head_model.inputs) == 1)
        self._input_names = getattr(self.head_model, "input_names", None)
        self._sig = tf.TensorSpec(shape=[None, h, w, c], dtype=tf.float32)

        @tf.function(reduce_retracing=True, input_signature=[self._sig])
        def _predict_tf(x):
            # Feed as list to match the recorded input structure (avoids warnings)
            if self._single_input:
                return self.head_model([x], training=False)
            else:
                raise ValueError("Only single-input models are supported.")
        self._predict_tf = _predict_tf

    def probs(self, xb_np):
        x = tf.convert_to_tensor(xb_np, dtype=tf.float32)
        out = self._predict_tf(x).numpy()
        return ensure_softmax_probs(out)

# =========================
# MIME explanation
# =========================
def mime_local_importance(predictor, x1, class_to_explain='pred',
                          k=80, perturb_std_frac=0.10, n_bins=10,
                          clip_min=0.0, clip_max=1.0):
    base_probs = predictor.probs(x1)
    num_classes = base_probs.shape[-1]
    if class_to_explain == 'pred':
        explained_class = int(np.argmax(base_probs[0]))
    else:
        explained_class = int(class_to_explain)

    H, W, C = x1.shape[1:]
    mi_map = np.zeros((H, W, C), dtype=np.float32)

    vmin, vmax = float(np.min(x1)), float(np.max(x1))
    if vmax == vmin:
        vmax = vmin + 1.0
    value_range = max(vmax - vmin, 1e-6)
    sigma = perturb_std_frac * value_range

    xb = np.repeat(x1.astype(np.float32), repeats=k, axis=0)

    total_feats = H * W * C
    idx = 0
    for i in range(H):
        for j in range(W):
            for ch in range(C):
                base_val = float(x1[0, i, j, ch])
                noise = np.random.normal(loc=0.0, scale=sigma, size=(k,))
                pert_vals = np.clip(base_val + noise, clip_min, clip_max).astype(np.float32)
                xb[:] = x1
                xb[:, i, j, ch] = pert_vals
                probs = predictor.probs(xb)
                y = probs[:, explained_class]
                x_bins, _, _ = discretize(pert_vals, n_bins)
                y_bins, _, _ = discretize(y, n_bins)
                mi = mutual_information_discrete(x_bins, y_bins, n_bins, n_bins)
                mi_map[i, j, ch] = mi
                idx += 1
                if idx % 100 == 0:
                    print(f"[MIME] Processed {idx}/{total_feats} features...")
    return mi_map, explained_class

# =========================
# Plotting
# =========================
def _normalize(a):
    a = np.asarray(a, dtype=float)
    m, M = np.min(a), np.max(a)
    if M > m:
        return (a - m) / (M - m + 1e-12)
    return np.zeros_like(a)

def plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR, tag=""):
    H, W, C = mi_map.shape
    channel_names = CHANNEL_NAMES_DEFAULT if len(CHANNEL_NAMES_DEFAULT) == C else [f"channel_{i}" for i in range(C)]

    per_pixel = mi_map.sum(axis=2)
    per_channel = mi_map.sum(axis=(0,1))

    from matplotlib.colors import LinearSegmentedColormap
    cmap = LinearSegmentedColormap.from_list("wlb", ["white", "lightblue", "blue"])

    # Per-pixel heatmap
    fig = plt.figure(figsize=(5.5, 4.5))
    im = plt.imshow(_normalize(per_pixel), interpolation='nearest', cmap=cmap)
    plt.title(f"MIME per-pixel MI (class {explained_class}) [{H}x{W}]")
    plt.colorbar(im, label="normalized MI")
    plt.xticks(range(W)); plt.yticks(range(H))
    plt.savefig(os.path.join(outdir, f"{H}x{W}_ea_mi_per_pixel_class_{explained_class}.png"), dpi=200)
    plt.close(fig)

    # Per-channel bar
    fig = plt.figure(figsize=(max(8, C*0.7), 4.5))
    xs = np.arange(C)
    plt.bar(xs, per_channel)
    plt.xlabel("Channel")
    plt.ylabel("MI (sum over spatial)")
    plt.title(f"MIME per-channel MI (class {explained_class}) [{H}x{W}]")
    plt.xticks(xs, channel_names, rotation=45, ha="right")
    plt.savefig(os.path.join(outdir, f"{H}x{W}_ea_mi_per_channel_class_{explained_class}.png"), dpi=200)
    plt.close(fig)

    # Grid of channel heatmaps
    import matplotlib.colors as mcolors
    vmin = float(np.min(mi_map)); vmax = float(np.max(mi_map))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    n_cols = 4
    n_rows = int(math.ceil(C / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    axs = np.atleast_1d(axs).reshape(n_rows, n_cols)
    im_last = None
    for ch in range(C):
        r, c = divmod(ch, n_cols)
        ax = axs[r, c]
        im_last = ax.imshow(mi_map[:,:,ch], interpolation='nearest', norm=norm, cmap=cmap)
        ax.set_title(channel_names[ch], fontsize=9)
        ax.set_xticks(range(W)); ax.set_yticks(range(H))
    for idx in range(C, n_rows*n_cols):
        r, c = divmod(idx, n_cols)
        axs[r, c].axis("off")
    cbar = fig.colorbar(im_last, ax=axs, shrink=0.95, pad=0.02)
    cbar.set_label("Mutual Information (MI)")
    fig.suptitle(f"MIME per-feature MI by channel (class {explained_class}) [{H}x{W}]", y=0.995, fontsize=12)
    plt.savefig(os.path.join(outdir, f"{H}x{W}_ea_mi_per_channel_grid_class_{explained_class}.png"), dpi=200)
    plt.close(fig)

    print("[MIME] Saved 3 PNG visualizations.")

# =========================
# Main
# =========================
def main():
    model = tf.keras.models.load_model(MODEL_PATH)

    # Infer expected (H, W, C) from model input
    _, H, W, C = model.input_shape
    print(f"[INFO] Model expects input shape: (None, {H}, {W}, {C})")

    head = pick_classification_head(model, prefer_name="classification_output")
    predictor = MIMEPredictor(model, head_tensor=head, h=H, w=W, c=C)

    # Load sample matching the model's expected shape (or synthesize)
    x1 = find_sample(SAMPLE_NPY, FAC_IMAGE_DIR, expected_shape=(H, W, C))
    x1 = np.clip(x1, CLIP_MIN, CLIP_MAX).astype(np.float32)

    base_probs = predictor.probs(x1)
    print(f"[INFO] Predicted probs: {np.round(base_probs[0], 6)}")
    cexp = 0
    print(f"[INFO] Explaining class: {cexp}")

    mi_map, explained_class = mime_local_importance(
        predictor, x1,
        class_to_explain=cexp,
        k=K,
        perturb_std_frac=PERTURB_STD_FRAC,
        n_bins=N_BINS,
        clip_min=CLIP_MIN,
        clip_max=CLIP_MAX
    )
    print(f"[RESULT] mi_map shape: {mi_map.shape}")
    print(f"[RESULT] total MI sum: {mi_map.sum():.6f}")

    # Tag filenames with directory hint if useful
    tag = "" if not FAC_IMAGE_DIR else f"_{os.path.basename(FAC_IMAGE_DIR)}"
    plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR, tag=tag)
    print(f"[INFO] PNGs saved in {OUTDIR}")

if __name__ == "__main__":
    main()

[INFO] Model expects input shape: (None, 7, 7, 12)
[INFO] Predicted probs: [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[INFO] Explaining class: 0
[MIME] Processed 100/588 features...
[MIME] Processed 200/588 features...
[MIME] Processed 300/588 features...
[MIME] Processed 400/588 features...
[MIME] Processed 500/588 features...
[RESULT] mi_map shape: (7, 7, 12)
[RESULT] total MI sum: -0.000000
[MIME] Saved 3 PNG visualizations.
[INFO] PNGs saved in mime_explanations


In [50]:
import os
import math
import glob
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# =========================
# Config
# =========================
MODEL_PATH = os.path.join("models", "veremi_images_m_7x7_ea_i.keras")

SAMPLE_NPY = None          # e.g., r"samples\example_0001.npy"
FAC_IMAGE_DIR = "veremi_multilevel_images_7x7_ea"

K = 80
PERTURB_STD_FRAC = 0.10
N_BINS = 10
CLASS_TO_EXPLAIN = 'pred'

CLIP_MIN, CLIP_MAX = 0.0, 1.0

OUTDIR = "mime_explanations"
os.makedirs(OUTDIR, exist_ok=True)

CHANNEL_NAMES_DEFAULT = [
    "sendTime", "sender", "posx", "posy",
    "spdx_n", "spdy_n", "aclx", "acly",
    "hedx", "hedy", "hedx_n", "hedy_n"
]

# =========================
# Utilities
# =========================
def pick_classification_head(model, prefer_name=None):
    outputs = model.outputs
    if isinstance(outputs, tf.Tensor):
        return outputs
    chosen = None
    if prefer_name is not None and hasattr(model, "output_names"):
        try:
            idx = list(model.output_names).index(prefer_name)
            chosen = outputs[idx]
        except Exception:
            pass
    if chosen is None:
        for t in outputs:
            try:
                last = int(t.shape[-1])
                if last >= 2:
                    chosen = t
                    break
            except Exception:
                continue
    if chosen is None:
        chosen = outputs[0]
    return chosen

def ensure_softmax_probs(logits_or_probs):
    arr = np.asarray(logits_or_probs)
    s = arr.sum(axis=-1, keepdims=True)
    if np.all(arr >= -1e-6) and np.all(np.abs(s - 1.0) < 1e-3):
        return arr
    exps = np.exp(arr - np.max(arr, axis=-1, keepdims=True))
    return exps / np.clip(exps.sum(axis=-1, keepdims=True), 1e-12, None)

def discretize(values, n_bins):
    v = np.asarray(values).ravel()
    vmin, vmax = np.min(v), np.max(v)
    if vmax == vmin:
        return np.zeros_like(v, dtype=int), vmin, vmax
    edges = np.linspace(vmin, vmax, n_bins + 1)
    bins = np.clip(np.digitize(v, edges[:-1], right=False) - 1, 0, n_bins - 1)
    return bins, vmin, vmax

def mutual_information_discrete(x_bins, y_bins, n_x_bins, n_y_bins, eps=1e-12):
    x_bins = np.asarray(x_bins).ravel()
    y_bins = np.asarray(y_bins).ravel()
    N = len(x_bins)
    joint = np.zeros((n_x_bins, n_y_bins), dtype=float)
    for xb, yb in zip(x_bins, y_bins):
        joint[xb, yb] += 1.0
    joint /= max(N, 1)
    px = joint.sum(axis=1, keepdims=True)
    py = joint.sum(axis=0, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        ratio = joint / (px @ py + eps)
        term = joint * np.log(np.clip(ratio, eps, None))
    return float(np.nansum(term))

def find_sample(sample_path, fallback_dir, expected_shape):
    H, W, C = expected_shape
    if sample_path and os.path.isfile(sample_path):
        x = np.load(sample_path)
        if x.shape != (H, W, C):
            raise ValueError(f"Expected {expected_shape}, got {x.shape}")
        return x.astype(np.float32)[None, ...]
    if fallback_dir and os.path.isdir(fallback_dir):
        candidates = sorted(glob.glob(os.path.join(fallback_dir, "image_*.npy")))
        if not candidates:
            candidates = sorted(glob.glob(os.path.join(fallback_dir, "*.npy")))
        for fp in candidates:
            try:
                x = np.load(fp)
                if x.shape == (H, W, C):
                    return x.astype(np.float32)[None, ...]
            except Exception:
                pass
    # Fallback synthetic
    rng = np.random.RandomState(0)
    x = rng.rand(H, W, C).astype(np.float32)
    return x[None, ...]

class MIMEPredictor:
    def __init__(self, model, head_tensor, h, w, c):
        if isinstance(model.inputs, (list, tuple)):
            inputs_for_submodel = list(model.inputs)
        else:
            inputs_for_submodel = [model.input]
        self.head_model = tf.keras.Model(inputs=inputs_for_submodel, outputs=head_tensor)
        self._single_input = (len(self.head_model.inputs) == 1)
        self._input_names = getattr(self.head_model, "input_names", None)
        self._sig = tf.TensorSpec(shape=[None, h, w, c], dtype=tf.float32)

        @tf.function(reduce_retracing=True, input_signature=[self._sig])
        def _predict_tf(x):
            # Feed as list to match the recorded input structure (avoids warnings)
            if self._single_input:
                return self.head_model([x], training=False)
            else:
                raise ValueError("Only single-input models are supported.")
        self._predict_tf = _predict_tf

    def probs(self, xb_np):
        x = tf.convert_to_tensor(xb_np, dtype=tf.float32)
        out = self._predict_tf(x).numpy()
        return ensure_softmax_probs(out)

# =========================
# MIME explanation
# =========================
def mime_local_importance(predictor, x1, class_to_explain='pred',
                          k=80, perturb_std_frac=0.10, n_bins=10,
                          clip_min=0.0, clip_max=1.0):
    base_probs = predictor.probs(x1)
    num_classes = base_probs.shape[-1]
    if class_to_explain == 'pred':
        explained_class = int(np.argmax(base_probs[0]))
    else:
        explained_class = int(class_to_explain)

    H, W, C = x1.shape[1:]
    mi_map = np.zeros((H, W, C), dtype=np.float32)

    vmin, vmax = float(np.min(x1)), float(np.max(x1))
    if vmax == vmin:
        vmax = vmin + 1.0
    value_range = max(vmax - vmin, 1e-6)
    sigma = perturb_std_frac * value_range

    xb = np.repeat(x1.astype(np.float32), repeats=k, axis=0)

    total_feats = H * W * C
    idx = 0
    for i in range(H):
        for j in range(W):
            for ch in range(C):
                base_val = float(x1[0, i, j, ch])
                noise = np.random.normal(loc=0.0, scale=sigma, size=(k,))
                pert_vals = np.clip(base_val + noise, clip_min, clip_max).astype(np.float32)
                xb[:] = x1
                xb[:, i, j, ch] = pert_vals
                probs = predictor.probs(xb)
                y = probs[:, explained_class]
                x_bins, _, _ = discretize(pert_vals, n_bins)
                y_bins, _, _ = discretize(y, n_bins)
                mi = mutual_information_discrete(x_bins, y_bins, n_bins, n_bins)
                mi_map[i, j, ch] = mi
                idx += 1
                if idx % 100 == 0:
                    print(f"[MIME] Processed {idx}/{total_feats} features...")
    return mi_map, explained_class

# =========================
# Plotting
# =========================
def _normalize(a):
    a = np.asarray(a, dtype=float)
    m, M = np.min(a), np.max(a)
    if M > m:
        return (a - m) / (M - m + 1e-12)
    return np.zeros_like(a)

def plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR, tag=""):
    H, W, C = mi_map.shape
    channel_names = CHANNEL_NAMES_DEFAULT if len(CHANNEL_NAMES_DEFAULT) == C else [f"channel_{i}" for i in range(C)]

    per_pixel = mi_map.sum(axis=2)
    per_channel = mi_map.sum(axis=(0,1))

    from matplotlib.colors import LinearSegmentedColormap
    cmap = LinearSegmentedColormap.from_list("wlb", ["white", "lightblue", "blue"])

    # Per-pixel heatmap
    fig = plt.figure(figsize=(5.5, 4.5))
    im = plt.imshow(_normalize(per_pixel), interpolation='nearest', cmap=cmap)
    plt.title(f"MIME per-pixel MI (class {explained_class}) [{H}x{W}]")
    plt.colorbar(im, label="normalized MI")
    plt.xticks(range(W)); plt.yticks(range(H))
    plt.savefig(os.path.join(outdir, f"{H}x{W}_ea_mi_per_pixel_class_{explained_class}.png"), dpi=200)
    plt.close(fig)

    # Per-channel bar
    fig = plt.figure(figsize=(max(8, C*0.7), 4.5))
    xs = np.arange(C)
    plt.bar(xs, per_channel)
    plt.xlabel("Channel")
    plt.ylabel("MI (sum over spatial)")
    plt.title(f"MIME per-channel MI (class {explained_class}) [{H}x{W}]")
    plt.xticks(xs, channel_names, rotation=45, ha="right")
    plt.savefig(os.path.join(outdir, f"{H}x{W}_ea_mi_per_channel_class_{explained_class}.png"), dpi=200)
    plt.close(fig)

    # Grid of channel heatmaps
    import matplotlib.colors as mcolors
    vmin = float(np.min(mi_map)); vmax = float(np.max(mi_map))
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    n_cols = 4
    n_rows = int(math.ceil(C / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    axs = np.atleast_1d(axs).reshape(n_rows, n_cols)
    im_last = None
    for ch in range(C):
        r, c = divmod(ch, n_cols)
        ax = axs[r, c]
        im_last = ax.imshow(mi_map[:,:,ch], interpolation='nearest', norm=norm, cmap=cmap)
        ax.set_title(channel_names[ch], fontsize=9)
        ax.set_xticks(range(W)); ax.set_yticks(range(H))
    for idx in range(C, n_rows*n_cols):
        r, c = divmod(idx, n_cols)
        axs[r, c].axis("off")
    cbar = fig.colorbar(im_last, ax=axs, shrink=0.95, pad=0.02)
    cbar.set_label("Mutual Information (MI)")
    fig.suptitle(f"MIME per-feature MI by channel (class {explained_class}) [{H}x{W}]", y=0.995, fontsize=12)
    plt.savefig(os.path.join(outdir, f"{H}x{W}_ea_mi_per_channel_grid_class_{explained_class}.png"), dpi=200)
    plt.close(fig)

    print("[MIME] Saved 3 PNG visualizations.")

def main():
    model = tf.keras.models.load_model(MODEL_PATH)

    # Infer expected (H, W, C) from model input
    _, H, W, C = model.input_shape
    print(f"[INFO] Model expects input shape: (None, {H}, {W}, {C})")

    head = pick_classification_head(model, prefer_name="classification_output")
    predictor = MIMEPredictor(model, head_tensor=head, h=H, w=W, c=C)

    # Load sample matching the model's expected shape (or synthesize)
    x1 = find_sample(SAMPLE_NPY, FAC_IMAGE_DIR, expected_shape=(H, W, C))
    x1 = np.clip(x1, CLIP_MIN, CLIP_MAX).astype(np.float32)

    base_probs = predictor.probs(x1)
    print(f"[INFO] Predicted probs: {np.round(base_probs[0], 6)}")
    cexp = 19
    print(f"[INFO] Explaining class: {cexp}")

    mi_map, explained_class = mime_local_importance(
        predictor, x1,
        class_to_explain=cexp,
        k=K,
        perturb_std_frac=PERTURB_STD_FRAC,
        n_bins=N_BINS,
        clip_min=CLIP_MIN,
        clip_max=CLIP_MAX
    )
    print(f"[RESULT] mi_map shape: {mi_map.shape}")
    print(f"[RESULT] total MI sum: {mi_map.sum():.6f}")

    # Tag filenames with directory hint if useful
    tag = "" if not FAC_IMAGE_DIR else f"_{os.path.basename(FAC_IMAGE_DIR)}"
    plot_and_save_mime(mi_map, explained_class, outdir=OUTDIR, tag=tag)
    print(f"[INFO] PNGs saved in {OUTDIR}")

if __name__ == "__main__":
    main()

[INFO] Model expects input shape: (None, 7, 7, 12)
[INFO] Predicted probs: [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[INFO] Explaining class: 19
[MIME] Processed 100/588 features...
[MIME] Processed 200/588 features...
[MIME] Processed 300/588 features...
[MIME] Processed 400/588 features...
[MIME] Processed 500/588 features...
[RESULT] mi_map shape: (7, 7, 12)
[RESULT] total MI sum: -0.000000
[MIME] Saved 3 PNG visualizations.
[INFO] PNGs saved in mime_explanations
