diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py index 3a4c865..c8c7494 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py @@ -8,7 +8,7 @@ from rclpy.task import Future import bdai_ros2_wrappers.scope as scope -from bdai_ros2_wrappers.filters import TransformFilter +from bdai_ros2_wrappers.filters import SimpleAdapter, TransformFilter from bdai_ros2_wrappers.utilities import Tape @@ -27,8 +27,7 @@ def __init__( Args: link: Wrapped message filter, connecting this message feed with its source. history_length: optional historic data size, defaults to 1 - node: optional node for the underlying native subscription, defaults to - the current process node. + node: optional node for lifetime control, defaults to the current process node. """ if node is None: node = scope.ensure_node() @@ -103,6 +102,37 @@ def close(self) -> None: self._tape.close() +class AdaptedMessageFeed(MessageFeed): + """A message feed decorator to simplify adapter patterns.""" + + def __init__( + self, + feed: MessageFeed, + fn: Callable, + **kwargs: Any, + ) -> None: + """Initializes the message feed. + + Args: + feed: the upstream (ie. decorated) message feed. + fn: message adapting callable. + kwargs: all other keyword arguments are forwarded + for `MessageFeed` initialization. + """ + super().__init__(SimpleAdapter(feed.link, fn), **kwargs) + self._feed = feed + + @property + def feed(self) -> MessageFeed: + """Gets the upstream message feed.""" + return self._feed + + def close(self) -> None: + """Closes this message feed and the upstream one as well.""" + self._feed.close() + super().close() + + class FramedMessageFeed(MessageFeed): """A message feed decorator, incorporating transforms using a `TransformFilter` instance.""" diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py index dd4fb80..e58b4c2 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py @@ -4,7 +4,7 @@ import functools import threading from collections.abc import Sequence -from typing import Any, Optional +from typing import Any, Callable, Optional import tf2_ros from message_filters import SimpleFilter @@ -115,3 +115,22 @@ def add(self, *messages: Any) -> None: ) self._ongoing_wait_time = time self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages)) + + +class SimpleAdapter(SimpleFilter): + """A message filter for data adaptation.""" + + def __init__(self, f: SimpleFilter, fn: Callable) -> None: + """Initializes the adapter. + + Args: + f: the upstream message filter. + fn: adapter implementation as a callable. + """ + super().__init__() + self.do_adapt = fn + self.incoming_connection = f.registerCallback(self.add) + + def add(self, *messages: Any) -> None: + """Adds new `messages` to the adapter.""" + self.signalMessage(self.do_adapt(*messages)) diff --git a/bdai_ros2_wrappers/test/test_feeds.py b/bdai_ros2_wrappers/test/test_feeds.py index eb6f285..6450c7b 100644 --- a/bdai_ros2_wrappers/test/test_feeds.py +++ b/bdai_ros2_wrappers/test/test_feeds.py @@ -10,6 +10,7 @@ from message_filters import SimpleFilter from bdai_ros2_wrappers.feeds import ( + AdaptedMessageFeed, FramedMessageFeed, MessageFeed, SynchronizedMessageFeed, @@ -73,3 +74,23 @@ def test_synchronized_message_feed(ros: ROSAwareScope) -> None: pose_message, twist_message = ensure(synchronized_message_feed.latest) assert pose_message.pose.position.x == expected_pose_message.pose.position.x assert twist_message.twist.linear.x == expected_twist_message.twist.linear.x + + +def test_adapted_message_feed(ros: ROSAwareScope) -> None: + pose_message_feed = MessageFeed(SimpleFilter()) + position_message_feed = AdaptedMessageFeed( + pose_message_feed, + fn=lambda message: message.pose.position, + ) + + expected_pose_message = PoseStamped() + expected_pose_message.header.frame_id = "odom" + expected_pose_message.header.stamp.sec = 1 + expected_pose_message.pose.position.x = 1.0 + expected_pose_message.pose.position.z = -1.0 + expected_pose_message.pose.orientation.w = 1.0 + pose_message_feed.link.signalMessage(expected_pose_message) + + position_message = ensure(position_message_feed.latest) + # no copies are expected, thus an identity check is valid + assert position_message is expected_pose_message.pose.position