In [14]:
import json
import logging
from pathlib import Path
from copy import deepcopy

import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
from PIL import Image

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.transforms import Affine2D
import pandas as pd
from typing import Dict, Any, List
from matplotlib.lines import Line2D
from torchvision.transforms import InterpolationMode

In [15]:



# ---------------------------
# LOGGING & CONFIG
# ---------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# CONFIGURATION
COCO_JSON        = Path("../data/rotation/batches/rotation_20250721_01/annotations/instances_default.json")
IMAGES_DIR       = Path("../data/rotation/batches/rotation_20250721_01/images/default/")
PRED_JSON        = Path("../data/rotation/batches/rotation_20250721_01/annotations/instances_predicted.json")
COCO_5OBB       = Path("../data/rotation/batches/rotation_20250721_01/annotations/instances_updated.json")
CHECKPOINT_PATH = Path("checkpoints/best_model.pth")

BATCHES_DIR = Path("../data/rotation/batches/")

DEBUG_IMAGES_DEST = Path("../pipeline/debug_imgs/")



CLASS_NAMES = [0, 180, 270, 90 ]
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

IMAGE_SIZE = 300

TRANSFORM = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE),
    interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [16]:



#instead of hardcoding the batchname: rotation_20250721_01 it should run through all batches dynamically and perform the predictions and in the end have a new file with the new angles# ---------------------------
# I/O
# ---------------------------
def load_coco(path: Path) -> dict:
    logger.info(f"Loading COCO from {path}")
    return json.loads(path.read_text(encoding="utf-8"))

