In [None]:
#!/usr/bin/env python
import sys
import threading
import time

import cv2 as cv
import numpy as np
import rospy
from centermask.config import get_cfg
from detectron2.data import MetadataCatalog
from cv_bridge import CvBridge, CvBridgeError
# import some common detectron2 utilities
from detectron2.engine import DefaultPredictor
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from detectron2_ros.msg import Result
from sensor_msgs.msg import Image, RegionOfInterest

import centermask.modeling.backbone

In [None]:
def load_param(param, default=None):
    new_param = rospy.get_param(param, default)
    rospy.loginfo("[Centermask2] %s: %s", param, new_param)
    return new_param

def convert_to_cv_image(image_msg):

    if image_msg is None:
        return None

    channels = int(len(image_msg.data) / (image_msg.width * image_msg.height))

    encoding = None
    if image_msg.encoding.lower() in ['rgb8', 'bgr8']:
        encoding = np.uint8
    elif image_msg.encoding.lower() == 'mono8':
        encoding = np.uint8
    elif image_msg.encoding.lower() == '32fc1':
        encoding = np.float32
        channels = 1

    cv_img = np.ndarray(shape=(image_msg.height, image_msg.width, channels),
                        dtype=encoding, buffer=image_msg.data)

    if image_msg.encoding.lower() == 'mono8':
        cv_img = cv.cvtColor(cv_img, cv.COLOR_RGB2GRAY)
    else:
        cv_img = cv.cvtColor(cv_img, cv.COLOR_RGB2BGR)

    return cv_img

def getResult(predictions, header, bridge, class_names):

    boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None

    if predictions.has("pred_masks"):
        masks = np.asarray(predictions.pred_masks)
    else:
        return

    result_msg = Result()
    result_msg.header = header
    result_msg.class_ids = predictions.pred_classes if predictions.has("pred_classes") else None
    result_msg.class_names = np.array(class_names)[result_msg.class_ids.numpy()]
    result_msg.scores = predictions.scores if predictions.has("scores") else None

    for i, (x1, y1, x2, y2) in enumerate(boxes):
        mask = np.zeros(masks[i].shape, dtype="uint8")
        mask[masks[i, :, :]]=255
        mask = bridge.cv2_to_imgmsg(mask)
        result_msg.masks.append(mask)

        box = RegionOfInterest()
        box.x_offset = np.uint32(x1)
        box.y_offset = np.uint32(y1)
        box.height = np.uint32(y2 - y1)
        box.width = np.uint32(x2 - x1)
        result_msg.boxes.append(box)

    return result_msg


In [None]:
rospy.init_node('centermask2_ros')

bridge = CvBridge()
_last_msg = None
_msg_lock = threading.Lock()
_image_counter = 0

cfg = get_cfg()
# cfg.merge_from_file(load_param('~config'))
# cfg.merge_from_file("/root/centermask2/configs/centermask/centermask_V_99_eSE_FPN_ms_3x.yaml")
cfg.merge_from_file("/root/centermask2/configs/centermask/centermask_lite_V_39_eSE_FPN_ms_4x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = load_param('~detection_threshold') # set threshold for this model
# cfg.MODEL.WEIGHTS = load_param('~model')
# cfg.MODEL.WEIGHTS = "/root/centermask2-V-99-eSE-FPN-ms-3x.pth"
cfg.MODEL.WEIGHTS = "/root/centermask2-lite-V-39-eSE-FPN-ms-4x.pth"
predictor = DefaultPredictor(cfg)
_class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).get("thing_classes", None)

_visualization = load_param('~visualization',True)
_result_pub = rospy.Publisher('~result', Result, queue_size=1)
_vis_pub = rospy.Publisher('~visualization', Image, queue_size=1)


def callback_image(msg):
    rospy.logdebug("Get an image")
    if _msg_lock.acquire(False):
        _last_msg = msg
        _header = msg.header
        _msg_lock.release()
        
_sub = rospy.Subscriber(load_param('~input'), Image, callback_image, queue_size=1)

_last_msg = rospy.wait_for_message(load_param('~input'), Image, timeout=None)
_header = _last_msg.header

np_image = convert_to_cv_image(_last_msg)
start_time = time.time()
outputs = predictor(np_image)
finish_time = time.time()
print("Time to predict:" + str(finish_time- start_time))

start_time = time.time()
result = outputs["instances"].to("cpu")
finish_time = time.time()
print("CPU transfer:" + str(finish_time- start_time))

start_time = time.time()
result_msg = getResult(result, _header, bridge, _class_names)
finish_time = time.time()
print("Results msg time:" + str(finish_time- start_time))

# self._result_pub.publish(result_msg)

# Visualize results
start_time = time.time()
v = Visualizer(np_image[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
img = v.get_image()[:, :, ::-1]
image_msg = bridge.cv2_to_imgmsg(img)
finish_time = time.time()
print("Visualisation time:" + str(finish_time- start_time))

plt.imshow(img)
# self._vis_pub.publish(image_msg)