In [None]:
import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

In [None]:
def load_keras_model(path, custom_objects=None):
    """
    Load and return a Keras model from `path`. If you used custom objects,
    pass them in custom_objects dict.
    """
    return load_model(path, custom_objects=custom_objects) if isinstance(path, str) else path

In [None]:
def get_last_conv_layer_name(model):
    """
    Return the name of the last Conv2D layer in the model.
    Raises ValueError if none found.
    """
    for layer in reversed(model.layers):
        # Some models may wrap layers; checking by class is most robust
        if isinstance(layer, tf.keras.layers.Conv2D):
            return layer.name
    raise ValueError("No Conv2D layer found in the model. Grad-CAM requires at least one conv layer.")

In [None]:
def make_gradcam_heatmap(img_array, model, last_conv_layer_name=None, pred_index=None):
    """
    Generate a Grad-CAM heatmap for a single image.
    - img_array: numpy array with shape (1, H, W, C) (batched)
    - model: tf.keras.Model
    - last_conv_layer_name: optional name of conv layer to use. If None, auto-detected.
    - pred_index: optional integer class index. If None, uses the top predicted class.
    Returns heatmap (H, W) normalized to [0, 1].
    """
    if last_conv_layer_name is None:
        last_conv_layer_name = get_last_conv_layer_name(model)

    # Model that gives conv outputs and predictions
    last_conv_layer = model.get_layer(last_conv_layer_name)
    grad_model = tf.keras.models.Model([model.inputs], [last_conv_layer.output, model.output])

    img_tensor = tf.cast(img_array, tf.float32)

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_tensor)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]

    # Gradients of the target class w.r.t. feature map
    grads = tape.gradient(class_channel, conv_outputs)
    if grads is None:
        # Gradients could be None in some TF graph setups; give a safe fallback
        raise RuntimeError("GradientTape returned None. Check model and input tensor shapes/preprocessing.")

    # Channel-wise mean of gradients
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    conv_outputs = conv_outputs[0]  # HxWxChannels
    heatmap = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1)

    # Apply ReLU and normalize
    heatmap = tf.maximum(heatmap, 0)
    max_val = tf.reduce_max(heatmap)
    if max_val == 0:
        heatmap = tf.zeros_like(heatmap)
    else:
        heatmap /= max_val

    return heatmap.numpy()

In [None]:
def overlay_heatmap_on_image(img, heatmap, alpha=0.4, colormap=cv2.COLORMAP_JET):
    """
    Overlay heatmap on image and return RGB uint8 image.
    - img: HxW or HxWx1 or HxWx3 numpy array. Values can be float [0,1] or uint8 [0,255].
    - heatmap: HxW numpy array with values in [0,1].
    - alpha: overlay strength of heatmap.
    Returns: HxWx3 RGB uint8 image.
    """
    # Prepare heatmap
    img_h, img_w = img.shape[:2]
    heatmap_resized = cv2.resize(heatmap, (img_w, img_h))
    heatmap_uint8 = np.uint8(255 * heatmap_resized)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, colormap)  # BGR

    # Normalize input image to uint8 BGR
    img_copy = img.copy()
    if img_copy.dtype != np.uint8:
        # assume floats in [0,1]
        if img_copy.max() <= 1.0:
            img_copy = (img_copy * 255).astype(np.uint8)
        else:
            img_copy = img_copy.astype(np.uint8)

    if img_copy.ndim == 2:
        img_color = cv2.cvtColor(img_copy, cv2.COLOR_GRAY2BGR)
    elif img_copy.ndim == 3 and img_copy.shape[2] == 1:
        img_color = cv2.cvtColor(img_copy[:, :, 0], cv2.COLOR_GRAY2BGR)
    else:
        # assume already 3-channel, but ensure dtype uint8
        img_color = img_copy

    # Overlay: heatmap_color is BGR
    superimposed_bgr = cv2.addWeighted(heatmap_color, alpha, img_color, 1 - alpha, 0)
    # Convert to RGB for matplotlib-friendly display
    superimposed_rgb = cv2.cvtColor(superimposed_bgr, cv2.COLOR_BGR2RGB)
    return superimposed_rgb

In [None]:
# High-level convenience
def visualize_gradcam_for_index(model_path, x_data, idx=0, pred_index=None, last_conv_layer_name=None, alpha=0.4, save_path=None, cmap=cv2.COLORMAP_JET, custom_objects=None):
    """
    Load model from the path provided, compute Grad-CAM for x_data[idx], and display.
    Returns tuple: (heatmap, overlay_image_rgb, predicted_class_index)
    """
    # Load model or use object
    model = load_keras_model(model_path, custom_objects=custom_objects)

    # Validate x_data index
    if idx < 0 or idx >= len(x_data):
        raise IndexError(f"idx {idx} out of range for x_data with length {len(x_data)}")

    img = x_data[idx]
    # Prepare batched input expected by model
    if img.ndim == 2:  # (H, W) -> expand to (1, H, W, 1)
        input_tensor = np.expand_dims(np.expand_dims(img, axis=-1), axis=0)
        img_for_display = img
    elif img.ndim == 3 and img.shape[-1] == 1:
        input_tensor = np.expand_dims(img, axis=0)
        img_for_display = img.squeeze()
    elif img.ndim == 3 and img.shape[-1] == 3:
        input_tensor = np.expand_dims(img, axis=0)
        img_for_display = img
    else:
        # Unexpected shape (maybe channel-first); attempt to handle
        input_tensor = np.expand_dims(img, axis=0)
        img_for_display = img.squeeze()

    # Compute heatmap
    if last_conv_layer_name is None:
        try:
            last_conv_layer_name = get_last_conv_layer_name(model)
        except ValueError as e:
            raise

    heatmap = make_gradcam_heatmap(input_tensor, model, last_conv_layer_name=last_conv_layer_name, pred_index=pred_index)

    # Overlay
    overlay_rgb = overlay_heatmap_on_image(img_for_display, heatmap, alpha=alpha, colormap=cmap)

    # Prediction info
    preds = model.predict(input_tensor)
    pred_cls = int(np.argmax(preds[0]))

    # Display side-by-side
    plt.figure(figsize=(6, 3))
    plt.subplot(1, 2, 1)
    plt.title("Original")
    if img_for_display.ndim == 2:
        plt.imshow(img_for_display, cmap='gray')
    else:
        # If uint8 in [0,255], handle it
        plt.imshow(img_for_display.astype('uint8'))
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title(f"Grad-CAM (pred={pred_cls})")
    plt.imshow(overlay_rgb)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    return heatmap, overlay_rgb, pred_cls

In [None]:
if __name__ == "__main__":
    demo_model_path = "best_model.keras"   # change if needed
    demo_x_test_path = None  # if you want to load x_test from disk, set a path and load here
    demo_idx = 0

    visualize_gradcam_for_index(model_path, x_data, idx)