 notebook provides a production-ready pipeline for:

1. Converting COCO axis-aligned bboxes to oriented bounding boxes (OBBs).
2. Cropping image patches based on OBBs and trimming borders.
3. Predicting object orientation using a trained ResNet18 model.
4. Updating COCO annotations with predicted rotations.
5. Running end-to-end over all batches in a specified directory.

The flow follows modular functions, dynamic batch discovery, progress bars, and clean logging.

In [35]:


# %%
# Imports & Configuration
import json
import logging
from pathlib import Path
from copy import deepcopy

import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
from PIL import Image
import pandas as pd
from typing import Dict, Any
from tqdm.auto import tqdm

# Base directory (adjust if needed)
BASE_DIR = Path().resolve().parent
DATA_DIR = BASE_DIR / "data" / "rotation" / "batches"
CHECKPOINT_PATH = BASE_DIR / "pipeline" / "checkpoints" / "best_model.pth"
DEBUG_IMAGES_DIR = BASE_DIR / "pipeline" / "debug_imgs"
RESULTS_CSV = BASE_DIR / "pipeline" / "results.csv"

# Classes and device
CLASS_NAMES = [0, 90, 180, 270]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 300

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("rotation_pipeline")

# Transform
from torchvision.transforms import InterpolationMode
TRANSFORM = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# %%
# I/O Functions
def load_coco(path: Path) -> dict:
    logger.info(f"Loading COCO JSON: {path}")
    return json.loads(path.read_text(encoding="utf-8"))

def save_coco(coco: dict, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    logger.info(f"Saving updated COCO to {path}")
    path.write_text(json.dumps(coco, ensure_ascii=False, indent=2), encoding="utf-8")

# %%
# Model Loading
def load_model(ckpt_path: Path) -> nn.Module:
    logger.info(f"Loading model from {ckpt_path}")
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
    ckpt = torch.load(str(ckpt_path), map_location=DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    return model.to(DEVICE).eval()

# %%
# OBB Utilities
def create_obb(ann: Dict[str, Any]):
    x, y, w, h = ann["bbox"]
    cx, cy = x + w/2, y + h/2
    angle = ann.get("attributes", {}).get("rotation", 0.0)
    ann["bbox"] = [cx, cy, w, h, angle]

# %%
# Image & Rotation Helpers
def crop_obb_trim(img: np.ndarray, cx, cy, w, h, angle, pad=0) -> np.ndarray:
    theta = np.deg2rad(angle)
    R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float32)
    corners = np.float32([[-w/2, -h/2], [w/2, -h/2], [w/2, h/2], [-w/2, h/2]]) @ R.T + np.array([cx, cy])
    xs, ys = corners[:,0], corners[:,1]
    x0, x1 = max(int(np.floor(xs.min()))-pad, 0), min(int(np.ceil(xs.max()))+pad, img.shape[1]-1)
    y0, y1 = max(int(np.floor(ys.min()))-pad, 0), min(int(np.ceil(ys.max()))+pad, img.shape[0]-1)
    roi = img[y0:y1+1, x0:x1+1]
    mask = cv2.fillPoly(np.zeros(roi.shape[:2], np.uint8), [np.round(corners - [x0, y0]).astype(np.int32)], 255)
    masked = cv2.bitwise_and(roi, roi, mask=mask)
    ys_nz, xs_nz = np.where(mask>0)
    return masked[ys_nz.min():ys_nz.max()+1, xs_nz.min():xs_nz.max()+1]

def predict_rotation(model: nn.Module, patch: np.ndarray) -> float:
    rgb = cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)
    tensor = TRANSFORM(Image.fromarray(rgb)).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(tensor)
    return float(CLASS_NAMES[torch.argmax(logits, dim=1).item()])

# %%
# Batch Processing
def process_batch(batch_dir: Path, model: nn.Module, debug: bool=False) -> pd.DataFrame:
    img_dir = batch_dir / "images" / "default"
    coco_default = load_coco(batch_dir / "annotations" / "instances_default.json")
    coco_obb = deepcopy(coco_default)
    for ann in coco_obb["annotations"]:
        create_obb(ann)

    records = []
    cache = {}
    for ann in tqdm(coco_obb["annotations"], desc=batch_dir.name):
        cx, cy, w, h, orig = ann["bbox"]
        if orig not in CLASS_NAMES:
            continue
        img_info = next(img for img in coco_default["images"] if img["id"]==ann["image_id"])
        path = img_dir / img_info["file_name"]
        if not path.exists():
            logger.error(f"Missing file {path}")
            continue
        # Load image into cache
        if path in cache:
            img = cache[path]
        else:
            img = cv2.imread(str(path))
            if img is None:
                logger.error(f"Failed to read {path}")
                continue
            cache[path] = img

        patch = crop_obb_trim(img, cx, cy, w, h, orig)
        pred = predict_rotation(model, patch)
        records.append({"id": ann["id"], "orig": orig, "pred": pred, "file": path.name})
        if debug:
            DEBUG_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
            cv2.imwrite(str(DEBUG_IMAGES_DIR/ f"{batch_dir.name}_{ann['id']}.png"), patch)
    from pandas import DataFrame
    return DataFrame(records)

# %%
# Run All Batches

torch_model = load_model(CHECKPOINT_PATH)
all_results = []
for batch in sorted(DATA_DIR.iterdir()):
    if not batch.is_dir() or "rotation" not in batch.name:
        continue
    df = process_batch(batch, torch_model, debug=False)
    df["batch"] = batch.name
    all_results.append(df)

if all_results:
    df_all = pd.concat(all_results, ignore_index=True)
    display(df_all.groupby(["batch","orig"])['pred']
            .value_counts().unstack(fill_value=0))
    df_all.to_csv(RESULTS_CSV, index=False)
else:
    logger.warning("No batches processed.")


INFO:rotation_pipeline:Loading model from /Users/gerhardkarbeutz/cerpro/ocr-rec-lab/pipeline/checkpoints/best_model.pth


INFO:rotation_pipeline:Loading COCO JSON: /Users/gerhardkarbeutz/cerpro/ocr-rec-lab/data/rotation/batches/rotation_20250721_01/annotations/instances_default.json


rotation_20250721_01:   0%|          | 0/21163 [00:00<?, ?it/s]

Unnamed: 0_level_0,pred,0.0,90.0,180.0,270.0
batch,orig,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
rotation_20250721_01,0.0,13062,37,165,1180
rotation_20250721_01,90.0,69,0,9,26
rotation_20250721_01,180.0,38,0,2,4
rotation_20250721_01,270.0,890,1,46,452
