In [10]:
from __future__ import annotations

import argparse
import json
import logging
from pathlib import Path
from copy import deepcopy
from typing import Tuple, List, Dict, Any

import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
from PIL import Image

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.transforms import Affine2D
import pandas as pd

In [21]:
# -----------------------------------------------------------------------------
# Logging – enable *INFO* by default so that the user sees progress messages
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO,
                    format="%(levelname)7s | %(name)s | %(message)s")
logger = logging.getLogger("rotation-pipeline")

# -----------------------------------------------------------------------------
# Configuration – adapt these paths to your project layout
# -----------------------------------------------------------------------------
COCO_JSON        = Path("../data/rotation/batches/rotation_20250721_01/annotations/instances_default.json")
IMAGES_DIR       = Path("../data/rotation/batches/rotation_20250721_01/images/default/")
PRED_JSON        = Path("../data/rotation/batches/rotation_20250721_01/instances_predicted.json")
ATTR_JSON        = Path("../data/rotation/batches/rotation_20250721_01/instances_updated.json")
CHECKPOINT_PATH  = Path("checkpoints/best_model.pth")
CSV_OUT          = Path("../data/rotation/batches/rotation_20250721_01/rot_summary.csv")
DEBUG_DIR        = Path("debug")

# Class labels *in the order used during training*
CLASS_NAMES: List[int] = [0, 180, 270, 90]

IMAGE_SIZE = 300

