Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Jan 14, 2024
1 parent dc88f6b commit baa4ff8
Showing 1 changed file with 46 additions and 19 deletions.
65 changes: 46 additions & 19 deletions node_scripts/deva_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# -*- coding: utf-8 -*-

import numpy as np
import cv2
import torch
import torch.nn.functional as F
import torchvision
import supervision as sv

import rospy
Expand All @@ -16,7 +17,7 @@
from jsk_recognition_msgs.msg import Label, LabelArray

from deva.dataset.utils import im_normalization
from deva.ext.grounding_dino import segment_with_text
from deva.inference.object_info import ObjectInfo

from model_config import SAMConfig, GroundingDINOConfig, DEVAConfig
from utils import overlay_davis
Expand Down Expand Up @@ -74,24 +75,50 @@ def publish_result(self, mask, vis, frame_id):
def callback(self, img_msg):
self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
with torch.cuda.amp.autocast(enabled=self.cfg["amp"]):
h, w = self.image.shape[:2]
deva_input = F.interpolate(
im_normalization(torch.from_numpy(self.image).permute(2, 0, 1).float() / 255)
.unsqueeze(0)
.to(self.deva_config.device),
(h, w),
mode="bilinear",
align_corners=False,
)[0]
if self.cnt % self.cfg["detection_every"] == 0:
incorporate_mask, segments_info = segment_with_text(
self.cfg,
self.gd_predictor,
self.sam_predictor,
self.image,
self.classes,
min(h, w),
torch_image = im_normalization(torch.from_numpy(self.image).permute(2, 0, 1).float() / 255)
deva_input = torch_image.to(self.deva_config.device)
if self.cnt % self.cfg["detection_every"] == 0: # object detection query
self.sam_predictor.set_image(self.image, image_format="RGB")
# detect objects with GroundingDINO
detections = self.gd_predictor.predict_with_classes(
image=cv2.cvtColor(self.image, cv2.COLOR_RGB2BGR),
classes=self.classes,
box_threshold=self.cfg["DINO_THRESHOLD"],
text_threshold=self.cfg["DINO_THRESHOLD"],
)
nms_idx = (
torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
self.cfg["DINO_NMS_THRESHOLD"],
)
.numpy()
.tolist()
)
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
# segment objects with SAM
result_masks = []
for box in detections.xyxy:
masks, scores, _ = self.sam_predictor.predict(box=box, multimask_output=True)
index = np.argmax(scores)
result_masks.append(masks[index])
detections.mask = np.array(result_masks)
incorporate_mask = torch.zeros(self.image.shape[:2], dtype=torch.int64, device=self.gd_predictor.device)
curr_id = 1
segments_info = []
# sort by descending area to preserve the smallest object
for i in np.flip(np.argsort(detections.area)):
mask = detections.mask[i]
confidence = detections.confidence[i]
class_id = detections.class_id[i]
mask = torch.from_numpy(mask.astype(np.float32))
mask = (mask > 0.5).float()
if mask.sum() > 0:
incorporate_mask[mask > 0] = curr_id
segments_info.append(ObjectInfo(id=curr_id, category_id=class_id, score=confidence))
curr_id += 1
prob = self.deva_predictor.incorporate_detection(deva_input, incorporate_mask, segments_info)
self.object_ids = [seg.id for seg in segments_info]
self.category_ids = [seg.category_ids[0] for seg in segments_info]
Expand Down

0 comments on commit baa4ff8

Please sign in to comment.