In [None]:
import json
import logging
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms

# Configure logging for better debug information
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# CONFIGURATION

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Paths (customize these)
COCO_JSON        = Path("../data/rotation/instances_unknown.json")
IMAGES_DIR       = Path("../data/rotation/images/default")
OUTPUT_JSON      = Path("../data/rotation/instances_predicted.json")
CHECKPOINT_PATH  = Path("checkpoints/best_model.pth")

# Class labels in the exact order your model was trained on
CLASS_NAMES = [0, 90, 180, 270]

In [10]:


# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing for the classifier
TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


In [11]:


# COCO I/O

def load_coco(json_path: Path) -> dict:
    logger.info(f"Loading COCO annotations from {json_path}")
    with open(json_path, 'r', encoding='utf-8') as f:
        return json.load(f)

    
def save_coco(coco: dict, out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    logger.info(f"Writing updated COCO to {out_path}")
    with open(out_path, 'w', encoding='utf-8') as f:
        json.dump(coco, f, ensure_ascii=False, indent=2)

In [12]:



# MODEL LOADING

def load_model(checkpoint_path: Path) -> torch.nn.Module:
    logger.info(f"Loading model from {checkpoint_path}")
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, len(CLASS_NAMES))
    checkpoint = torch.load(str(checkpoint_path), map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(DEVICE).eval()
    return model


In [13]:


# CROPPING & INFERENCE

def crop_box(img: np.ndarray, x: float, y: float, w: float, h: float) -> np.ndarray:
    """
    Crop an axis-aligned box [x, y, w, h] from img.
    """
    x1 = int(round(x))
    y1 = int(round(y))
    x2 = int(round(x + w))
    y2 = int(round(y + h))
    # clamp to image bounds
    x1, y1 = max(0, x1), max(0, y1)
    x2 = min(img.shape[1], x2)
    y2 = min(img.shape[0], y2)
    return img[y1:y2, x1:x2]


def predict_angle(model: torch.nn.Module, patch: np.ndarray) -> int:
    """
    Run the classifier on the BGR-uint8 patch and return one of CLASS_NAMES.
    """
    # to PIL-like RGB
    patch_rgb = cv2.cvtColor(patch, cv2.COLOR_BGR2RGB)
    # apply transforms
    tensor = TRANSFORM(patch_rgb).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(tensor)
        idx = int(logits.argmax(dim=1))
    return CLASS_NAMES[idx]

In [14]:


#  UPDATE LOOP

def update_rotations():
    coco = load_coco(COCO_JSON)
    model = load_model(CHECKPOINT_PATH)

    # build lookup for images
    images = {img['id']: img for img in coco.get('images', [])}
    # cache loaded cv2 images
    img_cache: dict = {}

    for ann in coco.get('annotations', []):
        bbox = ann.get('bbox', [])
        if len(bbox) != 4:
            logger.warning(f"Skipping annotation {ann.get('id')} with unexpected bbox {bbox}")
            continue

        x, y, w, h = bbox
        img_id = ann['image_id']
        img_info = images.get(img_id)
        if img_info is None:
            logger.warning(f"No image metadata for id {img_id}")
            continue

        fname = img_info['file_name']
        if fname not in img_cache:
            img_path = IMAGES_DIR / fname
            img = cv2.imread(str(img_path))
            if img is None:
                logger.error(f"Failed to load image {img_path}")
                continue
            img_cache[fname] = img

        # crop & predict
        patch = crop_box(img_cache[fname], x, y, w, h)
        pred_angle = predict_angle(model, patch)
        logger.debug(f"Ann {ann['id']}: predicted {pred_angle}°")

        # update to OBB [cx, cy, w, h, angle]
        cx = x + w/2
        cy = y + h/2
        ann['bbox'] = [cx, cy, w, h, float(pred_angle)]
        # also store back into attributes.rotation
        ann.setdefault('attributes', {})['rotation'] = float(pred_angle)

    save_coco(coco, OUTPUT_JSON)
    logger.info("All annotations updated with predicted rotations.")


if __name__ == "__main__":
    #update_rotations()
    print("UNCOMMENT")

UNCOMMENT
