# Dense Object Detection on SKU-110K (RetinaNet & LocalRetinaNet)

This notebook contains a **complete training & evaluation pipeline** for:

- A **torchvision RetinaNet baseline** (`retinanet_resnet50_fpn`)
- A custom **LocalRetinaNet** model with:
  - ResNet-50 + FPN backbone (P3–P5)
  - Multi-level anchors
  - Local correlation-aware focal loss



## 1. Imports

In [24]:
import os
import zipfile
import json
import math
import datetime
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fnn
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision.transforms import functional as F
from torchvision.ops import box_iou

from PIL import Image
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

from google.colab import drive

print("PyTorch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("CUDA available:", torch.cuda.is_available())
print("MPS available:", torch.backends.mps.is_available())


ModuleNotFoundError: No module named 'google'

## 2. Download and Mount Pretrained Models

In [None]:
PRETRAINED_DIR = "./output/pt-models"
# Example: https://drive.google.com/file/d/FILE_ID/view?usp=sharing
PRETRAINED_FILE_ID = "1M8TJoZ-P8wiU1KLGlSDpRUggNavzCWia"
ZIP_PATH = "./pt-models.zip"

os.makedirs("./output", exist_ok=True)
os.makedirs(PRETRAINED_DIR, exist_ok=True)

existing_files = [f for f in os.listdir(PRETRAINED_DIR) if f.endswith(".pth")]

if len(existing_files) > 0:
    print("Pretrained models already available:")
    for f in existing_files:
        print(" •", f)
    print("\nSkipping download...")
else:
    drive.mount('/content/drive')

    print("Downloading pretrained models from Google Drive...")
    !gdown --id $PRETRAINED_FILE_ID -O $ZIP_PATH

    print("Extracting ZIP...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        for member in zip_ref.namelist():
            # Skip macOS metadata
            if member.startswith('__MACOSX') or member.endswith('.DS_Store'):
                continue

            # Extract into output/pt-models while flattening structure
            filename = os.path.basename(member)
            if not filename:
                continue  # skip folders

            source = zip_ref.open(member)
            target_path = os.path.join(PRETRAINED_DIR, filename)

            with open(target_path, "wb") as target:
                with source as src:
                    target.write(src.read())

    print("Extraction complete!")

    if os.path.exists(ZIP_PATH):
        os.remove(ZIP_PATH)
        print(f"Removed ZIP file: {ZIP_PATH}")

## 3. Download and Mount Dataset

In [None]:
DATA_DIR = "./data/SKU110K_modified"
ZIP_PATH = "./SKU110K_modified.zip"
# Example: https://drive.google.com/file/d/FILE_ID/view?usp=sharing
FILE_ID = "1QrZ6zTbOSiE28TQkBExb4Fa7EM6i5mfr"

# If folder already exists, skip download
if os.path.exists(DATA_DIR):
    print("Dataset already exists — skipping download.")
else:
    os.makedirs("./data", exist_ok=True)

    from google.colab import drive
    drive.mount('/content/drive')
    
    print("Downloading dataset ZIP...")
    !gdown --id $FILE_ID -O $ZIP_PATH

    print("Extracting ZIP...")

    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        for member in zip_ref.infolist():

            # Skip macOS junk files
            if member.filename.startswith("__MACOSX") or member.filename.endswith(".DS_Store"):
                continue

            zip_ref.extract(member, "./data")

    print("Extraction complete!")

    # Remove ZIP file
    if os.path.exists(ZIP_PATH):
        os.remove(ZIP_PATH)
        print(f"Removed ZIP file: {ZIP_PATH}")

## 4. Utility functions

In [2]:
# ======================================================================================
# UTILS
# ======================================================================================

def timestamp():
    return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")


def ensure_dir(path):
    Path(path).mkdir(parents=True, exist_ok=True)


def fix_coco_json(path):
    """
    Add missing fields so pycocotools does not crash on annotation JSON.
    """
    with open(path, "r") as f:
        data = json.load(f)

    # Prediction JSON is a list -> ignore
    if isinstance(data, list):
        return

    changed = False

    if "info" not in data:
        data["info"] = {}
        changed = True

    if "licenses" not in data:
        data["licenses"] = []
        changed = True

    if changed:
        with open(path, "w") as f:
            json.dump(data, f)


def pad_images_to_batch(images):
    """
    Pad list of CHW images to max height and max width within the batch.
    Keeps content in top-left, pads bottom/right with zeros.
    """
    import torch.nn.functional as F_pad

    max_h = max(img.shape[1] for img in images)
    max_w = max(img.shape[2] for img in images)

    padded = []
    for img in images:
        _, h, w = img.shape
        pad_bottom = max_h - h
        pad_right = max_w - w
        img_padded = F_pad.pad(img, (0, pad_right, 0, pad_bottom))
        padded.append(img_padded)

    return torch.stack(padded)

## 5. Configuration

In [None]:
# ======================================================================================
# CONFIGURATION
# ======================================================================================

class Config:
    def __init__(self):
        # Dataset paths
        self.train_images = "./data/SKU110K_modified/images"
        self.val_images = "./data/SKU110K_modified/images"
        self.test_images = "./data/SKU110K_modified/images"

        self.train_annotations = "./data/SKU110K_modified/annotations/COCO_json/annotations_train.json"
        self.val_annotations = "./data/SKU110K_modified/annotations/COCO_json/annotations_val.json"
        self.test_annotations = "./data/SKU110K_modified/annotations/COCO_json/annotations_test.json"

        # Inference
        self.infer_image_path = "./data/SKU110K_modified/images/test_0.jpg"

        # Local RetinaNet (correlation-aware) hyperparameters
        self.num_classes_local = 1        # single "product" class
        self.lambda_reg_local = 0.1       # weight between cls and reg for local model
        self.num_epochs_local = 5
        self.batch_size_local = 1         # memory-safe for LocalRetinaNet

        # Baseline RetinaNet hyperparameters
        self.num_classes_retina = 2       # torchvision RetinaNet classes (background + product)
        self.num_epochs_retina = 5
        self.batch_size_retina = 1

        self.lr = 1e-4
        self.num_workers = 0

        # Save paths
        self.save_local_model_path = "./output/pt-models/retinanet_local_sku110k.pth"
        self.save_retinanet_model_path = "./output/pt-models/retinanet_sku110k.pth"

        self.device = torch.device(
            "mps" if torch.backends.mps.is_available()
            else "cuda" if torch.cuda.is_available()
            else "cpu"
        )


config = Config()
print("Using device:", config.device)


## 6. Dataset & transforms

In [4]:
# ======================================================================================
# DATASET (COCO FORMAT)
# ======================================================================================

class SKU110K_COCO(Dataset):
    def __init__(self, root, annotation_json, transforms=None):
        self.root = root
        self.transforms = transforms

        with open(annotation_json, "r") as f:
            data = json.load(f)

        self.images = {img["id"]: img for img in data["images"]}
        self.ids = sorted(self.images.keys())

        self.annos = {img_id: [] for img_id in self.ids}
        for ann in data["annotations"]:
            self.annos[ann["image_id"]].append(ann)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        info = self.images[img_id]

        fname = info["file_name"].split("/")[-1]
        img_path = os.path.join(self.root, fname)
        img = Image.open(img_path).convert("RGB")

        boxes = []
        labels = []
        for ann in self.annos[img_id]:
            x, y, w, h = ann["bbox"]
            boxes.append([x, y, x + w, y + h])
            labels.append(1)  # single product class (id=1)

        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64),
            "image_id": torch.tensor(int(img_id)),
        }

        if self.transforms:
            img, target = self.transforms(img, target)
        else:
            img = F.to_tensor(img)

        return img, target

    def __len__(self):
        return len(self.ids)


