Skip to content

Commit

Permalink
Add publisher matching checks to Subscriptions (#106)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com>
  • Loading branch information
mhidalgo-bdai committed Jun 12, 2024
1 parent 2c14e4b commit 2224e6d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 18 deletions.
54 changes: 51 additions & 3 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,56 @@ def __init__(
self._topic_name = topic_name
self._node = node

@property
def subscriber(self) -> Subscriber:
"""Gets the underlying subscriber.
Type-casted alias of `Subscription.link`.
"""
return cast(Subscriber, self.link)

def publisher_matches(self, num_publishers: int) -> Future:
"""Gets a future to next publisher matching status update.
Note that in ROS 2 Humble and ealier distributions, this method relies on
polling the number of known publishers for the topic subscribed, as subscription
matching events are missing.
Args:
num_publishers: lower bound on the number of publishers to match.
Returns:
a future, done if the current number of publishers already matches
the specified lower bound.
"""
future_match = Future()
num_matched_publishers = self._node.count_publishers(self._topic_name)
if num_matched_publishers < num_publishers:

def _poll_publisher_matches() -> None:
nonlocal future_match, num_publishers
if future_match.cancelled():
return
num_matched_publishers = self._node.count_publishers(self._topic_name)
if num_publishers <= num_matched_publishers:
future_match.set_result(num_matched_publishers)

timer = self._node.create_timer(0.1, _poll_publisher_matches)
future_match.add_done_callback(lambda _: self._node.destroy_timer(timer))
else:
future_match.set_result(num_matched_publishers)
return future_match

@property
def matched_publishers(self) -> int:
"""Gets the number publishers matched and linked to.
Note that in ROS 2 Humble and earlier distributions, this property
relies on the number of known publishers for the topic subscribed
as subscription matching status info is missing.
"""
return self._node.count_publishers(self._topic_name)

@property
def message_type(self) -> Type[MessageT]:
"""Gets the type of the message subscribed."""
Expand All @@ -77,9 +127,7 @@ def topic_name(self) -> str:

def close(self) -> None:
"""Closes the subscription."""
self._node.destroy_subscription(
cast(Subscriber, self.link).sub,
)
self._node.destroy_subscription(self.subscriber.sub)
super().close()

# Aliases for improved readability
Expand Down
21 changes: 6 additions & 15 deletions bdai_ros2_wrappers/test/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.
import logging
import threading
from typing import List, Optional
from typing import List

from rcl_interfaces.msg import Log
from rclpy.clock import ROSClock
from rclpy.task import Future
from rclpy.time import Time

from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.futures import unwrap_future
from bdai_ros2_wrappers.logging import LoggingSeverity, as_memoizing_logger, logs_to_ros
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.subscription import Subscription


def test_memoizing_logger(verbose_ros: ROSAwareScope) -> None:
Expand Down Expand Up @@ -76,26 +76,17 @@ def all_messages_arrived() -> bool:


def test_log_forwarding(verbose_ros: ROSAwareScope) -> None:
future: Optional[Future] = None

def callback(message: Log) -> None:
nonlocal future
if future and not future.done():
future.set_result(message)

assert verbose_ros.node is not None
verbose_ros.node.create_subscription(Log, "/rosout", callback, 10)
rosout = Subscription(Log, "/rosout", 10, node=verbose_ros.node)
assert unwrap_future(rosout.publisher_matches(1), timeout_sec=5.0) > 0

future = Future()
with logs_to_ros(verbose_ros.node):
logger = logging.getLogger("my_logger")
logger.setLevel(logging.INFO)
logger.propagate = True # ensure propagation is enabled
logger.info("test")

assert wait_for_future(future, timeout_sec=10)
assert future.done()
log = future.result()
log = unwrap_future(rosout.update, timeout_sec=5.0)
# NOTE(hidmic) why are log levels of bytestring type !?
assert log.level == int.from_bytes(Log.INFO, byteorder="little")
assert log.name == verbose_ros.node.get_logger().name
Expand Down
22 changes: 22 additions & 0 deletions bdai_ros2_wrappers/test/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,27 @@ def deferred_publish() -> None:
assert message.data == 1


def test_subscription_matching_publishers(ros: ROSAwareScope) -> None:
"""Asserts that checking for publisher matching on a subscription works as expected."""
assert ros.node is not None
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert sequence.matched_publishers == 0
future = sequence.publisher_matches(1)
assert not future.done()
future.cancel()

ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1


def test_subscription_future_wait(ros: ROSAwareScope) -> None:
"""Asserts that waiting for a subscription update works as expected."""
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

pub.publish(Int8(data=1))

Expand All @@ -53,6 +69,8 @@ def test_subscription_matching_future_wait(ros: ROSAwareScope) -> None:
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

def deferred_publish() -> None:
time.sleep(0.5)
Expand Down Expand Up @@ -84,6 +102,8 @@ def test_subscription_iteration(ros: ROSAwareScope) -> None:
history_length=3,
node=ros.node,
)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

expected_sequence_numbers = [1, 10, 100]

Expand All @@ -108,6 +128,8 @@ def test_subscription_cancelation(ros: ROSAwareScope) -> None:
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

pub.publish(Int8(data=1))

Expand Down

0 comments on commit 2224e6d

Please sign in to comment.