In [None]:
# Imports & Config
# --- YOLOv8 ONNX Preview: Pixel-perfect crops matching drawn bbox ---

import os
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Iterable

import cv2
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO

In [None]:
@dataclass(frozen=True)
class PreviewConfig:
    input_dir: Path = Path("/workspace/")
    model_path: Path = Path("/workspace/runs/detect/train/weights/best.onnx")

    # Detection filtering
    target_classes: Optional[set[str]] = frozenset({"target"})  # None => no filtering
    conf_thres: float = 0.10
    iou_thres: float = 0.45

    # Preview/cropping behavior
    take_largest_box: bool = True
    pad_frac: float = 0.0                  # expand bbox by pad_frac * max(w, h)
    preview_images: int = 20               # number of images to preview
    max_previews_per_image: int = 3        # max crops to show per image

    # Sampling (optional)
    random_sample: bool = False
    seed: int = 42

CFG = PreviewConfig()
print(CFG)


In [None]:
# Utilities (RGB conversion, bbox integerization, drawing)

SUPPORTED_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}


def to_rgb(img: np.ndarray) -> np.ndarray:
    """Convert BGR/BGRA image (OpenCV) to RGB/RGBA for matplotlib."""
    if img is None:
        return None
    if img.ndim == 3 and img.shape[2] == 4:
        return cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
    if img.ndim == 3 and img.shape[2] == 3:
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def load_image_any(path: Path) -> Optional[np.ndarray]:
    """Load image with alpha if present."""
    img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    return img


def bbox_to_int_exclusive(
    xyxy: np.ndarray,
    img_shape: tuple,
    pad_frac: float = 0.0,
    rounding: str = "round",
) -> tuple[int, int, int, int]:
    """
    Convert float xyxy to integer bbox with end-exclusive convention:
      crop = img[y1:y2, x1:x2]
    This makes bbox math consistent with numpy slicing.

    rounding:
      - 'round'     : round all coords
      - 'floorceil' : floor for (x1,y1), ceil for (x2,y2)
    """
    x1, y1, x2, y2 = map(float, xyxy)
    w = x2 - x1
    h = y2 - y1
    pad = pad_frac * max(w, h)

    x1 -= pad; y1 -= pad; x2 += pad; y2 += pad

    if rounding == "round":
        xi1 = int(np.round(x1)); yi1 = int(np.round(y1))
        xi2 = int(np.round(x2)); yi2 = int(np.round(y2))
    else:
        xi1 = int(np.floor(x1)); yi1 = int(np.floor(y1))
        xi2 = int(np.ceil(x2));  yi2 = int(np.ceil(y2))

    H, W = img_shape[:2]

    # Clamp to [0..W] / [0..H] for end-exclusive upper bound
    xi1 = max(0, min(W, xi1))
    yi1 = max(0, min(H, yi1))
    xi2 = max(0, min(W, xi2))
    yi2 = max(0, min(H, yi2))

    # Ensure non-empty region for slicing
    if xi2 <= xi1: xi2 = min(W, xi1 + 1)
    if yi2 <= yi1: yi2 = min(H, yi1 + 1)

    return xi1, yi1, xi2, yi2


def draw_boxes_pixelperfect(
    img: np.ndarray,
    boxes_xyxy: np.ndarray,
    clses: np.ndarray,
    confs: np.ndarray,
    id2name: dict,
    pad_frac: float = 0.0,
) -> np.ndarray:
    """
    Draw rectangles consistent with numpy cropping.
    Because cv2.rectangle uses inclusive endpoint, we draw at (x2-1, y2-1)
    for end-exclusive bbox [x1, x2), [y1, y2).
    """
    vis = to_rgb(img).copy()

    # For cv2 drawing, we need a BGR/RGB image. We'll draw on BGR then convert back.
    # If vis is RGBA, convert to BGR (alpha discarded for drawing) then back to RGB.
    if vis.ndim == 3 and vis.shape[2] == 4:
        draw_img = cv2.cvtColor(vis, cv2.COLOR_RGBA2BGR)
        back_to_rgb = lambda x: cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    else:
        draw_img = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
        back_to_rgb = lambda x: cv2.cvtColor(x, cv2.COLOR_BGR2RGB)

    for (x1, y1, x2, y2), c, cf in zip(boxes_xyxy, clses, confs):
        xi1, yi1, xi2, yi2 = bbox_to_int_exclusive(
            (x1, y1, x2, y2), draw_img.shape, pad_frac=pad_frac, rounding="round"
        )

        # pixel-perfect: draw inclusive endpoint to match slicing
        pt1 = (xi1, yi1)
        pt2 = (max(xi1, xi2 - 1), max(yi1, yi2 - 1))

        cv2.rectangle(draw_img, pt1, pt2, (0, 255, 0), 2)

        label = f"{id2name.get(int(c), int(c))}:{float(cf):.2f}"
        cv2.putText(
            draw_img, label, (xi1, max(0, yi1 - 5)),
            cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA
        )

    return back_to_rgb(draw_img)


In [None]:
# ONNX input-size detection (imgsz) + Provider check

def detect_onnx_input_hw(onnx_path: Path, fallback: int = 512) -> int:
    """
    Read expected HxW from ONNX input shape to avoid size mismatch errors.
    Returns a single imgsz (min(H,W) if not square).
    """
    try:
        import onnxruntime as ort
        sess = ort.InferenceSession(
            str(onnx_path),
            providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
        )
        shp = sess.get_inputs()[0].shape  # e.g., [None, 3, 512, 512]
        H = shp[2] if isinstance(shp[2], int) else fallback
        W = shp[3] if isinstance(shp[3], int) else fallback
        return int(min(H, W)) if H != W else int(H)
    except Exception:
        return int(fallback)


