From baa4ff8faa76dea10c702e6f150334cec6e87a4e Mon Sep 17 00:00:00 2001 From: Jihoon Oh Date: Sun, 14 Jan 2024 22:03:24 +0900 Subject: [PATCH] update --- node_scripts/deva_node.py | 65 +++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 19 deletions(-) diff --git a/node_scripts/deva_node.py b/node_scripts/deva_node.py index e211288..83c281c 100644 --- a/node_scripts/deva_node.py +++ b/node_scripts/deva_node.py @@ -2,8 +2,9 @@ # -*- coding: utf-8 -*- import numpy as np +import cv2 import torch -import torch.nn.functional as F +import torchvision import supervision as sv import rospy @@ -16,7 +17,7 @@ from jsk_recognition_msgs.msg import Label, LabelArray from deva.dataset.utils import im_normalization -from deva.ext.grounding_dino import segment_with_text +from deva.inference.object_info import ObjectInfo from model_config import SAMConfig, GroundingDINOConfig, DEVAConfig from utils import overlay_davis @@ -74,24 +75,50 @@ def publish_result(self, mask, vis, frame_id): def callback(self, img_msg): self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8") with torch.cuda.amp.autocast(enabled=self.cfg["amp"]): - h, w = self.image.shape[:2] - deva_input = F.interpolate( - im_normalization(torch.from_numpy(self.image).permute(2, 0, 1).float() / 255) - .unsqueeze(0) - .to(self.deva_config.device), - (h, w), - mode="bilinear", - align_corners=False, - )[0] - if self.cnt % self.cfg["detection_every"] == 0: - incorporate_mask, segments_info = segment_with_text( - self.cfg, - self.gd_predictor, - self.sam_predictor, - self.image, - self.classes, - min(h, w), + torch_image = im_normalization(torch.from_numpy(self.image).permute(2, 0, 1).float() / 255) + deva_input = torch_image.to(self.deva_config.device) + if self.cnt % self.cfg["detection_every"] == 0: # object detection query + self.sam_predictor.set_image(self.image, image_format="RGB") + # detect objects with GroundingDINO + detections = self.gd_predictor.predict_with_classes( + image=cv2.cvtColor(self.image, cv2.COLOR_RGB2BGR), + classes=self.classes, + box_threshold=self.cfg["DINO_THRESHOLD"], + text_threshold=self.cfg["DINO_THRESHOLD"], ) + nms_idx = ( + torchvision.ops.nms( + torch.from_numpy(detections.xyxy), + torch.from_numpy(detections.confidence), + self.cfg["DINO_NMS_THRESHOLD"], + ) + .numpy() + .tolist() + ) + detections.xyxy = detections.xyxy[nms_idx] + detections.confidence = detections.confidence[nms_idx] + detections.class_id = detections.class_id[nms_idx] + # segment objects with SAM + result_masks = [] + for box in detections.xyxy: + masks, scores, _ = self.sam_predictor.predict(box=box, multimask_output=True) + index = np.argmax(scores) + result_masks.append(masks[index]) + detections.mask = np.array(result_masks) + incorporate_mask = torch.zeros(self.image.shape[:2], dtype=torch.int64, device=self.gd_predictor.device) + curr_id = 1 + segments_info = [] + # sort by descending area to preserve the smallest object + for i in np.flip(np.argsort(detections.area)): + mask = detections.mask[i] + confidence = detections.confidence[i] + class_id = detections.class_id[i] + mask = torch.from_numpy(mask.astype(np.float32)) + mask = (mask > 0.5).float() + if mask.sum() > 0: + incorporate_mask[mask > 0] = curr_id + segments_info.append(ObjectInfo(id=curr_id, category_id=class_id, score=confidence)) + curr_id += 1 prob = self.deva_predictor.incorporate_detection(deva_input, incorporate_mask, segments_info) self.object_ids = [seg.id for seg in segments_info] self.category_ids = [seg.category_ids[0] for seg in segments_info]