# Device selection (GPU preferred)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# -----------------------------------------------------------------------------
# TorchVision preprocessing – identical to ImageNet‐trained ResNet18
# -----------------------------------------------------------------------------
TRANSFORM = transforms.Compose([
    transforms.Resize((IMAGE, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [19]:
# IO helpers
# -----------------------------------------------------------------------------

def load_coco(path: Path) -> Dict[str, Any]:
    logger.info("Loading COCO from %s", path)
    return json.loads(path.read_text(encoding="utf‑8"))


def save_coco(coco: Dict[str, Any], path: Path) -> None:
    path.parent.mkdir(parents=True,exist_ok=True)
    logger.info("Writing COCO to %s", path)
    path.write_text(json.dumps(coco,ensure_ascii=False,indent=2), encoding="utf‑8")

# -----------------------------------------------------------------------------
# Model loader
# -----------------------------------------------------------------------------

def load_model(ckpt_path: Path) -> nn.Module:
    """Load the fine‑tuned ResNet18 classifier."""
    logger.info("Loading model from %s", ckpt_path)
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
    ckpt = torch.load(str(ckpt_path), map_location=DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    return model.to(DEVICE).eval()


In [20]:
# Geometry helpers
# -----------------------------------------------------------------------------

def crop_box(img: np.ndarray, x: float, y: float, w: float, h: float) -> np.ndarray:
    x1, y1 = max(0, int(round(x))), max(0, int(round(y)))
    x2 = min(img.shape[1], int(round(x + w)))
    y2 = min(img.shape[0], int(round(y + h)))
    return img[y1:y2, x1:x2]


def predict_angle(model: nn.Module, patch: np.ndarray, cur_rot: float) -> float:
    """Predict the *absolute* CW rotation of the object in *degrees*."""
    # Undo the current CW rotation so the patch is visually upright for the CNN.
    h, w = patch.shape[:2]
    M = cv2.getRotationMatrix2D((w/2, h/2), ‑cur_rot, 1.0)  # OpenCV wants CCW, hence –cur_rot
    straight = cv2.warpAffine(patch, M, (w, h), borderMode=cv2.BORDER_REPLICATE)

    rgb = cv2.cvtColor(straight, cv2.COLOR_BGR2RGB)
    pil = Image.fromarray(rgb)
    t = TRANSFORM(pil).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits = model(t)
        idx = int(torch.argmax(logits, dim=1).item())
    return float(CLASS_NAMES[idx])


def extract_current_rotation(ann: Dict[str, Any]) -> float:
    """Return the CW rotation (°) stored in the annotation – 0 if absent."""
    bb = ann.get("bbox", [])
    if len(bb) == 5:
        return float(bb[4])
    attrs = ann.get("attributes", {})
    return float(attrs.get("rotation", 0.0))

SyntaxError: invalid non-printable character U+00A0 (4291106713.py, line 5)

In [6]:
# Debug Utilities – optional visualisation helpers
# -----------------------------------------------------------------------------

def debug_annotation(ann_id: int, image_id: int, fname: str, orig_rot: float, pred_rot: float,
                     bbox: List[float], full_img: np.ndarray, save_dir: Path | None = None):
    """Display or save side‑by‑side view of (crop, full image with overlays)."""
    x, y, w, h = bbox
    patch = crop_box(full_img, x, y, w, h)

    # Convert BGR → RGB for matplotlib
    patch_rgb = cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)
    full_rgb = cv2.cvtColor(full_img, cv2.COLOR_BGR2RGB)

    logger.debug("Annotation %s (%s): cur = %g°, pred = %g°", ann_id, fname, orig_rot, pred_rot)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), dpi=150)
    ax1.imshow(patch_rgb)
    ax1.axis("off")
    ax1.set_title(f"Crop {patch.shape[1]}×{patch.shape[0]}")

    ax2.imshow(full_rgb)
    ax2.axis("off")

    # ----- Visual overlays ---------------------------------------------------
    # Axis‑aligned bbox (orange dashed)
    ax2.add_patch(Rectangle((x, y), w, h, fill=False, edgecolor="orange", linestyle="--"))

    # Rotated bbox (red) – rotate **CCW** by *pred_rot* around centre (cx,cy)
    cx, cy = x + w / 2, y + h / 2
    obb = Rectangle((cx - w / 2, cy - h / 2), w, h, linewidth=1, edgecolor="red", fill=False)
    obb.set_transform(Affine2D().rotate_deg_around(cx, cy, pred_rot) + ax2.transData)
    ax2.add_patch(obb)

    # Anchor points
    ax2.scatter([x], [y], color="blue", s=5)
    ax2.scatter([cx], [cy], color="lime", s=5)
    ax2.set_title("Full Image with boxes")

    plt.tight_layout()

    if save_dir is not None:
        save_dir.mkdir(parents=True, exist_ok=True)
        outfile = save_dir / f"ann_{ann_id:05d}.png"
        fig.savefig(outfile, dpi=150)
        plt.close(fig)
        logger.debug("Saved debug image → %s", outfile)
    else:
        plt.show()

In [7]:
# Core routine
# -----------------------------------------------------------------------------

def update_rotations(target_image_id: int | None = None, debug: bool = False,
                     save_csv: Path | None = None, save_debug_dir: Path | None = None) -> None:
    """Predict & update rotations.

    Parameters
    ----------
    target_image_id : int | None
        If provided, restrict processing to this single *image_id* (handy during QA).
    debug : bool
        ``True`` → draw crops for every processed annotation.
    save_csv : Path | None
        If given, write summary DataFrame to this path (.csv).
    save_debug_dir : Path | None
        Directory that will receive one *PNG* per annotation when ``debug`` is *True*.
    """
    coco_raw = load_coco(COCO_JSON)
    coco_new = deepcopy(coco_raw)
    model = load_model(CHECKPOINT_PATH)

    # Cache full‑resolution images to avoid re‑loading on every annotation
    images: Dict[int, Dict] = {img["id"]: img for img in coco_raw["images"]}
    cache: Dict[str, np.ndarray] = {}
    records: List[Dict] = []

    # ---------------------------------------------------------------------
    # Pass 1 – iterate over annotations, predict angle, write back results
    # ---------------------------------------------------------------------
    for ann, ann_new in zip(coco_raw["annotations"], coco_new["annotations"]):
        if target_image_id is not None and ann["image_id"] != target_image_id:
            continue

        x, y, w, h = ann["bbox"][:4]
        cur_rot = extract_current_rotation(ann)

        # Only evaluate if *cur_rot* is close to one of the discrete classes
        if not any(abs(cur_rot - base) <= 10 for base in CLASS_NAMES):
            continue

        img_info = images[ann["image_id"]]
        fname = img_info["file_name"]

        if fname not in cache:
            img = cv2.imread(str(IMAGES_DIR / fname))
            if img is None:
                logger.error("Could not load image %s", fname)
                continue
            cache[fname] = img
        img_full = cache[fname]

        # ---------------- Prediction ----------------
        crop = crop_box(img_full, x, y, w, h)
        pred_rot = predict_angle(model, crop, cur_rot)

        records.append({
            "ann_id": ann["id"],
            "img_id": ann["image_id"],
            "file_name": fname,
            "bbox": [x, y, w, h],
            "orig_rot": cur_rot,
            "pred_rot": pred_rot,
        })

        # Write back only if changed – avoid dirty diffs in repo
        if pred_rot != cur_rot:
            logger.info("Ann %5d: %3g° → %3g°", ann["id"], cur_rot, pred_rot)
            cx, cy = x + w / 2, y + h / 2
            ann_new["bbox"] = [cx, cy, w, h, pred_rot]
            ann_new.setdefault("attributes", {})["rotation"] = pred_rot

        # Optional debug visualisation
        if debug:
            debug_annotation(ann["id"], ann["image_id"], fname, cur_rot, pred_rot,
                             [x, y, w, h], img_full, save_dir=save_debug_dir)

    # ---------------------------------------------------------------------
    # Pass 2 – persist artifacts
    # ---------------------------------------------------------------------
    save_coco(coco_new, PRED_JSON)
    save_coco(coco_new, ATTR_JSON)  # identical content, different file name for backwards compat

    df = pd.DataFrame(records)
    logger.info("Processed %s annotations", len(df))

    if save_csv is not None:
        save_csv.parent.mkdir(parents=True, exist_ok=True)
        df.to_csv(save_csv, index=False)
        logger.info("Wrote CSV summary → %s", save_csv)

    if not df.empty:
        logger.info("\n%s", df.head().to_string(index=False))


In [8]:
# Command‑line Interface
# -----------------------------------------------------------------------------

def _parse_args() -> argparse.Namespace:  # noqa: D401
    """CLI wrapper so you can run the script stand‑alone."""
    p = argparse.ArgumentParser(description="Rotation correction pipeline")
    p.add_argument("--image-id", type=int, default=None,
                   help="Process only this image_id (for debugging)")
    p.add_argument("--debug", action="store_true",
                   help="Show / save per‑annotation debug panels")
    p.add_argument("--csv-out", type=Path, default=CSV_OUT,
                   help="Write summary CSV to this path")
    p.add_argument("--debug-dir", type=Path, default=DEBUG_DIR,
                   help="When --debug, save images here instead of showing them")
    return p.parse_args()


def main() -> None:
    args = _parse_args()
    update_rotations(target_image_id=args.image_id,
                     debug=args.debug,
                     save_csv=args.csv_out,
                     save_debug_dir=args.debug_dir if args.debug else None)


if __name__ == "__main__":
    main()

usage: ipykernel_launcher.py [-h] [--image-id IMAGE_ID] [--debug]
                             [--csv-out CSV_OUT] [--debug-dir DEBUG_DIR]
ipykernel_launcher.py: error: unrecognized arguments: --f=/Users/gerhardkarbeutz/Library/Jupyter/runtime/kernel-v36b67a94d2752417c2ae831325c4efa8870227c18.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
