Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Jan 14, 2024
1 parent 46a90ab commit 9978fcc
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions node_scripts/deva_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import torch.nn.functional as F
import supervision as sv

import rospy
Expand All @@ -14,7 +15,7 @@
from jsk_recognition_msgs.msg import ClassificationResult
from jsk_recognition_msgs.msg import Label, LabelArray

from deva.inference.demo_utils import get_input_frame_for_deva
from deva.dataset.utils import im_normalization
from deva.ext.grounding_dino import segment_with_text

from model_config import SAMConfig, GroundingDINOConfig, DEVAConfig
Expand Down Expand Up @@ -73,16 +74,17 @@ def publish_result(self, mask, vis, frame_id):
def callback(self, img_msg):
self.image = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding="rgb8")
with torch.cuda.amp.autocast(enabled=self.cfg["amp"]):
min_size = min(self.image.shape[:2])
deva_input = get_input_frame_for_deva(self.image, min_size).to(self.deva_config.device)
h, w = self.image.shape[:2]
deva_input = im_normalization(torch.from_numpy(self.image).permute(2, 0, 1).float() / 255).unsqueeze(0).to(self.deva_config.device)
deva_input = F.interpolate(deva_input, (h, w), mode='bilinear', align_corners=False)[0]
if self.cnt % self.cfg["detection_every"] == 0:
incorporate_mask, segments_info = segment_with_text(
self.cfg,
self.gd_predictor,
self.sam_predictor,
self.image,
self.classes,
min_size,
min(h, w),
)
prob = self.deva_predictor.incorporate_detection(deva_input, incorporate_mask, segments_info)
self.object_ids = [seg.id for seg in segments_info]
Expand Down

0 comments on commit 9978fcc

Please sign in to comment.