In [None]:
# %% [markdown]
# # Rotation Prediction Pipeline
# This notebook implements a production-ready pipeline to process multiple batches of images stored in folders. It converts COCO annotations into oriented bounding boxes (OBB), runs a rotation prediction model on each detected object, and saves updated annotations.
# The flow is kept similar to the original script, with core logic unchanged. Each batch directory under `BATCHES_DIR` will be processed automatically.

# %%
# ---------------------------
# IMPORTS & CONFIGURATION
# ---------------------------
import json
import logging
from pathlib import Path
from copy import deepcopy

import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
from PIL import Image

import pandas as pd

# Setup logging
default_level = logging.INFO
logging.basicConfig(level=default_level)
logger = logging.getLogger(__name__)

# Global configuration
BATCHES_DIR = Path("../data/rotation/batches/")
CHECKPOINT_PATH = Path("checkpoints/best_model.pth")
DEBUG_ROOT = Path("../pipeline/debug_imgs/")
CLASS_NAMES = [0, 180, 270, 90]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 300

# Transform for model input
TRANSFORM = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# %%
# ---------------------------
# I/O UTILITIES
# ---------------------------
def load_coco(path: Path) -> dict:
    logger.info(f"Loading COCO from {path}")
    return json.loads(path.read_text(encoding="utf-8"))

def save_coco(coco: dict, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    logger.info(f"Writing COCO to {path}")
    path.write_text(json.dumps(coco, ensure_ascii=False, indent=2), encoding="utf-8")

# %%
# ---------------------------
# MODEL LOADING
# ---------------------------
def load_model(ckpt_path: Path) -> nn.Module:
    logger.info(f"Loading model from {ckpt_path}")
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
    ckpt = torch.load(str(ckpt_path), map_location=DEVICE)
    model.load_state_dict(ckpt["model_state_dict"]);
    model.to(DEVICE).eval()
    return model

# %%
# ---------------------------
# UTILITY FUNCTIONS
# ---------------------------
def create_obb_tuple(anns: dict):
    bbox = anns.get("bbox")
    if isinstance(bbox, list) and len(bbox) == 4:
        x, y, w, h = bbox
        cx, cy = x + w/2, y + h/2
        angle = anns.get("attributes", {}).get("rotation", 0.0)
        anns["bbox"] = [cx, cy, w, h, angle]
    else:
        logger.warning(f"Unexpected bbox for ann {anns.get('id')}: {bbox}")


def replace_obb(coco: dict):
    for anns in coco.get("annotations", []):
        create_obb_tuple(anns)
    return coco


def extract_rotation(ann: dict) -> float:
    bb = ann.get("bbox", [])
    return float(bb[4]) if len(bb)==5 else float(ann.get("attributes", {}).get("rotation", 0.0))


def crop_obb_exact_mask_trim(img: np.ndarray, cx: float, cy: float, w: float, h: float, angle: float, pad: int=0) -> np.ndarray:
    theta = np.deg2rad(angle); ct, st = np.cos(theta), np.sin(theta)
    local = np.float32([[-w/2,-h/2],[w/2,-h/2],[w/2,h/2],[-w/2,h/2]])
    poly = (local @ np.float32([[ct,-st],[st,ct]]).T) + [cx, cy]
    xs, ys = poly[:,0], poly[:,1]
    x0, x1 = int(np.floor(xs.min()))-pad, int(np.ceil(xs.max()))+pad
    y0, y1 = int(np.floor(ys.min()))-pad, int(np.ceil(ys.max()))+pad
    x0, y0 = max(0,x0), max(0,y0)
    x1, y1 = min(img.shape[1]-1,x1), min(img.shape[0]-1,y1)
    roi = img[y0:y1+1, x0:x1+1].copy()
    poly_roi = (poly - [x0,y0]).astype(np.int32)
    mask = np.zeros(roi.shape[:2], np.uint8)
    cv2.fillPoly(mask, [poly_roi], 255)
    patch = cv2.bitwise_and(roi, roi, mask=mask)
    ys_nz, xs_nz = np.where(mask>0)
    return patch[ys_nz.min():ys_nz.max()+1, xs_nz.min():xs_nz.max()+1]


def predict_angle(model: nn.Module, patch: np.ndarray) -> float:
    rgb = cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(rgb)
    tensor = TRANSFORM(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        out = model(tensor)
        idx = torch.argmax(out, dim=1).item()
    return float(CLASS_NAMES[idx])


def is_valid_rotation(angle: float, tol: int=3) -> bool:
    return any(abs(angle-base)<=tol for base in CLASS_NAMES)

# %%
# ---------------------------
# PIPELINE: PROCESS ONE BATCH
# ---------------------------
def process_batch(batch_dir: Path, model: nn.Module):
    # Paths
    coco_in = batch_dir / "annotations" / "instances_default.json"
    ococo = batch_dir / "annotations" / "instances_obbs.json"
    pred_out = batch_dir / "annotations" / "instances_predicted.json"
    imgs_dir = batch_dir / "images" / "default"
    debug_dir = DEBUG_ROOT / batch_dir.name
    debug_dir.mkdir(parents=True, exist_ok=True)

    # Load & convert to OBB
    coco = load_coco(coco_in)
    coco = replace_obb(coco)
    save_coco(coco, ococo)

    # Prepare prediction copy
    coco_pred = deepcopy(coco)
    images = {img['id']: img for img in coco['images']}
    cache = {}
    records = []

    # Iterate all annotations
    for ann, ann_pred in zip(coco['annotations'], coco_pred['annotations']):
        cx,cy,w,h,orig = ann['bbox']
        if not is_valid_rotation(orig):
            continue
        fname = images[ann['image_id']]['file_name']
        path = imgs_dir / fname
        if not path.exists():
            logger.error(f"Missing image {path}"); continue
        img = cache.get(fname) if fname in cache else cv2.imread(str(path))
        cache[fname] = img
        patch = crop_obb_exact_mask_trim(img, cx, cy, w, h, orig)
        out_file = debug_dir / f"{Path(fname).stem}_{ann['id']}.png"
        cv2.imwrite(str(out_file), patch)
        pred = predict_angle(model, patch)
        records.append({"id": ann['id'], "orig": orig, "pred": pred})
        if pred != orig:
            ann_pred['bbox'][4] = pred
            ann_pred.setdefault('attributes', {})['rotation'] = pred
    # Save predictions & summary
    save_coco(coco_pred, pred_out)
    df = pd.DataFrame(records)
    df.to_csv(batch_dir / "results.csv", index=False)
    logger.info(f"Finished batch {batch_dir.name}, processed {len(records)} objects.")

# %%
# ---------------------------
# MAIN: LOOP ALL BATCHES
# ---------------------------
if __name__ == "__main__":
    model = load_model(CHECKPOINT_PATH)
    for batch in BATCHES_DIR.iterdir():
        if batch.is_dir() and batch.name.startswith("rotation_"):
            logger.info(f"=== Processing {batch.name} ===")
            process_batch(batch, model)
    logger.info("All batches processed.")


INFO:__main__:Loading model from checkpoints/best_model.pth
INFO:__main__:=== Processing rotation_20250721_01 ===
INFO:__main__:Loading COCO from ../data/rotation/batches/rotation_20250721_01/annotations/instances_default.json
INFO:__main__:Writing COCO to ../data/rotation/batches/rotation_20250721_01/annotations/instances_obbs.json