def collate_fn(batch):
    return tuple(zip(*batch))


class ToTensorOnly:
    """Transform that converts PIL to tensor and leaves boxes unchanged."""
    def __call__(self, img, target):
        img = F.to_tensor(img)
        return img, target


class ResizeForDetection:
    def __init__(self, max_side=1024, stride=32):
        self.max_side = max_side
        self.stride = stride  # ensure divisibility for FPN

    def __call__(self, img, target):
        # img is PIL
        w, h = img.size
        scale = self.max_side / max(h, w)
        if scale < 1.0:
            new_w = int(w * scale)
            new_h = int(h * scale)

            # Make divisible by FPN stride
            new_w = (new_w // self.stride) * self.stride
            new_h = (new_h // self.stride) * self.stride

            img = img.resize((new_w, new_h))

            # Resize boxes too:
            boxes = target["boxes"]
            boxes = boxes * scale
            target["boxes"] = boxes

        img = F.to_tensor(img)
        return img, target


## 7. Torchvision RetinaNet baseline

In [5]:
# ======================================================================================
# BASELINE RETINANET (torchvision)
# ======================================================================================

def create_retinanet(num_classes):
    model = torchvision.models.detection.retinanet_resnet50_fpn(weights="DEFAULT")
    num_anchors = model.head.classification_head.num_anchors

    # replace cls head to match num_classes
    model.head.classification_head.num_classes = num_classes
    model.head.classification_head.cls_logits = nn.Conv2d(
        256, num_anchors * num_classes, kernel_size=3, padding=1
    )

    torch.nn.init.normal_(model.head.classification_head.cls_logits.weight, std=0.01)
    torch.nn.init.constant_(model.head.classification_head.cls_logits.bias, -4.0)
    return model


def train_one_epoch_retinanet(model, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0

    for step, (images, targets) in enumerate(loader):
        images = [img.to(device) for img in images]
        targets = [
            {
                "boxes": t["boxes"].to(device),
                "labels": t["labels"].to(device),
            }
            for t in targets
        ]

        loss_dict = model(images, targets)
        loss = sum(loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if step % 10 == 0:
            print(
                f"[RetinaNet-Baseline][Epoch {epoch}] Step {step} "
                f"loss: {loss.item():.4f}"
            )

    print(f"{timestamp()} — RetinaNet Baseline Epoch {epoch} Avg Loss: {total_loss / len(loader):.4f}")


### 5.1 Baseline inference helpers

In [6]:
def run_inference_retinanet(model_path, image_path, config, save_output=False):
    device = config.device
    print(f"Loading RetinaNet baseline: {model_path}")

    model = create_retinanet(config.num_classes_retina)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    img_bgr = cv2.imread(image_path)
    if img_bgr is None:
        raise ValueError("Image not found: " + image_path)

    rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    tensor = F.to_tensor(rgb).to(device)

    with torch.no_grad():
        out = model([tensor])[0]

    vis = img_bgr.copy()
    for box, score in zip(out["boxes"], out["scores"]):
        if score < 0.4:
            continue
        x1, y1, x2, y2 = map(int, box.tolist())
        cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(vis, f"{score:.2f}", (x1, y1 - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    if save_output:
        ensure_dir("./output/inference_retina_baseline")
        out_path = "./output/inference_retina_baseline/" + Path(image_path).stem + "_retina_pred.jpg"
        cv2.imwrite(out_path, vis)
        print("Saved:", out_path)

    plt.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    plt.show()


def batch_inference_retinanet(model_path, folder, config,
                              save_dir="./output/inference_retina_baseline"):
    device = config.device
    print("\n=== Batch Inference: RetinaNet Baseline ===")
    ensure_dir(save_dir)

    model = create_retinanet(config.num_classes_retina)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    image_files = [f for f in os.listdir(folder)
                   if f.lower().endswith((".jpg", ".jpeg", ".png"))]

    with torch.no_grad():
        for file in tqdm(image_files, desc="RetinaNet Baseline Batch Infer"):
            path_img = os.path.join(folder, file)
            img_bgr = cv2.imread(path_img)
            if img_bgr is None:
                print("Skipping unreadable image:", file)
                continue

            rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
            tensor = F.to_tensor(rgb).to(device)

            out = model([tensor])[0]
            vis = img_bgr.copy()

            for box, score in zip(out["boxes"], out["scores"]):
                if score < 0.4:
                    continue
                x1, y1, x2, y2 = map(int, box.tolist())
                cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(vis, f"{score:.2f}", (x1, y1 - 5),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

            save_path = os.path.join(save_dir, f"{Path(file).stem}_retina_pred.jpg")
            cv2.imwrite(save_path, vis)

    print("\n=== RetinaNet Baseline Batch Inference Completed ===")


## 8. Local focal loss & anchor utilities

In [7]:
# ======================================================================================
# LOCAL CORRELATION-AWARE FOCAL LOSS
# ======================================================================================

class LocalFocalLoss2d(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, kernel_size=5, lambda_local=1.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.kernel_size = kernel_size
        self.lambda_local = lambda_local
        self.reduction = reduction
        self.avg_pool = nn.AvgPool2d(kernel_size, stride=1, padding=kernel_size // 2)

    def forward(self, logits, targets):
        logits = logits.float()
        targets = targets.float()

        prob = torch.sigmoid(logits)
        p_t = prob * targets + (1 - prob) * (1 - targets)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)

        eps = 1e-8
        focal_weight = alpha_t * (1 - p_t).pow(self.gamma)
        base_loss = -focal_weight * torch.log(p_t.clamp(min=eps))

        base_loss_spatial = base_loss.mean(dim=1, keepdim=True)
        local_hardness = self.avg_pool(base_loss_spatial)

        global_mean = local_hardness.mean().detach()
        h_norm = local_hardness / (global_mean + eps)

        local_weight = 1 + self.lambda_local * (h_norm - 1)
        local_weight = torch.clamp(local_weight, min=0.1, max=3.0)
        local_weight = local_weight.expand_as(base_loss)

        final_loss = local_weight * base_loss

        if self.reduction == "mean":
            return final_loss.mean()
        elif self.reduction == "sum":
            return final_loss.sum()
        return final_loss


# ======================================================================================
# BOX ENCODING / NMS / ANCHORS
# ======================================================================================

def encode_boxes(anchors, gt_boxes):
    ax = (anchors[:, 0] + anchors[:, 2]) / 2
    ay = (anchors[:, 1] + anchors[:, 3]) / 2
    aw = (anchors[:, 2] - anchors[:, 0])
    ah = (anchors[:, 3] - anchors[:, 1])

    gx = (gt_boxes[:, 0] + gt_boxes[:, 2]) / 2
    gy = (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2
    gw = (gt_boxes[:, 2] - gt_boxes[:, 0])
    gh = (gt_boxes[:, 3] - gt_boxes[:, 1])

    tx = (gx - ax) / aw
    ty = (gy - ay) / ah
    tw = torch.log(gw / aw)
    th = torch.log(gh / ah)
    return torch.stack([tx, ty, tw, th], dim=1)


def decode_boxes(anchors, deltas):
    ax = (anchors[:, 0] + anchors[:, 2]) / 2
    ay = (anchors[:, 1] + anchors[:, 3]) / 2
    aw = (anchors[:, 2] - anchors[:, 0])
    ah = (anchors[:, 3] - anchors[:, 1])

    tx, ty, tw, th = deltas.unbind(dim=1)

    gx = tx * aw + ax
    gy = ty * ah + ay
    gw = aw * torch.exp(tw)
    gh = ah * torch.exp(th)

    x1 = gx - gw / 2
    y1 = gy - gh / 2
    x2 = gx + gw / 2
    y2 = gy + gh / 2
    return torch.stack([x1, y1, x2, y2], dim=1)


def assign_anchors_to_gt(anchors, gt_boxes, iou_pos_thresh=0.5, iou_neg_thresh=0.4):
    A = anchors.size(0)
    device = anchors.device

    labels = torch.full((A,), -1, dtype=torch.int64, device=device)
    matched_gt_boxes = torch.zeros((A, 4), dtype=torch.float32, device=device)

    if gt_boxes.numel() == 0:
        labels[:] = 0
        return labels, matched_gt_boxes

    ious = box_iou(anchors, gt_boxes)
    max_iou, max_idx = ious.max(dim=1)

    labels[max_iou < iou_neg_thresh] = 0
    labels[max_iou >= iou_pos_thresh] = 1

    matched_gt_boxes[:] = gt_boxes[max_idx]
    return labels, matched_gt_boxes


def nms(boxes, scores, threshold=0.5):
    return torchvision.ops.nms(boxes, scores, threshold)


class MultiLevelAnchorGenerator:
    """Multi-scale anchors for FPN levels P3..P5."""
    def __init__(self, sizes_per_level, ratios, strides):
        self.sizes_per_level = sizes_per_level
        self.ratios = ratios
        self.strides = strides

    def _grid_anchors(self, grid_size, stride, sizes, device):
        H, W = grid_size
        shifts_x = (torch.arange(W, device=device) + 0.5) * stride
        shifts_y = (torch.arange(H, device=device) + 0.5) * stride
        shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")

        shift_x = shift_x.reshape(-1)
        shift_y = shift_y.reshape(-1)

        anchor_boxes = []
        for size in sizes:
            for ratio in self.ratios:
                size = float(size)
                ratio = float(ratio)
                w = size * math.sqrt(ratio)
                h = size / math.sqrt(ratio)
                anchors = torch.stack([
                    shift_x - w / 2,
                    shift_y - h / 2,
                    shift_x + w / 2,
                    shift_y + h / 2
                ], dim=1)
                anchor_boxes.append(anchors)

        return torch.cat(anchor_boxes, dim=0)

    def __call__(self, feature_shapes, device):
        anchors_per_level = []
        for (H, W), stride, sizes in zip(feature_shapes, self.strides, self.sizes_per_level):
            anchors = self._grid_anchors((H, W), stride, sizes, device)
            anchors_per_level.append(anchors)
        return anchors_per_level


## 9. LocalRetinaNet model

In [8]:
# ======================================================================================
# BACKBONE + FPN FOR LOCAL RETINANET (P3–P5 ONLY)
# ======================================================================================

class ResNetFPN(nn.Module):
    """ResNet50 backbone with FPN producing P3, P4, P5."""
    def __init__(self, backbone_name="resnet50"):
        super().__init__()
        if backbone_name == "resnet50":
            backbone = torchvision.models.resnet50(weights="DEFAULT")
            c3_channels = 512
            c4_channels = 1024
            c5_channels = 2048
        else:
            backbone = torchvision.models.resnet18(weights=None)
            c3_channels = 128
            c4_channels = 256
            c5_channels = 512

        self.conv1 = backbone.conv1
        self.bn1 = backbone.bn1
        self.relu = backbone.relu
        self.maxpool = backbone.maxpool

        self.layer1 = backbone.layer1  # C2
        self.layer2 = backbone.layer2  # C3
        self.layer3 = backbone.layer3  # C4
        self.layer4 = backbone.layer4  # C5

        self.lateral3 = nn.Conv2d(c3_channels, 256, 1)
        self.lateral4 = nn.Conv2d(c4_channels, 256, 1)
        self.lateral5 = nn.Conv2d(c5_channels, 256, 1)

        self.p3_conv = nn.Conv2d(256, 256, 3, padding=1)
        self.p4_conv = nn.Conv2d(256, 256, 3, padding=1)
        self.p5_conv = nn.Conv2d(256, 256, 3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        c1 = self.maxpool(x)

        c2 = self.layer1(c1)
        c3 = self.layer2(c2)
        c4 = self.layer3(c3)
        c5 = self.layer4(c4)

        p5 = self.lateral5(c5)
        p4 = self.lateral4(c4) + Fnn.interpolate(p5, size=c4.shape[-2:], mode="nearest")
        p3 = self.lateral3(c3) + Fnn.interpolate(p4, size=c3.shape[-2:], mode="nearest")

        p3 = self.p3_conv(p3)
        p4 = self.p4_conv(p4)
        p5 = self.p5_conv(p5)

        return [p3, p4, p5]


# ======================================================================================
# RETINANET HEAD
# ======================================================================================

class RetinaNetHead(nn.Module):
    def __init__(self, in_channels, num_anchors, num_classes):
        super().__init__()
        self.num_anchors = num_anchors
        self.num_classes = num_classes

        cls_layers = []
        for _ in range(4):
            cls_layers.append(nn.Conv2d(in_channels, in_channels, 3, padding=1))
            cls_layers.append(nn.ReLU(inplace=True))
        self.cls_subnet = nn.Sequential(*cls_layers)
        self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, 3, padding=1)

        box_layers = []
        for _ in range(4):
            box_layers.append(nn.Conv2d(in_channels, in_channels, 3, padding=1))
            box_layers.append(nn.ReLU(inplace=True))
        self.box_subnet = nn.Sequential(*box_layers)
        self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, 3, padding=1)

        torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
        torch.nn.init.constant_(self.cls_logits.bias, -4.0)

        torch.nn.init.normal_(self.bbox_pred.weight, std=0.01)
        torch.nn.init.constant_(self.bbox_pred.bias, 0)

    def forward(self, features):
        cls_logits = []
        bbox_regs = []
        for f in features:
            cls = self.cls_subnet(f)
            cls_logits.append(self.cls_logits(cls))

            box = self.box_subnet(f)
            bbox_regs.append(self.bbox_pred(box))
        return cls_logits, bbox_regs


# ======================================================================================
# LOCAL RETINANET
# ======================================================================================

class LocalRetinaNet(nn.Module):
    def __init__(self, num_classes=1, lambda_reg=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.lambda_reg = lambda_reg

        self.backbone = ResNetFPN("resnet50")

        self.num_anchors = 9
        self.head = RetinaNetHead(256, self.num_anchors, num_classes)

        base_sizes = [32, 64, 128]
        scales = [1.0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]
        sizes_per_level = [[b * s for s in scales] for b in base_sizes]
        ratios = [0.5, 1.0, 2.0]
        strides = [8, 16, 32]

        self.anchor_gen = MultiLevelAnchorGenerator(sizes_per_level, ratios, strides)
        self.cls_loss_fn = LocalFocalLoss2d(alpha=0.25, gamma=2.0, lambda_local=1.0)
        self.reg_loss_fn = nn.SmoothL1Loss(reduction="sum")

    def _compute_loss_single_image(
        self,
        cls_logits_list,
        bbox_regs_list,
        anchors_per_level,
        gt_boxes,
        feature_shapes,
        b,
        device,
    ):
        num_levels = len(cls_logits_list)
        level_offsets = []
        logits_flat_by_level = []
        box_flat_by_level = []

        start = 0
        for lvl in range(num_levels):
            cls_l_b = cls_logits_list[lvl][b]
            box_l_b = bbox_regs_list[lvl][b]
            _, Hf, Wf = cls_l_b.shape
            A = self.num_anchors

            N_l = Hf * Wf * A

            cls_flat = cls_l_b.permute(1, 2, 0).reshape(N_l, self.num_classes)
            box_flat = box_l_b.permute(1, 2, 0).reshape(N_l, 4)

            logits_flat_by_level.append(cls_flat)
            box_flat_by_level.append(box_flat)

            end = start + N_l
            level_offsets.append((start, end, Hf, Wf))
            start = end

        logits_all = torch.cat(logits_flat_by_level, dim=0)
        box_all = torch.cat(box_flat_by_level, dim=0)
        anchors_all = torch.cat([a.to(device) for a in anchors_per_level], dim=0)

        labels, matched_gt_boxes = assign_anchors_to_gt(anchors_all, gt_boxes)
        fg_mask = labels == 1

        cls_target_all = torch.zeros_like(logits_all)
        if self.num_classes == 1:
            cls_target_all[fg_mask, 0] = 1.0
        else:
            raise NotImplementedError("Current pipeline assumes num_classes=1.")

        cls_loss = torch.tensor(0.0, device=device)
        for lvl in range(num_levels):
            start, end, Hf, Wf = level_offsets[lvl]
            targets_lvl_1d = cls_target_all[start:end]

            t_map = targets_lvl_1d.reshape(
                Hf, Wf, self.num_anchors * self.num_classes
            ).permute(2, 0, 1).unsqueeze(0)

            logits_lvl = cls_logits_list[lvl][b].unsqueeze(0)
            cls_loss = cls_loss + self.cls_loss_fn(logits_lvl, t_map)

        num_pos = int(fg_mask.sum().item())
        if num_pos > 0:
            pred_pos = box_all[fg_mask]
            tgt_pos = encode_boxes(anchors_all[fg_mask], matched_gt_boxes[fg_mask])
            reg_loss_sum = self.reg_loss_fn(pred_pos, tgt_pos)
        else:
            reg_loss_sum = torch.tensor(0.0, device=device)

        return cls_loss, reg_loss_sum, num_pos

    def forward(self, images, targets=None):
        device = images[0].device

        if not self.training or targets is None:
            return self._forward_inference(images)

        x = pad_images_to_batch(images).to(device)
        B = x.shape[0]

        features = self.backbone(x)
        cls_logits_list, bbox_regs_list = self.head(features)

        feature_shapes = [(f.shape[2], f.shape[3]) for f in features]
        anchors_per_level = self.anchor_gen(feature_shapes, device)

        total_cls_loss = torch.tensor(0.0, device=device)
        total_reg_sum = torch.tensor(0.0, device=device)
        total_pos = 0

        for b in range(B):
            gt_boxes = targets[b]["boxes"].to(device)
            cls_loss_b, reg_sum_b, num_pos_b = self._compute_loss_single_image(
                cls_logits_list,
                bbox_regs_list,
                anchors_per_level,
                gt_boxes,
                feature_shapes,
                b,
                device,
            )
            total_cls_loss = total_cls_loss + cls_loss_b
            total_reg_sum = total_reg_sum + reg_sum_b
            total_pos += num_pos_b

        if total_pos > 0:
            avg_reg_loss = total_reg_sum / float(total_pos)
        else:
            avg_reg_loss = torch.tensor(0.0, device=device)

        total_loss = total_cls_loss + self.lambda_reg * avg_reg_loss

        return {
            "loss_cls": total_cls_loss,
            "loss_reg": avg_reg_loss,
            "loss_total": total_loss,
        }

    def _forward_inference(self, images):
        device = images[0].device
        x = pad_images_to_batch(images).to(device)
        B = x.shape[0]

        features = self.backbone(x)
        cls_logits_list, bbox_regs_list = self.head(features)
        feature_shapes = [(f.shape[2], f.shape[3]) for f in features]
        anchors_per_level = self.anchor_gen(feature_shapes, device)

        outputs = []
        orig_sizes = [(img.shape[1], img.shape[2]) for img in images]

        for b in range(B):
            all_boxes = []
            all_scores = []

            for cls_l, box_l, anchors_l in zip(cls_logits_list, bbox_regs_list, anchors_per_level):
                cls_l_img = cls_l[b]
                box_l_img = box_l[b]

                A_C, Hf, Wf = cls_l_img.shape
                A = self.num_anchors
                N_l = Hf * Wf * A

                cls_flat = cls_l_img.permute(1, 2, 0).reshape(N_l, self.num_classes)
                box_flat = box_l_img.permute(1, 2, 0).reshape(N_l, 4)

                scores = torch.sigmoid(cls_flat[:, 0])
                keep = scores > 0.05
                if keep.sum() == 0:
                    continue

                scores_k = scores[keep]
                box_k = box_flat[keep]
                anchors_k = anchors_l[keep]

                decoded = decode_boxes(anchors_k, box_k)
                all_boxes.append(decoded)
                all_scores.append(scores_k)

            if len(all_boxes) == 0:
                outputs.append({
                    "boxes": torch.zeros((0, 4), device=device),
                    "scores": torch.zeros((0,), device=device),
                    "labels": torch.zeros((0,), dtype=torch.int64, device=device),
                })
                continue

            boxes_cat = torch.cat(all_boxes, dim=0)
            scores_cat = torch.cat(all_scores, dim=0)

            h, w = orig_sizes[b]
            boxes_cat[:, 0] = boxes_cat[:, 0].clamp(min=0, max=w - 1)
            boxes_cat[:, 2] = boxes_cat[:, 2].clamp(min=0, max=w - 1)
            boxes_cat[:, 1] = boxes_cat[:, 1].clamp(min=0, max=h - 1)
            boxes_cat[:, 3] = boxes_cat[:, 3].clamp(min=0, max=h - 1)

            keep_nms = nms(boxes_cat, scores_cat, 0.5)
            boxes_cat = boxes_cat[keep_nms]
            scores_cat = scores_cat[keep_nms]
            labels_cat = torch.ones_like(scores_cat, dtype=torch.int64, device=device)

            outputs.append({
                "boxes": boxes_cat,
                "scores": scores_cat,
                "labels": labels_cat,
            })

        return outputs


## 10. Training loop for LocalRetinaNet

In [9]:
# ======================================================================================
# TRAINING LOOP (LOCAL RETINANET)
# ======================================================================================

def train_one_epoch_local_retina(model, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0

    for step, (images, targets) in enumerate(loader):
        images = [img.to(device) for img in images]

        loss_dict = model(images, targets)
        loss = loss_dict["loss_total"]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if step % 10 == 0:
            print(
                f"[Local-RetinaNet][Epoch {epoch}] Step {step} "
                f"cls: {loss_dict['loss_cls'].item():.4f}, "
                f"reg: {loss_dict['loss_reg'].item():.4f}, "
                f"total: {loss.item():.4f}"
            )

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"{timestamp()} — Local RetinaNet Epoch {epoch} Avg Loss: {total_loss / len(loader):.4f}")


## 11. COCO evaluation utilities

In [10]:
# ======================================================================================
# COCO EVAL (shared for both models)
# ======================================================================================

def convert_to_coco_predictions(outputs, image_ids):
    results = []
    for out, img_id in zip(outputs, image_ids):
        for box, score, label in zip(out["boxes"], out["scores"], out["labels"]):
            x1, y1, x2, y2 = box.tolist()
            results.append({
                "image_id": int(img_id),
                "category_id": int(label),
                "bbox": [x1, y1, x2 - x1, y2 - y1],
                "score": float(score),
            })
    return results


def evaluate_coco(model, loader, ann_file, device, out_name):
    fix_coco_json(ann_file)
    print(f"Evaluating COCO mAP → {out_name} ...")
    model.eval()
    coco = COCO(ann_file)
    results = []

    with torch.no_grad():
        for images, targets in tqdm(loader, desc="COCO Eval"):
            images = [img.to(device) for img in images]
            outputs = model(images)
            image_ids = [t["image_id"].item() for t in targets]
            results.extend(convert_to_coco_predictions(outputs, image_ids))

    ensure_dir("./output/coco_eval")
    pred_file = f"./output/coco_eval/{out_name}"
    with open(pred_file, "w") as f:
        json.dump(results, f)

    if len(results) == 0:
        print("No predictions generated — skipping COCOeval (AP will be 0).")
        return

    coco_dt = coco.loadRes(pred_file)
    ev = COCOeval(coco, coco_dt, "bbox")
    ev.evaluate()
    ev.accumulate()
    ev.summarize()


## 12. LocalRetinaNet inference helpers

In [11]:
# ======================================================================================
# LOCAL RETINANET INFERENCE HELPERS
# ======================================================================================

def run_single_inference_local(model_path, image_path, config):
    device = config.device
    print(f"Loading Local RetinaNet: {model_path}")

    # Load model
    model = LocalRetinaNet(
        num_classes=config.num_classes_local,
        lambda_reg=config.lambda_reg_local,
    ).to(device)

    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state, strict=True)
    model.eval()

    img_bgr = cv2.imread(image_path)
    if img_bgr is None:
        raise ValueError("Image not found: " + image_path)

    orig_h, orig_w = img_bgr.shape[:2]

    # Convert to PIL for ResizeForDetection
    img_pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))

    resize = ResizeForDetection(max_side=1024)

    # Fake target (ResizeForDetection expects a dict)
    fake_target = {"boxes": torch.zeros((0,4))}

    resized_img_tensor, _ = resize(img_pil, fake_target)  # tensor CHW

    resized_h, resized_w = resized_img_tensor.shape[1], resized_img_tensor.shape[2]

    scale_x = orig_w / resized_w
    scale_y = orig_h / resized_h

    # Move to device
    tensor = resized_img_tensor.to(device)

    with torch.no_grad():
        outputs = model([tensor])
        out = outputs[0]

    boxes_scaled = out["boxes"].clone()
    boxes_scaled[:, 0] *= scale_x   # x1
    boxes_scaled[:, 2] *= scale_x   # x2
    boxes_scaled[:, 1] *= scale_y   # y1
    boxes_scaled[:, 3] *= scale_y   # y2

    vis = img_bgr.copy()
    for box, score in zip(boxes_scaled, out["scores"]):
        if score < 0.24:
            continue

        x1, y1, x2, y2 = map(int, box.tolist())
        cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(vis, f"{score:.2f}", (x1, y1 - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    plt.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    plt.show()


def batch_inference_local(model_path, folder, config,
                          save_dir="./output/inference_retina_local"):
    device = config.device
    print("\n=== Batch Inference: Local RetinaNet ===")
    ensure_dir(save_dir)

    model = LocalRetinaNet(
        num_classes=config.num_classes_local,
        lambda_reg=config.lambda_reg_local,
    ).to(device)

    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state, strict=True)
    model.eval()

    image_files = [f for f in os.listdir(folder)
                   if f.lower().endswith((".jpg", ".jpeg", ".png"))]

    with torch.no_grad():
        for file in tqdm(image_files, desc="Local RetinaNet Batch Infer"):
            path_img = os.path.join(folder, file)
            img_bgr = cv2.imread(path_img)
            if img_bgr is None:
                print("Skipping unreadable image:", file)
                continue

            rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
            tensor = F.to_tensor(rgb).to(device)

            outputs = model([tensor])
            out_det = outputs[0]

            vis = img_bgr.copy()
            for box, score in zip(out_det["boxes"], out_det["scores"]):
                if score < 0.3:
                    continue
                x1, y1, x2, y2 = map(int, box.tolist())
                cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(vis, f"{score:.2f}", (x1, y1 - 5),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

            save_path = os.path.join(save_dir, f"{Path(file).stem}_retina_local.jpg")
            cv2.imwrite(save_path, vis)

    print("\n=== Local RetinaNet Batch Inference Completed ===")


## 13. [Training] Train LocalRetinaNet

In [12]:
# ============================
# Example: Train LocalRetinaNet
# ============================

device = config.device
ensure_dir("./output/pt-models")

train_ds_local = SKU110K_COCO(
    config.train_images,
    config.train_annotations,
    transforms=ResizeForDetection(max_side=1024),
)
train_loader_local = DataLoader(
    train_ds_local,
    batch_size=config.batch_size_local,
    shuffle=True,
    num_workers=config.num_workers,
    collate_fn=collate_fn,
)

local_model = LocalRetinaNet(
    num_classes=config.num_classes_local,
    lambda_reg=config.lambda_reg_local,
).to(device)

optimizer_local = optim.Adam(local_model.parameters(), lr=config.lr)

for ep in range(1, config.num_epochs_local + 1):
    train_one_epoch_local_retina(local_model, train_loader_local, optimizer_local, device, ep)

torch.save(local_model.state_dict(), config.save_local_model_path)
print("Saved Local RetinaNet model:", config.save_local_model_path)


[Local-RetinaNet][Epoch 1] Step 0 cls: 0.0592, reg: 0.0824, total: 0.0675
[Local-RetinaNet][Epoch 1] Step 10 cls: 0.0698, reg: 0.0820, total: 0.0781
[Local-RetinaNet][Epoch 1] Step 20 cls: 0.0844, reg: 0.0782, total: 0.0922
[Local-RetinaNet][Epoch 1] Step 30 cls: 0.0775, reg: 0.0863, total: 0.0862


KeyboardInterrupt: 

## 14. [Training] Train RetinaNet baseline

In [None]:
# ============================
# Example: Train torchvision RetinaNet baseline
# ============================

device = config.device
ensure_dir("./output/pt-models")

train_ds_retina = SKU110K_COCO(
    config.train_images,
    config.train_annotations,
    transforms=None,
)
train_loader_retina = DataLoader(
    train_ds_retina,
    batch_size=config.batch_size_retina,
    shuffle=True,
    num_workers=config.num_workers,
    collate_fn=collate_fn,
)

retina_model = create_retinanet(config.num_classes_retina).to(device)
optimizer_retina = optim.Adam(retina_model.parameters(), lr=config.lr)

for ep in range(1, config.num_epochs_retina + 1):
    train_one_epoch_retinanet(retina_model, train_loader_retina, optimizer_retina, device, ep)

torch.save(retina_model.state_dict(), config.save_retinanet_model_path)
print("Saved RetinaNet baseline:", config.save_retinanet_model_path)


## 15. [Evaluation] COCO evaluation (LocalRetinaNet)

In [None]:
# ============================
# COCO Evaluation (LocalRetinaNet)
# ============================

device = config.device

test_ds_local = SKU110K_COCO(
    config.test_images,
    config.test_annotations,
    transforms=ResizeForDetection(max_side=1024),
)
test_loader_local = DataLoader(
    test_ds_local,
    batch_size=config.batch_size_local,
    shuffle=False,
    num_workers=config.num_workers,
    collate_fn=collate_fn,
)

local_model_eval = LocalRetinaNet(
    num_classes=config.num_classes_local,
    lambda_reg=config.lambda_reg_local,
).to(device)
local_model_eval.load_state_dict(torch.load(config.save_local_model_path, map_location=device))

evaluate_coco(local_model_eval, test_loader_local, config.test_annotations, device,
              out_name="pred_retinanet_local.json")


Evaluating COCO mAP → pred_retinanet_local.json ...
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


COCO Eval: 100%|██████████| 50/50 [03:49<00:00,  4.59s/it]


Loading and preparing results...
DONE (t=1.37s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=4.20s).
Accumulating evaluation results...
DONE (t=0.01s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.001
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.007
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=

## 16. [Evaluation] COCO evaluation (RetinaNet baseline)

In [16]:
# ================================================================
# COCO Evaluation (RetinaNet baseline)
# ================================================================

test_ds = SKU110K_COCO(
    root=config.test_images,
    annotation_json=config.test_annotations,
    transforms=None,   # Baseline RetinaNet expects raw image sizes
)

test_loader = DataLoader(
    test_ds,
    batch_size=config.batch_size_retina,
    shuffle=False,
    num_workers=config.num_workers,
    collate_fn=collate_fn,
)

device = config.device
print("Using device:", device)

model = create_retinanet(config.num_classes_retina).to(device)

print("Loading baseline checkpoint:", config.save_retinanet_model_path)
state = torch.load(config.save_retinanet_model_path, map_location=device)
model.load_state_dict(state)
model.eval()

fix_coco_json(config.test_annotations)
coco = COCO(config.test_annotations)

results = []

with torch.no_grad():
    for images, targets in tqdm(test_loader, desc="COCO Eval (Baseline)"):
        images = [img.to(device) for img in images]

        outputs = model(images)
        img_ids = [t["image_id"].item() for t in targets]

        results.extend(convert_to_coco_predictions(outputs, img_ids))

ensure_dir("./output/coco_eval")
pred_file = "./output/coco_eval/pred_retinanet_baseline.json"

with open(pred_file, "w") as f:
    json.dump(results, f)

print("Saved predictions:", pred_file)

if len(results) == 0:
    print("No predictions were produced — COCOeval skipped.")
else:
    coco_dt = coco.loadRes(pred_file)
    evaluator = COCOeval(coco, coco_dt, "bbox")
    evaluator.evaluate()
    evaluator.accumulate()
    evaluator.summarize()

Using device: mps
Loading baseline checkpoint: ./output/pt-models/retinanet_sku110k.pth
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


COCO Eval (Baseline): 100%|██████████| 25/25 [00:10<00:00,  2.29it/s]


Saved predictions: ./output/coco_eval/pred_retinanet_baseline.json
Loading and preparing results...
DONE (t=0.01s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=3.92s).
Accumulating evaluation results...
DONE (t=0.01s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.377
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.608
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.430
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.428
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.005
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.050
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.428

## 17. [Inference] Single-image inference (LocalRetinaNet)

In [None]:
run_single_inference_local(
    config.save_local_model_path,
    config.infer_image_path,
    config,
)

## 18. [Inference] Single-image inference (LocalRetina baseline)

In [None]:
run_inference_retinanet(
    config.save_retinanet_model_path,
    config.infer_image_path,
    config,
)