In [1]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

import supervision as sv
from supervision.draw.color import ColorPalette

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cuda


In [2]:
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

dino_id="IDEA-Research/grounding-dino-base"
grounding_processor = AutoProcessor.from_pretrained(dino_id)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(dino_id).to(device)



In [3]:
object_list = ["table", "chair", "trash can"]
text_prompt = ". ".join(object_list) + "."
text_prompt

'table. chair. trash can.'

In [4]:
from tqdm import tqdm
import rospy
import rosbag
from cv_bridge import CvBridge
import pathlib
import cv2

from byte_tracker import BYTETracker

In [7]:
from types import SimpleNamespace
# args for BYTETracker
args = SimpleNamespace(**{
        "track_thresh": 0.5,
        "track_buffer": 30,
        "match_thresh": 0.75,
        "mot20": False,
        "min_box_area": 100,
})

tracker = BYTETracker(args)

In [8]:
rospy.init_node("image_extractor", anonymous=True)
bridge = CvBridge()

bag_file = "/home/chadwick/Downloads/system.bag"
result_filename = "./results.txt"


with rosbag.Bag(bag_file, "r") as bag:
    total_messages = bag.get_message_count("/camera/image")
    for topic, msg, t in tqdm(
        bag.read_messages(topics=["/camera/image"]),
        total=total_messages,
        desc="Processing messages",
    ):
        cv_image = bridge.imgmsg_to_cv2(msg, desired_encoding="bgr8")
        rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
        height, width, _ = rgb_image.shape
        img_size = (height, width)
        image = np.array(rgb_image)

        inputs = grounding_processor(
            images=image,
            text=text_prompt,
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            outputs = grounding_model(**inputs)

        results = grounding_processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=0.4,
            text_threshold=0.4,
            target_sizes=[image.shape[:2]],
        )

        class_names: list = results[0]["labels"]
        input_boxes = results[0]["boxes"].cpu().numpy()  # (n_boxes, 4)
        confidences = results[0]["scores"].cpu().numpy()  # (n_boxes,)
        detection_data = np.hstack((input_boxes, confidences.reshape(-1, 1)))

        class_ids = np.array(list(range(len(class_names))))

        labels = [
            f"{class_name} {confidence:.2f}"
            for class_name, confidence in zip(class_names, confidences)
        ]

        detections = sv.Detections(
            xyxy=input_boxes,  # (n, 4)
            class_id=class_ids,
        )

        box_annotator = sv.BoxAnnotator(color=ColorPalette.DEFAULT)
        annotated_frame = box_annotator.annotate(
            scene=image.copy(), detections=detections
        )

        label_annotator = sv.LabelAnnotator(color=ColorPalette.DEFAULT)
        annotated_frame = label_annotator.annotate(
            scene=annotated_frame, detections=detections, labels=labels
        )

        online_targets = tracker.update(detection_data, img_size, img_size)
        online_tlwhs = []
        online_tlbrs = []
        online_ids = []
        online_scores = []

        for t in online_targets:
            tlwh = t.tlwh
            tlbr = t.tlbr
            tid = t.track_id
            vertical = tlwh[2] / tlwh[3] > 1.6
            if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
                online_tlwhs.append(tlwh)
                online_tlbrs.append(tlbr)
                online_ids.append(tid)
                online_scores.append(t.score)

        detections = sv.Detections(
            xyxy=np.array(online_tlbrs),
            class_id=np.array(online_ids),
        )

        box_annotator = sv.BoxAnnotator(color=ColorPalette.DEFAULT)
        annotated_frame_1 = box_annotator.annotate(
            scene=image.copy(), detections=detections
        )

        labels = [
            f"{class_name} {confidence:.2f}"
            for class_name, confidence in zip(online_ids, online_scores)
        ]

        label_annotator = sv.LabelAnnotator(color=ColorPalette.DEFAULT)
        annotated_frame_1 = label_annotator.annotate(
            scene=annotated_frame_1, detections=detections, labels=labels
        )

        # show the annotated frame using cv2
        final_frame = np.vstack((annotated_frame, annotated_frame_1))
        final_frame = cv2.cvtColor(final_frame, cv2.COLOR_RGB2BGR)
        cv2.imshow("Annotated Frame", final_frame)

        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

cv2.destroyAllWindows()


Processing messages:   1%|          | 17/2972 [00:02<06:44,  7.31it/s]