# Provider availability info (optional)
try:
    import onnxruntime as ort
    providers = ort.get_available_providers()
    print("[INFO] ONNXRuntime providers:", providers)
    if "CUDAExecutionProvider" not in providers:
        print("[INFO] CUDAExecutionProvider not available -> inference will run on CPU.")
except Exception as e:
    print(f"[INFO] ONNXRuntime check skipped: {e}")

imgsz = detect_onnx_input_hw(CFG.model_path, fallback=512)
print(f"[INFO] Detected ONNX imgsz: {imgsz}x{imgsz}")


In [None]:
# Load YOLO ONNX + Resolve classes safely

model = YOLO(str(CFG.model_path))

# Safe names extraction (ONNX sometimes needs fallback)
id2name = (
    getattr(getattr(model, "model", None), "names", None)
    or getattr(model, "names", None)
    or {0: "class0"}
)

# Normalize name->id (case-insensitive)
name2id_lower = {str(v).lower(): int(k) for k, v in id2name.items()}

if CFG.target_classes is None:
    target_ids = None
else:
    missing = [c for c in CFG.target_classes if c.lower() not in name2id_lower]
    if missing:
        print(f"[WARN] Missing classes in model: {missing} -> disable class filtering (show all detections).")
        target_ids = None
    else:
        target_ids = {name2id_lower[c.lower()] for c in CFG.target_classes}

print("[INFO] model.names:", id2name)
print("[INFO] target_ids :", target_ids if target_ids is not None else "ALL")


In [None]:
# Collect images + Run prediction + Build previews

def list_images(input_dir: Path) -> list[Path]:
    files = []
    for p in input_dir.iterdir():
        if p.is_file() and p.suffix.lower() in SUPPORTED_EXTS:
            files.append(p)
    files.sort()
    return files


img_files = list_images(CFG.input_dir)
if not img_files:
    raise RuntimeError(f"No images found in: {CFG.input_dir}")

# Optionally random sample
if CFG.random_sample:
    rng = np.random.default_rng(CFG.seed)
    idx = rng.choice(len(img_files), size=min(CFG.preview_images, len(img_files)), replace=False)
    img_files = [img_files[i] for i in sorted(idx)]
else:
    img_files = img_files[:CFG.preview_images]

print(f"[INFO] Previewing {len(img_files)} image(s).")

rows = []
for path in img_files:
    img = load_image_any(path)
    if img is None:
        rows.append((path.name, None, []))
        continue

    # NOTE:
    # For ONNX, actual provider selection is handled by onnxruntime.
    # Ultralytics 'device' arg may not force CUDA for ONNX the same way as PyTorch.
    results = model.predict(
        source=str(path),
        imgsz=imgsz,
        conf=CFG.conf_thres,
        iou=CFG.iou_thres,
        verbose=False
    )

    if not results:
        rows.append((path.name, to_rgb(img), []))
        continue

    res = results[0]
    if res.boxes is None or len(res.boxes) == 0:
        rows.append((path.name, to_rgb(img), []))
        continue

    boxes = res.boxes.xyxy.detach().cpu().numpy()
    clses = res.boxes.cls.detach().cpu().numpy().astype(int)
    confs = res.boxes.conf.detach().cpu().numpy()

    # Visualization: draw all detections
    vis = draw_boxes_pixelperfect(img, boxes, clses, confs, id2name, pad_frac=CFG.pad_frac)

    # Filter class if requested
    idxs = list(range(len(clses)))
    if target_ids is not None:
        idxs = [i for i in idxs if clses[i] in target_ids]

    # If empty after filtering, fallback to top-1 confidence
    if not idxs:
        idxs = [int(np.argmax(confs))]

    # Choose largest box or top-N boxes
    if CFG.take_largest_box and len(idxs) > 1:
        areas = [(i, (boxes[i][2] - boxes[i][0]) * (boxes[i][3] - boxes[i][1])) for i in idxs]
        areas.sort(key=lambda x: x[1], reverse=True)
        idxs = [areas[0][0]]
    else:
        idxs = idxs[:CFG.max_previews_per_image]

    crops = []
    for i in idxs:
        xi1, yi1, xi2, yi2 = bbox_to_int_exclusive(boxes[i], img.shape, pad_frac=CFG.pad_frac, rounding="round")
        crop = img[yi1:yi2, xi1:xi2]  # end-exclusive slicing
        crops.append((id2name.get(int(clses[i]), f"cls{int(clses[i])}"), float(confs[i]), to_rgb(crop)))

    rows.append((path.name, vis, crops))


In [None]:
# Display previews
for fname, vis, crops in rows:
    if vis is None:
        print(f"[SKIP] failed to load: {fname}")
        continue

    cols = 1 + max(1, len(crops))
    plt.figure(figsize=(5 * cols, 5))

    plt.subplot(1, cols, 1)
    plt.imshow(vis)
    plt.title(f"Detections: {fname}")
    plt.axis("off")

    if crops:
        for j, (cls_name, cf, cr) in enumerate(crops, start=2):
            plt.subplot(1, cols, j)
            plt.imshow(cr)
            plt.title(f"Crop ({cls_name}, conf={cf:.2f})")
            plt.axis("off")
    else:
        plt.subplot(1, cols, 2)
        plt.imshow(vis)
        plt.title("No crops (after filter)")
        plt.axis("off")

    plt.tight_layout()
    plt.show()