def save_coco(coco: dict, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    logger.info(f"Writing COCO to {path}")
    path.write_text(json.dumps(coco, ensure_ascii=False, indent=2), encoding="utf-8")


In [17]:

# ---------------------------
# MODEL
# ---------------------------
def load_model(ckpt_path: Path) -> nn.Module:
    logger.info(f"Loading model from {ckpt_path}")
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
    ckpt = torch.load(str(ckpt_path), map_location=DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    return model.to(DEVICE).eval()

In [18]:

def create_obb_tuple(anns: Dict[str, Any]) -> None:
    """
    Convert axis-aligned bbox [x, y, w, h] to oriented bbox [cx, cy, w, h, angle].
    Added validation and debug logging in place of print statements.
    """
    bbox = anns.get("bbox")
    if isinstance(bbox, list) and len(bbox) == 4:
        x, y, w, h = bbox
        cx = x + w / 2
        cy = y + h / 2
        angle = anns.get("attributes", {}).get("rotation", 0.0)
        anns["bbox"] = [cx, cy, w, h, angle]
        logger.debug(f"Converted bbox to OBB: {anns['bbox']} for annotation id {anns.get('id')}")
    else:
        logger.warning(f"Unexpected bbox format for annotation id {anns.get('id')}: {bbox}")

In [19]:

def replace_obb(coco: Dict[str, Any], batch_dir: Path) -> None:
    """
    Apply create_obb_tuple to all annotations and write updated JSON efficiently.
    """
    logger.info("Replacing OBBs in annotations")
    for anns in coco.get('annotations', []):
        create_obb_tuple(anns)

    #out_path = batch_dir / "annotations" / "instances_updated.json"
    out_path = COCO_5OBB
    logger.info(f"Writing updated annotations to {out_path}")
    with open(out_path, 'w', encoding='utf-8') as f:
        # Use json.dump for more efficient writing
        json.dump(coco, f, ensure_ascii=False)

In [20]:





def convert_all_batches() -> None:
    """
    Iterate over all batch directories and update OBBs.
    """
    for p in BATCHES_DIR.iterdir():
        if "rotation" in p.name:
            json_path = p / "annotations" / "instances_default.json"
            if json_path.exists():
                logger.info(f"Processing batch: {p.name}")
                coco = load_coco(json_path)
                replace_obb(coco, p)
            else:
                logger.warning(f"Missing JSON at {json_path}")
                


convert_all_batches()# ---------------------------
# HELPERS
# ---------------------------
def extract_current_rotation(ann: dict) -> float:
    bb = ann.get("bbox", [])
    if len(bb) == 5:
        return float(bb[4])
    attrs = ann.get("attributes", {})
    return float(attrs.get("rotation", 0.0))

INFO:__main__:Processing batch: rotation_20250721_01
INFO:__main__:Loading COCO from ../data/rotation/batches/rotation_20250721_01/annotations/instances_default.json
INFO:__main__:Replacing OBBs in annotations
INFO:__main__:Writing updated annotations to ../data/rotation/batches/rotation_20250721_01/annotations/instances_updated.json


In [21]:



def crop_obb_exact_mask_trim(
        img   : np.ndarray,
        cx    : float,
        cy    : float,
        w     : float,
        h     : float,
        angle : float,
        pad   : int = 0
) -> np.ndarray:
    """Same behaviour as crop_obb_exact_mask but returns *no* black border."""
    theta  = np.deg2rad(angle)                      # CCW positive
    ct, st = np.cos(theta), np.sin(theta)

    local  = np.float32([[-w/2, -h/2],
                         [ w/2, -h/2],
                         [ w/2,  h/2],
                         [-w/2,  h/2]])
    R      = np.float32([[ct, -st], [st, ct]])
    poly   = (local @ R.T) + np.float32([cx, cy])

    xs, ys = poly[:, 0], poly[:, 1]
    x0, x1 = int(np.floor(xs.min())) - pad, int(np.ceil(xs.max())) + pad
    y0, y1 = int(np.floor(ys.min())) - pad, int(np.ceil(ys.max())) + pad
    x0 = max(0, x0); y0 = max(0, y0)
    x1 = min(img.shape[1]-1, x1); y1 = min(img.shape[0]-1, y1)

    roi   = img[y0:y1+1, x0:x1+1].copy()
    poly_roi = (poly - np.float32([x0, y0])).astype(np.int32)

    mask  = np.zeros(roi.shape[:2], np.uint8)
    cv2.fillPoly(mask, [poly_roi], 255)
    patch = cv2.bitwise_and(roi, roi, mask=mask)

    # ── NEW: trim away any all-black rows / cols ───────────────────────────
    ys_nonzero, xs_nonzero = np.where(mask > 0)
    patch = patch[ys_nonzero.min():ys_nonzero.max()+1,
                  xs_nonzero.min():xs_nonzero.max()+1]
    # ─────────────────────────────────────────────────────────────────────── works good
    return patch


def predict_angle(model: nn.Module, patch: np.ndarray) -> float:
    rgb = cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)
    pil = Image.fromarray(rgb)
    tensor = TRANSFORM(pil).unsqueeze(0).to(DEVICE)
    #print(f"Tensor {tensor}")
    with torch.no_grad():
        prediction = model(tensor)
        #print(f"logits {prediction}")

        idx = torch.argmax(prediction, dim=1).item()
        #print(f"idx {idx}")

    return float(CLASS_NAMES[idx])


def is_valid_rotation(angle, base_angles=CLASS_NAMES, tolerance=3):
    return any(abs(angle - base) <= tolerance for base in base_angles)

#TODO: Lower Tolerance and check if it works on both directions
def load_image_if_needed(fname, cache):
    if fname in cache:
        return cache[fname]

    img = cv2.imread(str(IMAGES_DIR / fname))
    if img is None:
        logger.error(f"Cannot load {fname}")
        return None

    cache[fname] = img
    return img, 



In [22]:
# ---------------------------
# MAIN UPDATE
# ---------------------------
def update_rotations(target_image_id: int = 1, debug: bool = False):
    coco_default_4obb = load_coco(COCO_JSON)
    coco_updated_5obb = load_coco(COCO_5OBB)
    coco_pred_5obb = deepcopy(coco_updated_5obb)
    model = load_model(CHECKPOINT_PATH)

    images = {img["id"]: img for img in coco_default_4obb["images"]}
    cache = {}
    records = []
    ax = None  # only initialized if debug=True

    for ann, ann_pred in zip(coco_updated_5obb["annotations"], coco_pred_5obb["annotations"]):
        if ann["image_id"] != target_image_id:
            continue

        cx, cy, w, h, orig_rot = ann["bbox"]
        if not is_valid_rotation(orig_rot):
            continue

        fname = images[ann["image_id"]]["file_name"]
        full_img = load_image_if_needed(fname, cache)
        if full_img is None:
            continue

        
        ann_id=ann["id"]
        rotated_box = crop_obb_exact_mask_trim(full_img, cx, cy, w, h, orig_rot,
                       pad=0) 
        
        out_file = DEBUG_IMAGES_DEST / f"{Path(fname).stem}_{ann_id}.png"
        cv2.imwrite(str(out_file), rotated_box)

        
        pred_rot = predict_angle(model, rotated_box)

        records.append({
            "ann_id": ann["id"],
            "img_id": ann["image_id"],
            "file_name": fname,
            "bbox": [cx, cy, w, h, orig_rot],
            "orig_rot": orig_rot,
            "pred_rot": pred_rot
        })
        
        
        
        
        logger.info(f" Original {orig_rot}° → Predicted {pred_rot}° Ann ID {ann_id}")


        # 🟢 Update angle only if it changed (rest of bbox stays the same)
        if pred_rot != orig_rot:


            ann_pred["bbox"] = [cx, cy, w, h, pred_rot]
            ann_pred.setdefault("attributes", {})["rotation"] = pred_rot
        else:
            ann_pred["bbox"] = [cx, cy, w, h, orig_rot]


            

    # 💾 Save updated 5OBB predictions
    save_coco(coco_pred_5obb, PRED_JSON)

    # 🔍 Log summary
    df = pd.DataFrame(records)
    df.to_csv("./results.csv")
    print("\n=== Summary ===")
    print(df.to_string(index=False))




if __name__ == "__main__":
    update_rotations(target_image_id=2, debug=True)
    

INFO:__main__:Loading COCO from ../data/rotation/batches/rotation_20250721_01/annotations/instances_default.json
INFO:__main__:Loading COCO from ../data/rotation/batches/rotation_20250721_01/annotations/instances_updated.json
INFO:__main__:Loading model from checkpoints/best_model.pth


AttributeError: 'tuple' object has no attribute 'shape'