This notebook processes images and YOLO label files to generate bounding box annotations for Hexbug objects.
The pipeline works as follows:
1. Load YOLO labels containing subpart bounding boxes (head positions of Hexbugs).
2. Use the Segment Anything Model (SAM) to segment the entire frame and detect potential object bounding boxes.
3. Match each YOLO bounding box with SAM-detected bounding boxes based on area overlap:
   - If a SAM bounding box overlaps with at least 50% of the YOLO bounding box area and is larger, it is selected.
   - Refine the selected bounding box by expanding it by 20% to capture the "whole object."
4. Save annotations:
   - Subpart bounding boxes (class 0) from YOLO labels.
   - Whole object bounding boxes (class 1) refined from SAM detections.
5. Annotate and save images with bounding boxes for visualization.


In [None]:
import cv2
import torch
from segment_anything import sam_model_registry, SamPredictor
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import SamAutomaticMaskGenerator
import supervision as sv
import os

Directories for images and labels

In [14]:
img_dir = "../data/dataset/images/train/"
label_dir = "../data/dataset/labels/train/"
output_dir = "../data/dataset/new/"
output_dir_img = output_dir + "images/"
output_dir_img_bbox = output_dir + "images_bbox/"
output_dir_label = output_dir + "labels/"
output_dir_label_seg_only = output_dir + "labels_seg_only/"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir_img, exist_ok=True)
os.makedirs(output_dir_img_bbox, exist_ok=True)
os.makedirs(output_dir_label, exist_ok=True)
os.makedirs(output_dir_label_seg_only, exist_ok=True)


Load SAM

In [4]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

sam = sam_model_registry[MODEL_TYPE](checkpoint="./sam_vit_h_4b8939.pth")
sam.to(device=DEVICE)
mask_generator = SamAutomaticMaskGenerator(sam)


In [None]:
# Function to read YOLO label files
def read_yolo_labels(file_path):
    """
    Reads YOLO label files and parses them into a list of tuples.
    Each tuple contains:
    - class_id: the class label
    - x, y: relative center coordinates
    - w, h: relative width and height
    """
    labels = []
    with open(file_path, 'r') as file:
        for line in file:
            class_id, x, y, w, h = map(float, line.strip().split())
            labels.append((class_id, x, y, w, h))
    return labels

# Function to convert relative YOLO coordinates to absolute pixel coordinates
def rel_to_abs_center(x, y, w, h, img_w, img_h):
    """
    Converts relative bounding box coordinates to absolute pixel coordinates.
    - Input:
        x, y: center coordinates (relative)
        w, h: width and height (relative)
        img_w, img_h: image dimensions in pixels
    - Output:
        abs_x, abs_y: top-left corner in pixels
        abs_w, abs_h: width and height in pixels
    """
    abs_x_center = x * img_w
    abs_y_center = y * img_h
    abs_w = w * img_w
    abs_h = h * img_h
    abs_x = abs_x_center - abs_w / 2
    abs_y = abs_y_center - abs_h / 2
    return abs_x, abs_y, abs_w, abs_h

# Function to calculate the intersection area between two bounding boxes
def intersection_area(box1, box2):
    """
    Computes the intersection area of two bounding boxes.
    - Input:
        box1, box2: bounding boxes as (x1, y1, x2, y2)
    - Output:
        The intersection area (0 if no overlap).
    """
    x1, y1, x2, y2 = box1
    x1_, y1_, x2_, y2_ = box2

    # Calculate overlap coordinates
    xi1 = max(x1, x1_)
    yi1 = max(y1, y1_)
    xi2 = min(x2, x2_)
    yi2 = min(y2, y2_)

    # Compute overlap dimensions
    inter_width = max(0, xi2 - xi1)
    inter_height = max(0, yi2 - yi1)

    return inter_width * inter_height

