Skip to content

Commit

Permalink
Add support for adapted message feeds (#105)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com>
  • Loading branch information
mhidalgo-bdai committed Jun 11, 2024
1 parent 93ecf4a commit 2c14e4b
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 4 deletions.
36 changes: 33 additions & 3 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.filters import TransformFilter
from bdai_ros2_wrappers.filters import SimpleAdapter, TransformFilter
from bdai_ros2_wrappers.utilities import Tape


Expand All @@ -27,8 +27,7 @@ def __init__(
Args:
link: Wrapped message filter, connecting this message feed with its source.
history_length: optional historic data size, defaults to 1
node: optional node for the underlying native subscription, defaults to
the current process node.
node: optional node for lifetime control, defaults to the current process node.
"""
if node is None:
node = scope.ensure_node()
Expand Down Expand Up @@ -103,6 +102,37 @@ def close(self) -> None:
self._tape.close()


class AdaptedMessageFeed(MessageFeed):
"""A message feed decorator to simplify adapter patterns."""

def __init__(
self,
feed: MessageFeed,
fn: Callable,
**kwargs: Any,
) -> None:
"""Initializes the message feed.
Args:
feed: the upstream (ie. decorated) message feed.
fn: message adapting callable.
kwargs: all other keyword arguments are forwarded
for `MessageFeed` initialization.
"""
super().__init__(SimpleAdapter(feed.link, fn), **kwargs)
self._feed = feed

@property
def feed(self) -> MessageFeed:
"""Gets the upstream message feed."""
return self._feed

def close(self) -> None:
"""Closes this message feed and the upstream one as well."""
self._feed.close()
super().close()


class FramedMessageFeed(MessageFeed):
"""A message feed decorator, incorporating transforms using a `TransformFilter` instance."""

Expand Down
21 changes: 20 additions & 1 deletion bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import threading
from collections.abc import Sequence
from typing import Any, Optional
from typing import Any, Callable, Optional

import tf2_ros
from message_filters import SimpleFilter
Expand Down Expand Up @@ -115,3 +115,22 @@ def add(self, *messages: Any) -> None:
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))


class SimpleAdapter(SimpleFilter):
"""A message filter for data adaptation."""

def __init__(self, f: SimpleFilter, fn: Callable) -> None:
"""Initializes the adapter.
Args:
f: the upstream message filter.
fn: adapter implementation as a callable.
"""
super().__init__()
self.do_adapt = fn
self.incoming_connection = f.registerCallback(self.add)

def add(self, *messages: Any) -> None:
"""Adds new `messages` to the adapter."""
self.signalMessage(self.do_adapt(*messages))
21 changes: 21 additions & 0 deletions bdai_ros2_wrappers/test/test_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from message_filters import SimpleFilter

from bdai_ros2_wrappers.feeds import (
AdaptedMessageFeed,
FramedMessageFeed,
MessageFeed,
SynchronizedMessageFeed,
Expand Down Expand Up @@ -73,3 +74,23 @@ def test_synchronized_message_feed(ros: ROSAwareScope) -> None:
pose_message, twist_message = ensure(synchronized_message_feed.latest)
assert pose_message.pose.position.x == expected_pose_message.pose.position.x
assert twist_message.twist.linear.x == expected_twist_message.twist.linear.x


def test_adapted_message_feed(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed(SimpleFilter())
position_message_feed = AdaptedMessageFeed(
pose_message_feed,
fn=lambda message: message.pose.position,
)

expected_pose_message = PoseStamped()
expected_pose_message.header.frame_id = "odom"
expected_pose_message.header.stamp.sec = 1
expected_pose_message.pose.position.x = 1.0
expected_pose_message.pose.position.z = -1.0
expected_pose_message.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message)

position_message = ensure(position_message_feed.latest)
# no copies are expected, thus an identity check is valid
assert position_message is expected_pose_message.pose.position

0 comments on commit 2c14e4b

Please sign in to comment.