# Object Detection using RetinaNet

RetinaNet is a neural network architecture for object detection described in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) by Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He and Piotr Doll√°r.

The following shows how to use the [TorchVision implementation](https://pytorch.org/vision/main/models/retinanet.html) with model parameters pretrained on the [COCO object detection dataset](http://cocodataset.org/).

## Loading a Pretrained Model

In [None]:
import torch
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights

# Load pretrained weights (COCO dataset)
weights = RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1
model = retinanet_resnet50_fpn_v2(weights=weights)
model.eval()

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")

## Detecting Objects (Location and Classes) in Test Images

The COCO class labels are provided by the pretrained weights:

In [None]:
# Get COCO class names from the weights metadata
labels_to_names = weights.meta["categories"]
print(f"Number of classes: {len(labels_to_names)}")
print(f"First 10 classes: {labels_to_names[:10]}")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from PIL import Image
import time

# Get the preprocessing transforms from the weights
preprocess = weights.transforms()


def detect_and_visualize(image_path_or_array, score_threshold=0.5):
    """Detect objects in an image and visualize the results."""
    # Load image
    if isinstance(image_path_or_array, str):
        image = Image.open(image_path_or_array).convert("RGB")
    elif isinstance(image_path_or_array, np.ndarray):
        # Convert BGR (OpenCV) to RGB if needed
        if image_path_or_array.shape[2] == 3:
            image = Image.fromarray(image_path_or_array[..., ::-1])  # BGR to RGB
        else:
            image = Image.fromarray(image_path_or_array)
    else:
        image = image_path_or_array
    
    # Preprocess
    image_tensor = preprocess(image).unsqueeze(0).to(device)
    print(f"Input shape: {image_tensor.shape}, dtype: {image_tensor.dtype}")
    
    # Run inference
    start = time.time()
    with torch.no_grad():
        predictions = model(image_tensor)
    print(f"Processing time: {time.time() - start:.2f}s")
    
    # Extract predictions
    pred = predictions[0]
    boxes = pred["boxes"].cpu().numpy()
    scores = pred["scores"].cpu().numpy()
    labels = pred["labels"].cpu().numpy()
    
    # Visualize
    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(image)
    
    # Define colors for different classes
    cmap = plt.cm.get_cmap("tab20")
    
    for box, score, label in zip(boxes, scores, labels):
        if score < score_threshold:
            continue
            
        x1, y1, x2, y2 = box
        width, height = x2 - x1, y2 - y1
        
        color = cmap(label % 20)
        rect = patches.Rectangle(
            (x1, y1), width, height,
            linewidth=2, edgecolor=color, facecolor="none"
        )
        ax.add_patch(rect)
        
        class_name = labels_to_names[label]
        caption = f"{class_name}: {score:.2f}"
        print(caption)
        ax.text(
            x1, y1 - 5, caption,
            fontsize=10, color="white",
            bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.8)
        )
    
    ax.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
# Test on sample image
detect_and_visualize("webcam_shot.jpeg")

## Real World Data

Let's play with the laptop webcam:

In [None]:
import cv2

def camera_grab(camera_id=0, fallback_filename="webcam_shot.jpeg"):
    """Capture an image from the webcam."""
    camera = cv2.VideoCapture(camera_id)
    try:
        # Take 10 consecutive snapshots to let the camera automatically tune
        # itself and hope that the contrast and lighting of the last snapshot
        # is good enough.
        for i in range(10):
            snapshot_ok, image = camera.read()
        if not snapshot_ok:
            print("WARNING: could not access camera, using fallback image")
            if fallback_filename:
                image = cv2.imread(fallback_filename)
    finally:
        camera.release()
    return image

In [None]:
image = camera_grab(camera_id=0)
plt.figure(figsize=(8, 8))
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.axis("off");

In [None]:
detect_and_visualize(image)