In [None]:
import numpy as np
import torch
import torchvision
from ultralytics import YOLO

from drise.core.explainer import DRISE
from drise.core.utils import (
    get_data_path,
    load_image,
    rescale_boxes,
    resize_image_for_detector,
    upscale_saliency_map,
)
from drise.core.visualization import COCO_CLASSES_91, DRISEVisualizer
from drise.core.wrappers.faster_rcnn import TorchvisionFasterRCNNWrapper
from drise.core.wrappers.yolo import YOLOWrapper


# Demo Faster-RCNN

In [None]:
def demo_faster_rcnn(
    image_path: str,
    num_masks: int = 5000,
    mask_res: tuple = (16, 16),
    mask_prob: float = 0.5,
    save_path: str = get_data_path() / "rcnn_drise_results.png",
):
    """
    End-to-end D-RISE demo with torchvision Faster R-CNN.

    Args:
        image_path: Path to the input image.
        num_masks: Number of random masks (paper: 5000).
        mask_res: Low-resolution mask grid size.
        mask_prob: Pixel preservation probability (paper: 0.5).
        batch_size: Masked images per forward pass. Higher = faster.
            Reduce if running out of GPU memory.
        save_path: Output visualization path.
    """
    print("=" * 60)
    print("D-RISE Demo — Faster R-CNN")
    print("=" * 60)

    print("\n[1/5] Loading image...")
    image_original = load_image(image_path)
    image_resized = resize_image_for_detector(
        image_original, detector_type="fasterrcnn"
    )
    _, orig_H, orig_W = image_original.shape
    _, work_H, work_W = image_resized.shape
    print(f"      Original:  {orig_H}x{orig_W}")
    print(f"      Working:   {work_H}x{work_W}")

    cell_h = int(np.ceil(work_H / mask_res[0]))
    cell_w = int(np.ceil(work_W / mask_res[1]))
    print(f"      Mask grid: {mask_res} → cell size: {cell_h}x{cell_w} px")

    print("\n[2/5] Loading Faster R-CNN...")
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        pretrained=True
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"      Device: {device}")

    display_detector = TorchvisionFasterRCNNWrapper(
        model, device=device, score_threshold=0.5
    )
    explain_detector = TorchvisionFasterRCNNWrapper(
        model, device=device, score_threshold=0.01
    )

    print("\n[3/5] Running detector...")
    detections = display_detector.detect(image_resized)
    print(f"      Found {len(detections)} detections:")
    for i, det in enumerate(detections):
        name = COCO_CLASSES_91[det.label]
        print(f"      [{i}] {name} (score={det.score:.3f})")

    if len(detections) == 0:
        print("      No detections found.")
        return

    print("\n[4/5] Running D-RISE...")
    print(f"      N={num_masks}, res={mask_res}, p={mask_prob}")

    drise = DRISE(
        detector=explain_detector,
        num_masks=num_masks,
        mask_res=mask_res,
        mask_prob=mask_prob,
        device=device,
    )

    target_dets = [
        {"box": det.box, "label": det.label} for det in detections
    ]

    saliency_maps = drise.explain(
        image=image_resized,
        target_detections=target_dets,
        num_classes=91,
        verbose=True,
    )

    print(f"\n[5/5] Upscaling and saving to {save_path}...")

    saliency_maps_original = [
        upscale_saliency_map(smap, orig_H, orig_W)
        for smap in saliency_maps
    ]

    target_dets_original = rescale_boxes(
        target_dets, work_H, work_W, orig_H, orig_W
    )

    DRISEVisualizer.visualize_explanations(
        image=image_original,
        target_detections=target_dets_original,
        saliency_maps=saliency_maps_original,
        save_path=save_path,
    )

    print("\nDone!")

# Demo YOLO

In [None]:
def demo_yolo(
    image_path: str,
    num_masks: int = 5000,
    mask_res: tuple = (16, 16),
    mask_prob: float = 0.5,
    save_path: str = get_data_path() / "yolo_drise_results.png",
):
    """
    End-to-end D-RISE demo with ultralytics YOLO.

    Args:
        image_path: Path to the input image.
        num_masks: Number of random masks (paper: 5000).
        mask_res: Low-resolution mask grid size.
        mask_prob: Pixel preservation probability (paper: 0.5).
        batch_size: Masked images per forward pass.
        save_path: Output visualization path.
    """
    print("=" * 60)
    print("D-RISE Demo — YOLOv8")
    print("=" * 60)

    print("\n[1/5] Loading image...")
    image_original = load_image(image_path)
    image_resized = resize_image_for_detector(
        image_original, detector_type="yolo"
    )
    _, orig_H, orig_W = image_original.shape
    _, work_H, work_W = image_resized.shape
    print(f"      Original:  {orig_H}x{orig_W}")
    print(f"      Working:   {work_H}x{work_W}")

    cell_h = int(np.ceil(work_H / mask_res[0]))
    cell_w = int(np.ceil(work_W / mask_res[1]))
    print(f"      Mask grid: {mask_res} → cell size: {cell_h}x{cell_w} px")

    print("\n[2/5] Loading YOLOv8...")
    model = YOLO("yolov8n.pt")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"      Device: {device}")

    display_detector = YOLOWrapper(
        model, device=device, score_threshold=0.5, num_classes=80
    )
    explain_detector = YOLOWrapper(
        model, device=device, score_threshold=0.01, num_classes=80
    )

    print("\n[3/5] Running detector...")
    detections = display_detector.detect(image_resized)
    print(f"      Found {len(detections)} detections:")
    for i, det in enumerate(detections):
        print(f"      [{i}] class={det.label} (score={det.score:.3f})")

    if len(detections) == 0:
        print("      No detections found.")
        return

    print("\n[4/5] Running D-RISE...")
    print(f"      N={num_masks}, res={mask_res}, p={mask_prob}")

    drise = DRISE(
        detector=explain_detector,
        num_masks=num_masks,
        mask_res=mask_res,
        mask_prob=mask_prob,
        device=device,
    )

    target_dets = [
        {"box": det.box, "label": det.label} for det in detections
    ]

    saliency_maps = drise.explain(
        image=image_resized,
        target_detections=target_dets,
        num_classes=80,
        verbose=True,
    )

    print(f"\n[5/5] Upscaling and saving to {save_path}...")

    saliency_maps_original = [
        upscale_saliency_map(smap, orig_H, orig_W)
        for smap in saliency_maps
    ]

    target_dets_original = rescale_boxes(
        target_dets, work_H, work_W, orig_H, orig_W
    )

    DRISEVisualizer.visualize_explanations(
        class_names=model.names,
        image=image_original,
        target_detections=target_dets_original,
        saliency_maps=saliency_maps_original,
        save_path=save_path,
    )

    print("\nDone!")

In [None]:
IMG_PATH = get_data_path() / "cat.jpg"

In [None]:
demo_faster_rcnn(IMG_PATH)

In [None]:
demo_yolo(IMG_PATH)