# Spatial Router

Notebook source for the `routers.spatial` module.

In [None]:
#| default_exp routers.spatial

In [None]:
#| export
import os
import importlib.metadata
from typing import Dict, Any, Optional

from fastapi import APIRouter, HTTPException, status, UploadFile, File
from pydantic import BaseModel

from pipeline.spatial_pipeline import run_spatial_mapping


class ErrorDetail(BaseModel):
    success: bool = False
    status: str = "error"
    code: int
    message: str


class ReadinessResponse(BaseModel):
    success: bool = True
    status: str = "ready"
    code: int = 200
    message: str = "Service is ready to handle requests"


class SpatialMapResponse(BaseModel):
    success: bool = True
    status: str = "ok"
    code: int = 200
    data: Dict[str, Any]


router = APIRouter()


@router.get(
    "/ready", response_model=ReadinessResponse, responses={503: {"model": ErrorDetail}}
)
async def readiness_check():
    """Check if the service is ready to handle requests."""
    # Check all required model files and weights
    def _repo_root() -> str:
        # /pipeline/routers/spatial.py -> /pipeline -> /<repo_root>
        here = os.path.dirname(__file__)
        return os.path.abspath(os.path.join(here, "..", ".."))

    def _resolve_repo_path(path: str) -> str:
        if os.path.isabs(path):
            return path
        return os.path.join(_repo_root(), path)

    def _resolve_pipeline_path(path: str) -> str:
        """Resolve paths relative to pipeline directory"""
        here = os.path.dirname(__file__)
        pipeline_root = os.path.abspath(os.path.join(here, ".."))
        return os.path.join(pipeline_root, path)

    def _is_git_lfs_pointer_file(path: str) -> bool:
        try:
            if not os.path.exists(path):
                return False
            if os.path.getsize(path) < 1024 * 1024:
                with open(path, "rb") as f:
                    head = f.read(128)
                return b"git-lfs.github.com/spec" in head
            return False
        except Exception:
            return False

    # Check all required model files and weights
    required_files = {
        "RF-DETR weights": _resolve_pipeline_path("weights/pre-trained-model/checkpoint_best_regular.pth"),
        "Angle model weights": _resolve_pipeline_path("weights/angle-models/Triangle.pth"),
        "SAM2 checkpoint": _resolve_repo_path("sam2_checkpoints/sam2_hiera_base_plus.pt"),
    }

    # Check SAM2 config in package installation
    try:
        import sam2  # pylint: disable=import-outside-toplevel

        sam2_config_path = os.path.join(
            os.path.dirname(sam2.__file__), "sam2_hiera_b+.yaml"
        )
        required_files["SAM2 config"] = sam2_config_path
    except Exception:
        required_files["SAM2 config"] = None

    missing_files = []
    for name, path in required_files.items():
        if path is None:
            missing_files.append(f"{name}: package not installed")
        elif not os.path.exists(path):
            missing_files.append(f"{name}: {path}")

    # Check for Git-LFS pointer files specifically for RF-DETR
    rfdetr_path = required_files.get("RF-DETR weights")
    if rfdetr_path and os.path.exists(rfdetr_path) and _is_git_lfs_pointer_file(rfdetr_path):
        missing_files.append(
            "RF-DETR weights are a Git-LFS pointer (not the real checkpoint). Run `git lfs pull` or download the full .pth. "
            f"({rfdetr_path})"
        )

    # Check required packages
    required_packages = ["rfdetr", "paddleocr"]
    missing_packages = []
    for package in required_packages:
        try:
            importlib.metadata.version(package)
        except Exception:
            missing_packages.append(package)

    if missing_files or missing_packages:
        error_message = "Service is not ready. Missing: "
        if missing_files:
            error_message += f"Files: {', '.join(missing_files)}. "
        if missing_packages:
            error_message += f"Packages: {', '.join(missing_packages)}."
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail=ErrorDetail(
                success=False,
                status="not_ready",
                code=status.HTTP_503_SERVICE_UNAVAILABLE,
                message=error_message,
            ).model_dump(),
        )

    return ReadinessResponse()


@router.get("/health")
async def health_check():
    """Simple health check endpoint."""
    return {"status": "healthy", "message": "Service is running"}


@router.post(
    "/diagram-to-json",
    response_model=SpatialMapResponse,
    responses={500: {"model": ErrorDetail}, 400: {"model": ErrorDetail}, 503: {"model": ErrorDetail}},
)
async def spatial_map(file: UploadFile = File(...)):
    """Convert a diagram image to JSON spatial representation."""
    # Validate file type
    allowed_extensions = {".jpg", ".jpeg", ".png", ".pdf", ".bmp", ".tiff"}
    file_ext = os.path.splitext(file.filename)[1].lower()
    if file_ext not in allowed_extensions:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ErrorDetail(
                success=False,
                status="error",
                code=status.HTTP_400_BAD_REQUEST,
                message=f"Unsupported file type: {file_ext}. Allowed types: {', '.join(allowed_extensions)}",
            ).model_dump(),
        )

    try:
        # Save uploaded file temporarily
        temp_file_path = f"/tmp/{file.filename}"
        with open(temp_file_path, "wb") as buffer:
            content = await file.read()
            buffer.write(content)

        try:
            # Process the image
            result = run_spatial_mapping(temp_file_path)
            return SpatialMapResponse(data=result)
        finally:
            # Clean up temporary file
            if os.path.exists(temp_file_path):
                os.remove(temp_file_path)

    except Exception as e:
        msg = str(e)
        if "RF-DETR model not available" in msg or "Git-LFS" in msg:
            raise HTTPException(
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
                detail=ErrorDetail(
                    success=False,
                    status="not_ready",
                    code=status.HTTP_503_SERVICE_UNAVAILABLE,
                    message=(
                        "Shape detection model is not available. "
                        "Your RF-DETR weights are missing or are a Git-LFS pointer file. "
                        "Download the real `checkpoint_best_regular.pth` (~400MB) to "
                        "weights/pre-trained-model/checkpoint_best_regular.pth and retry."
                    ),
                ).model_dump(),
            ) from e
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=ErrorDetail(
                success=False,
                status="error",
                code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                message=f"Failed to process diagram: {msg}",
            ).model_dump(),
        ) from e