# Loop through all images in the directory and process them
for img_name in os.listdir(img_dir):
    if img_name.endswith('.jpg'):
        img_path = os.path.join(img_dir, img_name)
        label_path = os.path.join(label_dir, img_name.replace('.jpg', '.txt'))

        # Read image and convert to RGB
        image_bgr = cv2.imread(img_path)
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

        # Save original image to output directory
        output_img_path = os.path.join(output_dir_img, img_name)
        cv2.imwrite(output_img_path, image_bgr)

        # Get image dimensions
        img_h, img_w, _ = image_rgb.shape

        # Generate segmentation masks using SAM
        result = mask_generator.generate(image_rgb)
        detections = sv.Detections.from_sam(result)

        # Read YOLO labels
        labels = read_yolo_labels(label_path)

        # Annotate the image for visualization
        mask_annotator = sv.MaskAnnotator()
        annotated_image = image_bgr.copy()

        # Lists to store subpart (class 0) and whole object (class 1) labels
        subpart_labels = []
        whole_object_labels = []

        for class_id, rel_x, rel_y, rel_w, rel_h in labels:
            # Convert YOLO relative coordinates to absolute coordinates
            abs_x, abs_y, abs_w, abs_h = rel_to_abs_center(rel_x, rel_y, rel_w, rel_h, img_w, img_h)
            yolo_bbox = (abs_x, abs_y, abs_x + abs_w, abs_y + abs_h)
            yolo_area = abs_w * abs_h

            # Find the smallest SAM bounding box containing at least 50% of the YOLO bounding box area
            best_bbox = None
            best_bbox_area = float('inf')

            for detection_bbox in detections.xyxy:
                detection_bbox_area = (detection_bbox[2] - detection_bbox[0]) * (detection_bbox[3] - detection_bbox[1])
                inter_area = intersection_area(yolo_bbox, detection_bbox)
                if inter_area >= 0.5 * yolo_area and detection_bbox_area > yolo_area * 2:
                    x1, y1, x2, y2 = detection_bbox
                    detection_area = (x2 - x1) * (y2 - y1)
                    if detection_area < best_bbox_area:
                        best_bbox_area = detection_area
                        best_bbox = detection_bbox

            if best_bbox is not None:
                x1, y1, x2, y2 = best_bbox

                # Expand bounding box by 20% and ensure it stays within image bounds
                new_x1 = max(0, x1 - 0.1 * (x2 - x1))
                new_y1 = max(0, y1 - 0.1 * (y2 - y1))
                new_x2 = min(img_w, x2 + 0.1 * (x2 - x1))
                new_y2 = min(img_h, y2 + 0.1 * (y2 - y1))

                # Add to "whole object" labels
                whole_object_labels.append((1, (new_x1, new_y1, new_x2, new_y2)))

                # Draw green bounding box for whole object
                cv2.rectangle(annotated_image, (int(new_x1), int(new_y1)), (int(new_x2), int(new_y2)), (0, 255, 0), 2)

            # Add YOLO subpart bounding box to labels (class 0)
            subpart_labels.append((0, (abs_x, abs_y, abs_x + abs_w, abs_y + abs_h)))

            # Draw blue bounding box for subpart
            cv2.rectangle(annotated_image, (int(abs_x), int(abs_y)), (int(abs_x + abs_w), int(abs_y + abs_h)), (255, 0, 0), 2)

        # Save annotated image
        output_img_path = os.path.join(output_dir_img_bbox, img_name)
        cv2.imwrite(output_img_path, annotated_image)

        # Save labels for subpart (class 0) and whole object (class 1)
        with open(os.path.join(output_dir_label, img_name.replace('.jpg', '.txt')), 'w') as file:
            for class_id, bbox in subpart_labels + whole_object_labels:
                x1, y1, x2, y2 = bbox
                rel_x_center = (x1 + (x2 - x1) / 2) / img_w
                rel_y_center = (y1 + (y2 - y1) / 2) / img_h
                rel_w = (x2 - x1) / img_w
                rel_h = (y2 - y1) / img_h
                file.write(f"{class_id} {rel_x_center} {rel_y_center} {rel_w} {rel_h}\n")

        # Optionally display the annotated image
        plt.imshow(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
        plt.title(f"Annotated Image - {img_name}")
        plt.show()
