In [None]:
#| default_exp rf_detr.detector

#| export
from typing import List, Tuple, Optional
import os

from PIL import Image


_RFDETR_MODEL = None
_CLASS_NAMES = [
    "objects",
    "Cloud",
    "Diamond",
    "Double Arrow",
    "Pentagon",
    "Racetrack",
    "Star",
    "Sticky Notes",
    "Triangle",
    "arrow",
    "arrow_head",
    "circle",
    "dashed-arrow",
    "dotted-arrow",
    "rectangle",
    "rounded rectangle",
    "solid-arrow",
]

# Type for detection: (label, top_left, bottom_right, confidence)
ShapeDetection = Tuple[str, Tuple[float, float], Tuple[float, float], float]


def _default_weights_path() -> str:
    here = os.path.dirname(__file__)
    # /pipeline/rf_detr/detector.py -> /pipeline -> /<repo_root>
    repo_root = os.path.abspath(os.path.join(here, "..", ".."))
    return os.path.join(repo_root, "weights", "pre-trained-model", "checkpoint_best_regular.pth")


def _is_git_lfs_pointer_file(path: str) -> bool:
    try:
        if not os.path.exists(path):
            return False
        if os.path.getsize(path) < 1024 * 1024:
            with open(path, "rb") as f:
                head = f.read(128)
            return b"git-lfs.github.com/spec" in head
        return False
    except Exception:
        return False


def _load_model(weights_path: Optional[str] = None):
    global _RFDETR_MODEL
    if _RFDETR_MODEL is not None:
        return _RFDETR_MODEL
    if weights_path is None:
        weights_path = _default_weights_path()
    if not os.path.exists(weights_path):
        print(f"[RF-DETR] Weights not found at: {weights_path}")
        _RFDETR_MODEL = None
        return None
    if _is_git_lfs_pointer_file(weights_path):
        print(
            "[RF-DETR] Weights file looks like a Git-LFS pointer, not the real .pth weights. "
            "Fetch LFS files (e.g., `git lfs pull`) or replace the file with the full checkpoint. "
            f"Path: {weights_path}"
        )
        _RFDETR_MODEL = None
        return None
    try:
        from rfdetr import RFDETRMedium
        model = RFDETRMedium(pretrain_weights=weights_path)
        model.optimize_for_inference()
        _RFDETR_MODEL = model
        return _RFDETR_MODEL
    except Exception as exc:
        print(f"[RF-DETR] Failed to load model: {exc}")
        _RFDETR_MODEL = None
        return None


def detect_shapes(file_path: str, threshold: float = 0.25, raise_on_model_failure: bool = False) -> List[ShapeDetection]:
    """Run RF-DETR inference and return list of (label, (x1,y1), (x2,y2), confidence).

    Returns empty list if model or image cannot be loaded.
    """
    model = _load_model()
    if model is None:
        if raise_on_model_failure:
            raise RuntimeError("RF-DETR model not available (missing weights, Git-LFS pointer weights, or load failure).")
        return []

    try:
        image = Image.open(file_path).convert("RGB")
    except Exception as exc:
        print(f"[RF-DETR] Failed to open image: {exc}")
        return []

    try:
        detections = model.predict(image, threshold=float(threshold))
        results: List[ShapeDetection] = []
        for bbox, class_id, conf in zip(detections.xyxy, detections.class_id, detections.confidence):
            x_min, y_min, x_max, y_max = bbox
            class_name = _CLASS_NAMES[int(class_id)] if 0 <= int(class_id) < len(_CLASS_NAMES) else str(class_id)
            results.append(
                (
                    class_name,
                    (float(x_min), float(y_min)),
                    (float(x_max), float(y_max)),
                    float(conf),
                )
            )
        return results
    except Exception as exc:
        print(f"[RF-DETR] Inference failed: {exc}")
        return []


In [None]:
#| export
import os
from typing import List, Tuple, Optional
from PIL import Image

try:
    from rfdetr import RFDETRMedium
except Exception as e:
    print(f"[RF-DETR] Failed to import rfdetr: {e}")
    RFDETRMedium = None

_CLASS_NAMES = [
    "arrow",
    "circle",
    "diamond",
    "hexagon",
    "line",
    "parallelogram", 
    "pentagon",
    "rectangle",
    "rounded rectangle",
    "solid-arrow",
]

# Type for detection: (label, top_left, bottom_right, confidence)
ShapeDetection = Tuple[str, Tuple[float, float], Tuple[float, float], float]


def _default_weights_path() -> str:
    here = os.path.dirname(__file__)
    # /pipeline/rf_detr/detector.py -> /pipeline
    pipeline_root = os.path.abspath(os.path.join(here, ".."))
    return os.path.join(pipeline_root, "weights", "pre-trained-model", "checkpoint_best_regular.pth")


def _is_git_lfs_pointer_file(path: str) -> bool:
    try:
        if not os.path.exists(path):
            return False
        if os.path.getsize(path) < 1024 * 1024:
            with open(path, "rb") as f:
                head = f.read(128)
            return b"git-lfs.github.com/spec" in head
        return False
    except Exception:
        return False


_RFDETR_MODEL = None


def _load_model(weights_path: Optional[str] = None):
    """Load RF-DETR model once and cache as global."""
    global _RFDETR_MODEL
    if _RFDETR_MODEL is not None:
        return _RFDETR_MODEL
    
    if weights_path is None:
        weights_path = _default_weights_path()
    
    if not os.path.exists(weights_path):
        print(f"[RF-DETR] Weights not found at: {weights_path}")
        _RFDETR_MODEL = None
        return None
    
    if _is_git_lfs_pointer_file(weights_path):
        print("[RF-DETR] Weights file looks like a Git-LFS pointer, not the real .pth weights. "
              "Fetch LFS files (e.g., `git lfs pull`) or replace the file with the full checkpoint. "
              f"Path: {weights_path}")
        _RFDETR_MODEL = None
        return None
    
    try:
        if RFDETRMedium is None:
            print("[RF-DETR] RF-DETR library not available")
            _RFDETR_MODEL = None
            return None
            
        model = RFDETRMedium(pretrain_weights=weights_path)
        model.optimize_for_inference()
        _RFDETR_MODEL = model
        return _RFDETR_MODEL
    except Exception as exc:
        print(f"[RF-DETR] Failed to load model: {exc}")
        _RFDETR_MODEL = None
        return None


def detect_shapes(file_path: str, threshold: float = 0.25, raise_on_model_failure: bool = False) -> List[ShapeDetection]:
    """Run RF-DETR inference and return list of (label, (x1,y1), (x2,y2), confidence).

    Returns empty list if model or image cannot be loaded, unless raise_on_model_failure is True.
    """
    model = _load_model()
    if model is None:
        if raise_on_model_failure:
            raise RuntimeError("RF-DETR model not available (missing weights, Git-LFS pointer weights, or load failure).")
        return []

    try:
        image = Image.open(file_path).convert("RGB")
    except Exception as exc:
        print(f"[RF-DETR] Failed to open image: {exc}")
        return []

    try:
        detections = model.predict(image, threshold=threshold)
        results: List[ShapeDetection] = []
        for bbox, class_id, conf in zip(detections.xyxy, detections.class_id, detections.confidence):
            x_min, y_min, x_max, y_max = bbox
            class_name = _CLASS_NAMES[int(class_id)] if 0 <= int(class_id) < len(_CLASS_NAMES) else str(class_id)
            results.append(
                (
                    class_name,
                    (float(x_min), float(y_min)),
                    (float(x_max), float(y_max)),
                    float(conf),
                )
            )
        return results
    except Exception as exc:
        print(f"[RF-DETR] Inference failed: {exc}")
        if raise_on_model_failure:
            raise RuntimeError(f"RF-DETR inference failed: {exc}") from exc
        return []