diff --git a/examples/async_node/README.md b/examples/async_node/README.md new file mode 100644 index 0000000..f259214 --- /dev/null +++ b/examples/async_node/README.md @@ -0,0 +1,45 @@ +# Async Node samples + +The following samples are based off the ROS2 documented [tutorials](https://docs.ros.org/en/humble/Tutorials.html). + +## Usage + +*All scripts assume terminals are open in the root of this repository and have installed `rclpy_async`* + +A simple publisher and subscriber: +```sh +# Terminal 1: +python3 -m examples.async_node.subscriber + +# Terminal 2: +python3 -m examples.async_node.publisher +``` + +A simple service and client: +```sh +# Terminal 1: +python3 -m examples.async_node.service + +# Terminal 2: +python3 -m examples.async_node.client 1 2 +``` + +Using parameters in a class: +```sh +# Start the node +python3 -m examples.async_node.param +# [Optionally] start the node with the parameter as something other that "world" by adding: +python3 -m examples.async_node.param --ros-args -p my_parameter:=world2 + +# In another terminal set the parameter again +ros2 param set /minimal_param_node my_parameter earth +``` + +An action server and client: +```sh +# Terminal 1: +python3 -m examples.async_node.action_server + +# Terminal 2: +python3 -m examples.async_node.action_client +``` diff --git a/examples/async_node/action_client.py b/examples/async_node/action_client.py new file mode 100644 index 0000000..a1f4011 --- /dev/null +++ b/examples/async_node/action_client.py @@ -0,0 +1,31 @@ +from typing import Callable + +import anyio +import rclpy +from example_interfaces.action import Fibonacci + +import rclpy_async +from rclpy_async import AsyncNode + +node = AsyncNode("fibonacci_action_client") + + +async def main(): + rclpy.init() + node.initialize() + + async with rclpy_async.start_executor() as xtor: + xtor.add_node(node) + + with rclpy_async.action_client(node, Fibonacci, "fibonacci") as action_client: + print_feedback: Callable[[Fibonacci.Impl.FeedbackMessage], None] = ( + lambda msg: node.get_logger().info( + f"Fibonacci feedback: {msg.feedback}" + ) + ) + result = await action_client(Fibonacci.Goal(order=10), print_feedback) + node.get_logger().info(f"Fibonacci result: {result}") + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/async_node/action_server.py b/examples/async_node/action_server.py new file mode 100644 index 0000000..cbf4964 --- /dev/null +++ b/examples/async_node/action_server.py @@ -0,0 +1,44 @@ +import anyio +import rclpy +from example_interfaces.action import Fibonacci +from rclpy.action.server import ServerGoalHandle + +import rclpy_async +from rclpy_async import AsyncNode, async_run + +node = AsyncNode("fibonacci_action_server_node") + + +@node.action(Fibonacci, "fibonacci") +async def execute_callback(goal_handle: ServerGoalHandle): + node.get_logger().info("Executing goal...") + + feedback_msg = Fibonacci.Feedback() + feedback_msg.sequence = [0, 1] + + for i in range(1, goal_handle.request.order): + feedback_msg.sequence.append( + feedback_msg.sequence[i] + feedback_msg.sequence[i - 1] + ) + node.get_logger().info("Feedback: {0}".format(feedback_msg.sequence)) + goal_handle.publish_feedback(feedback_msg) + await anyio.sleep(1) + + goal_handle.succeed() + + result = Fibonacci.Result() + result.sequence = feedback_msg.sequence + return result + + +async def main(): + rclpy.init() + node.initialize() + + async with rclpy_async.start_executor() as xtor: + xtor.add_node(node) + await async_run(node) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/async_node/client.py b/examples/async_node/client.py new file mode 100644 index 0000000..59a344a --- /dev/null +++ b/examples/async_node/client.py @@ -0,0 +1,35 @@ +import sys + +import anyio +import rclpy +from example_interfaces.srv import AddTwoInts + +import rclpy_async +from rclpy_async import AsyncNode + +node = AsyncNode("minimal_client_async") + + +async def main(): + rclpy.init() + node.initialize() + + async with rclpy_async.start_executor() as executor: + executor.add_node(node) + with rclpy_async.service_client( + node, + AddTwoInts, + "add_two_ints", + ) as cli: + req = AddTwoInts.Request() + req.a = int(sys.argv[1]) + req.b = int(sys.argv[2]) + response = await cli(req) + node.get_logger().info( + "Result of add_two_ints: for %d + %d = %d" + % (int(sys.argv[1]), int(sys.argv[2]), response.sum) + ) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/async_node/param.py b/examples/async_node/param.py new file mode 100644 index 0000000..d21adfc --- /dev/null +++ b/examples/async_node/param.py @@ -0,0 +1,43 @@ +import anyio +import rclpy +from rclpy.parameter import Parameter + +import rclpy_async +from rclpy_async import AsyncNode, async_run + +node = AsyncNode("minimal_param_node") + + +@node.timer( + lambda: node.get_parameter("timer_period").get_parameter_value().double_value +) +async def timer_callback(): + node.get_logger().info( + f"Hello {node.get_parameter('my_parameter').get_parameter_value().string_value}!" + ) + node.set_parameters( + [ + Parameter("my_parameter", Parameter.Type.STRING, "world"), + Parameter( + "timer_period", + Parameter.Type.DOUBLE, + node.get_parameter("timer_period").get_parameter_value().double_value, + ), + ] + ) + + +async def main(): + rclpy.init() + node.initialize() + node.declare_parameters( + namespace="", parameters=[("my_parameter", "world"), ("timer_period", 2.0)] + ) + + async with rclpy_async.start_executor() as xtor: + xtor.add_node(node) + await async_run(node) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/async_node/publisher.py b/examples/async_node/publisher.py new file mode 100644 index 0000000..ebdf5c3 --- /dev/null +++ b/examples/async_node/publisher.py @@ -0,0 +1,31 @@ +import anyio +import rclpy +import rclpy_async +from rclpy_async import AsyncNode, async_run +from std_msgs.msg import String + +node = AsyncNode("minimal_publisher") + + +@node.timer(0.5) +async def timer_callback(): + msg = String(data="Hello World: %d" % node.state.i) + node.state.publisher_.publish(msg) + node.get_logger().info('Publishing: "%s"' % msg.data) + node.state.i += 1 + + +async def main(): + rclpy.init() + node.initialize() + + node.state.i = 0 + node.state.publisher_ = node.create_publisher(String, "topic", 10) + + async with rclpy_async.start_executor() as xtor: + xtor.add_node(node) + await async_run(node) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/async_node/service.py b/examples/async_node/service.py new file mode 100644 index 0000000..567d54c --- /dev/null +++ b/examples/async_node/service.py @@ -0,0 +1,29 @@ +import anyio +import rclpy +from example_interfaces.srv import AddTwoInts + +import rclpy_async +from rclpy_async import AsyncNode, async_run + +node = AsyncNode("minimal_service") + + +@node.service(AddTwoInts, "add_two_ints") +async def add_two_ints_callback(request, response): + response.sum = request.a + request.b + node.get_logger().info("Incoming request \ta: %d b: %d" % (request.a, request.b)) + + return response + + +async def main(): + rclpy.init() + node.initialize() + + async with rclpy_async.start_executor() as xtor: + xtor.add_node(node) + await async_run(node) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/async_node/subscriber.py b/examples/async_node/subscriber.py new file mode 100644 index 0000000..55cb776 --- /dev/null +++ b/examples/async_node/subscriber.py @@ -0,0 +1,30 @@ +import anyio +import rclpy +from std_msgs.msg import String + +import rclpy_async +from rclpy_async import AsyncNode, BackpressureHandlerSpec + +node = AsyncNode("minimal_subscriber") + + +@node.subscription( + String, + "topic", + backpressure_handler=BackpressureHandlerSpec(max_queue_size=5, drop_oldest=True), +) +async def listener_callback(msg: String): + node.get_logger().info('I heard: "%s"' % msg.data) + + +async def main(): + rclpy.init() + node.initialize() + + async with rclpy_async.start_executor() as xtor: + xtor.add_node(node) + await rclpy_async.async_run(node) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/src/rclpy_async/__init__.py b/src/rclpy_async/__init__.py index 34eae4b..b53a300 100644 --- a/src/rclpy_async/__init__.py +++ b/src/rclpy_async/__init__.py @@ -1,8 +1,11 @@ from importlib.metadata import version +from rclpy_async._async_node import BackpressureHandlerSpec from rclpy_async.action_client import action_client from rclpy_async.action_server import action_server from rclpy_async.async_executor import start_executor +from rclpy_async.async_node import AsyncNode +from rclpy_async.async_node import run as async_run from rclpy_async.service_client import service_client from rclpy_async.utilities import ( future_result, @@ -21,4 +24,7 @@ "service_client", "action_client", "action_server", + "async_run", + "AsyncNode", + "BackpressureHandlerSpec", ] diff --git a/src/rclpy_async/_async_node/__init__.py b/src/rclpy_async/_async_node/__init__.py new file mode 100644 index 0000000..4ad8621 --- /dev/null +++ b/src/rclpy_async/_async_node/__init__.py @@ -0,0 +1,17 @@ +from .action_handler_spec import ActionHandlerSpec +from .backpressure_handler_spec import BackpressureHandlerSpec +from .node_proto import NodeProto +from .service_handler_spec import ServiceHandlerSpec +from .state import State +from .timer_handler_spec import TimerHandlerSpec +from .topic_handler_spec import TopicHandlerSpec + +__all__ = [ + "ActionHandlerSpec", + "BackpressureHandlerSpec", + "NodeProto", + "ServiceHandlerSpec", + "State", + "TimerHandlerSpec", + "TopicHandlerSpec", +] diff --git a/src/rclpy_async/_async_node/_base_handler_spec.py b/src/rclpy_async/_async_node/_base_handler_spec.py new file mode 100644 index 0000000..03439bc --- /dev/null +++ b/src/rclpy_async/_async_node/_base_handler_spec.py @@ -0,0 +1,23 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Generic, Optional, TypeVar, ParamSpec + +from .backpressure_handler_spec import BackpressureHandlerSpec + +PInput = ParamSpec("PInput") +TOutput = TypeVar("TOutput") + + +@dataclass +class BaseHandlerSpec(ABC, Generic[PInput, TOutput]): + """Base Handler spec for registering ROS callbacks + + Args: + ABC (_type_): Abstract Base Class + Generic (_type_): Generic class + """ + + backpressure_handler: Optional[BackpressureHandlerSpec] + kwargs: dict[str, Any] + + async_fn: Callable[PInput, Awaitable[TOutput]] diff --git a/src/rclpy_async/_async_node/action_handler_spec.py b/src/rclpy_async/_async_node/action_handler_spec.py new file mode 100644 index 0000000..60530d4 --- /dev/null +++ b/src/rclpy_async/_async_node/action_handler_spec.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Any + +from rclpy.action.server import ServerGoalHandle + +from ._base_handler_spec import BaseHandlerSpec + + +@dataclass +class ActionHandlerSpec(BaseHandlerSpec[[ServerGoalHandle], Any]): + """Specification for a ROS action handler + + async_fn signature: (goal: ServerGoalHandle) -> Awaitable[Any] + + Args: + BaseHandlerSpec (_type_): Inherits from BaseHandlerSpec + """ + + action_type: type + action_name: str diff --git a/src/rclpy_async/_async_node/backpressure_handler_spec.py b/src/rclpy_async/_async_node/backpressure_handler_spec.py new file mode 100644 index 0000000..ab7ca74 --- /dev/null +++ b/src/rclpy_async/_async_node/backpressure_handler_spec.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + + +@dataclass +class BackpressureHandlerSpec: + """Specification for a backpressure handler""" + + max_queue_size: int = 0 + drop_oldest: bool = True diff --git a/src/rclpy_async/_async_node/node_proto.py b/src/rclpy_async/_async_node/node_proto.py new file mode 100644 index 0000000..9003b59 --- /dev/null +++ b/src/rclpy_async/_async_node/node_proto.py @@ -0,0 +1,372 @@ +"""Abstract prototype for a ROS 2 Node. + +This file intentionally strips all implementation details from the original +`rclpy.node.Node` class and leaves only the public interface shape so it can be +used as a base (InnerProto) for delegation or inheritance without bringing in +ROS graph side-effects. Each method raises NotImplementedError and MUST be +implemented by a concrete subclass or provided via composition. + +Only a subset of the original interface is preserved here (the commonly used +APIs). Extend as needed by adding further stubs mirroring upstream signatures. +""" + +from __future__ import annotations + +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + +from rcl_interfaces.msg import ParameterDescriptor, ParameterValue, SetParametersResult +from rclpy.callback_groups import CallbackGroup +from rclpy.client import Client +from rclpy.clock import Clock +from rclpy.context import Context +from rclpy.executors import Executor +from rclpy.guard_condition import GuardCondition +from rclpy.impl.rcutils_logger import RcutilsLogger +from rclpy.parameter import Parameter + +# ROS 2 interface types (mirroring rclpy.node.Node signatures) +from rclpy.publisher import Publisher +from rclpy.qos import QoSProfile +from rclpy.qos_event import PublisherEventCallbacks, SubscriptionEventCallbacks +from rclpy.qos_overriding_options import QoSOverridingOptions +from rclpy.service import Service +from rclpy.subscription import Subscription +from rclpy.timer import Rate, Timer +from rclpy.topic_endpoint_info import TopicEndpointInfo +from rclpy.waitable import Waitable + +# Type variables (kept for signature parity; concrete types resolved in subclasses) +MsgType = TypeVar("MsgType") +SrvType = TypeVar("SrvType") +SrvTypeRequest = TypeVar("SrvTypeRequest") +SrvTypeResponse = TypeVar("SrvTypeResponse") + + +class NodeProto(Protocol): + """Structural interface of a ROS 2 node (typing.Protocol). + + This protocol mirrors the public surface actually consumed by wrappers. + Concrete implementations (e.g. `rclpy.node.Node`) satisfy it implicitly. + Using a Protocol avoids inheritance side-effects and enables lightweight + delegation where unimplemented members simply fall through to the inner node. + """ + + # ---- Lifecycle / basic introspection ------------------------------------------------- + def __init__( + self, + node_name: str, + *, + context: Optional[Any] = None, + cli_args: Optional[List[str]] = None, + namespace: Optional[str] = None, + use_global_arguments: bool = True, + enable_rosout: bool = True, + start_parameter_services: bool = True, + parameter_overrides: Optional[List[Any]] = None, + allow_undeclared_parameters: bool = False, + automatically_declare_parameters_from_overrides: bool = False, + ) -> None: ... + + # ---- Entity collections --------------------------------------------------------------- + @property + def publishers(self) -> Iterator[Publisher]: ... + + @property + def subscriptions(self) -> Iterator[Subscription]: ... + + @property + def clients(self) -> Iterator[Client]: ... + + @property + def services(self) -> Iterator[Service]: ... + + @property + def timers(self) -> Iterator[Timer]: ... + + @property + def guards(self) -> Iterator[GuardCondition]: ... + + @property + def waitables(self) -> Iterator[Waitable]: ... + + # ---- Executor linkage ----------------------------------------------------------------- + @property + def executor(self) -> Optional[Executor]: ... + + @executor.setter + def executor(self, new_executor: Any) -> None: # type: ignore[override] + ... + + # ---- Core properties ------------------------------------------------------------------ + @property + def context(self) -> Context: ... + + @property + def default_callback_group(self) -> CallbackGroup: ... + + @property + def handle(self) -> Any: ... + + @handle.setter + def handle(self, value: Any) -> None: # type: ignore[override] + ... + + # ---- Introspection -------------------------------------------------------------------- + def get_name(self) -> str: ... + + def get_namespace(self) -> str: ... + + def get_clock(self) -> Clock: ... + + def get_logger(self) -> RcutilsLogger: ... + + # ---- Parameters ----------------------------------------------------------------------- + def declare_parameter( + self, + name: str, + value: Any = None, + descriptor: Optional[ParameterDescriptor] = None, + ignore_override: bool = False, + ) -> Parameter: ... + + def declare_parameters( + self, + namespace: str, + parameters: List[ + Union[ + Tuple[str], + Tuple[str, Parameter.Type], + Tuple[str, Any, ParameterDescriptor], + ] + ], + ignore_override: bool = False, + ) -> List[Parameter]: ... + + def undeclare_parameter(self, name: str) -> None: ... + + def has_parameter(self, name: str) -> bool: ... + + def get_parameter_types(self, names: List[str]) -> List[Parameter.Type]: ... + + def get_parameter_type(self, name: str) -> Parameter.Type: ... + + def get_parameters(self, names: List[str]) -> List[Parameter]: ... + + def get_parameter(self, name: str) -> Parameter: ... + + def get_parameter_or( + self, name: str, alternative_value: Optional[Parameter] = None + ) -> Parameter: ... + + def get_parameters_by_prefix( + self, + prefix: str, + ) -> Dict[ + str, + Optional[ + Union[ + bool, + int, + float, + str, + bytes, + Sequence[bool], + Sequence[int], + Sequence[float], + Sequence[str], + ] + ], + ]: ... + + def set_parameters( + self, parameter_list: List[Parameter] + ) -> List[SetParametersResult]: ... + + def set_parameters_atomically( + self, parameter_list: List[Parameter] + ) -> SetParametersResult: ... + + def add_on_set_parameters_callback( + self, callback: Callable[[List[Parameter]], SetParametersResult] + ) -> None: ... + + def remove_on_set_parameters_callback( + self, callback: Callable[[List[Parameter]], SetParametersResult] + ) -> None: ... + + def describe_parameter(self, name: str) -> ParameterDescriptor: ... + + def describe_parameters(self, names: List[str]) -> List[ParameterDescriptor]: ... + + def set_descriptor( + self, + name: str, + descriptor: ParameterDescriptor, + alternative_value: Optional[ParameterValue] = None, + ) -> ParameterValue: ... + + # ---- Name resolution ------------------------------------------------------------------ + def resolve_topic_name(self, topic: str, *, only_expand: bool = False) -> str: ... + + def resolve_service_name( + self, service: str, *, only_expand: bool = False + ) -> str: ... + + # ---- Creation of entities ------------------------------------------------------------- + def create_publisher( + self, + msg_type: Type[MsgType], + topic: str, + qos_profile: Union[QoSProfile, int], + *, + callback_group: Optional[CallbackGroup] = None, + event_callbacks: Optional[PublisherEventCallbacks] = None, + qos_overriding_options: Optional[QoSOverridingOptions] = None, + publisher_class: Type[Publisher] = Publisher, + ) -> Publisher: ... + + def create_subscription( + self, + msg_type: Type[MsgType], + topic: str, + callback: Callable[[MsgType], None], + qos_profile: Union[QoSProfile, int], + *, + callback_group: Optional[CallbackGroup] = None, + event_callbacks: Optional[SubscriptionEventCallbacks] = None, + qos_overriding_options: Optional[QoSOverridingOptions] = None, + raw: bool = False, + ) -> Subscription: ... + + def create_client( + self, + srv_type: Type[SrvType], + srv_name: str, + *, + qos_profile: QoSProfile = QoSProfile(depth=10), + callback_group: Optional[CallbackGroup] = None, + ) -> Client: ... + + def create_service( + self, + srv_type: Type[SrvType], + srv_name: str, + callback: Callable[[SrvTypeRequest, SrvTypeResponse], SrvTypeResponse], + *, + qos_profile: QoSProfile = QoSProfile(depth=10), + callback_group: Optional[CallbackGroup] = None, + ) -> Service: ... + + def create_timer( + self, + timer_period_sec: float, + callback: Callable, + callback_group: Optional[CallbackGroup] = None, + clock: Optional[Clock] = None, + ) -> Timer: ... + + def create_guard_condition( + self, + callback: Callable, + callback_group: Optional[CallbackGroup] = None, + ) -> GuardCondition: ... + + def create_rate( + self, + frequency: float, + clock: Optional[Clock] = None, + ) -> Rate: ... + + # ---- Destruction of entities ---------------------------------------------------------- + def destroy_publisher(self, publisher: Any) -> bool: ... + + def destroy_subscription(self, subscription: Any) -> bool: ... + + def destroy_client(self, client: Any) -> bool: ... + + def destroy_service(self, service: Any) -> bool: ... + + def destroy_timer(self, timer: Any) -> bool: ... + + def destroy_guard_condition(self, guard: Any) -> bool: ... + + def destroy_rate(self, rate: Any) -> bool: ... + + def destroy_node(self) -> None: ... + + # ---- Discovery / graph introspection -------------------------------------------------- + def get_publisher_names_and_types_by_node( + self, + node_name: str, + node_namespace: str, + no_demangle: bool = False, + ) -> List[Tuple[str, List[str]]]: ... + + def get_subscriber_names_and_types_by_node( + self, + node_name: str, + node_namespace: str, + no_demangle: bool = False, + ) -> List[Tuple[str, List[str]]]: ... + + def get_service_names_and_types_by_node( + self, + node_name: str, + node_namespace: str, + ) -> List[Tuple[str, List[str]]]: ... + + def get_client_names_and_types_by_node( + self, + node_name: str, + node_namespace: str, + ) -> List[Tuple[str, List[str]]]: ... + + def get_topic_names_and_types( + self, no_demangle: bool = False + ) -> List[Tuple[str, List[str]]]: ... + + def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]: ... + + def get_node_names(self) -> List[str]: ... + + def get_node_names_and_namespaces(self) -> List[Tuple[str, str]]: ... + + def get_node_names_and_namespaces_with_enclaves( + self, + ) -> List[Tuple[str, str, str]]: ... + + def get_fully_qualified_name(self) -> str: ... + + # ---- Counting / endpoint info --------------------------------------------------------- + def count_publishers(self, topic_name: str) -> int: ... + + def count_subscribers(self, topic_name: str) -> int: ... + + def get_publishers_info_by_topic( + self, + topic_name: str, + no_mangle: bool = False, + ) -> List[TopicEndpointInfo]: ... + + def get_subscriptions_info_by_topic( + self, + topic_name: str, + no_mangle: bool = False, + ) -> List[TopicEndpointInfo]: ... + + # ---- Utility -------------------------------------------------------------------------- + def __repr__(self) -> str: # Helpful debug hook + return f"<{self.__class__.__name__} (Protocol)>" diff --git a/src/rclpy_async/_async_node/service_handler_spec.py b/src/rclpy_async/_async_node/service_handler_spec.py new file mode 100644 index 0000000..984f2f2 --- /dev/null +++ b/src/rclpy_async/_async_node/service_handler_spec.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import Any + +from ._base_handler_spec import BaseHandlerSpec + + +@dataclass +class ServiceHandlerSpec(BaseHandlerSpec[[Any, Any], Any]): + """Specification for a ROS service handler + + async_fn signature: (request: SrvType.Request, response: SrvType.Response) -> Awaitable[SrvType.Response] + + Args: + BaseHandlerSpec (_type_): Inherits from BaseHandlerSpec + """ + + srv_type: type + srv_name: str diff --git a/src/rclpy_async/_async_node/state.py b/src/rclpy_async/_async_node/state.py new file mode 100644 index 0000000..67887a7 --- /dev/null +++ b/src/rclpy_async/_async_node/state.py @@ -0,0 +1,29 @@ +from typing import Any + + +class State: + """ + An object that can be used to store arbitrary state. + + Used for `request.state` and `app.state`. + """ + + _state: dict[str, Any] + + def __init__(self, state: dict[str, Any] | None = None): + if state is None: + state = {} + super().__setattr__("_state", state) + + def __setattr__(self, key: Any, value: Any) -> None: + self._state[key] = value + + def __getattr__(self, key: Any) -> Any: + try: + return self._state[key] + except KeyError: + message = "'{}' object has no attribute '{}'" + raise AttributeError(message.format(self.__class__.__name__, key)) + + def __delattr__(self, key: Any) -> None: + del self._state[key] diff --git a/src/rclpy_async/_async_node/timer_handler_spec.py b/src/rclpy_async/_async_node/timer_handler_spec.py new file mode 100644 index 0000000..3f80db3 --- /dev/null +++ b/src/rclpy_async/_async_node/timer_handler_spec.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from typing import Callable, Union + +from ._base_handler_spec import BaseHandlerSpec + + +@dataclass +class TimerHandlerSpec(BaseHandlerSpec[[], None]): + """Specification for a ROS service handler + + async_fn signature: () -> Awaitable[None] + + Args: + BaseHandlerSpec (_type_): Inherits from BaseHandlerSpec + """ + + timer_period_sec: Union[float, Callable[[], float]] diff --git a/src/rclpy_async/_async_node/topic_handler_spec.py b/src/rclpy_async/_async_node/topic_handler_spec.py new file mode 100644 index 0000000..934eaba --- /dev/null +++ b/src/rclpy_async/_async_node/topic_handler_spec.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Optional + +from rclpy.qos import QoSProfile + +from ._base_handler_spec import BaseHandlerSpec + + +@dataclass +class TopicHandlerSpec(BaseHandlerSpec[[Any], None]): + """Specification for a subscription handler. + + async_fn signature: (message: MsgType) -> Awaitable[None] + + Args: + BaseHandlerSpec (_type_): Inherits from BaseHandlerSpec + """ + + msg_type: type + topic_name: str + qos_profile: QoSProfile diff --git a/src/rclpy_async/async_node.py b/src/rclpy_async/async_node.py new file mode 100644 index 0000000..a9eca5a --- /dev/null +++ b/src/rclpy_async/async_node.py @@ -0,0 +1,373 @@ +import inspect +from contextlib import ExitStack +from typing import Any, Awaitable, Callable, List, Optional, Tuple, Type, Union + +import anyio +import rclpy +from rclpy.action.server import ServerGoalHandle +from rclpy.node import Node +from rclpy.qos import QoSProfile + +import rclpy_async + +from ._async_node import ( + ActionHandlerSpec, + BackpressureHandlerSpec, + NodeProto, + ServiceHandlerSpec, + State, + TimerHandlerSpec, + TopicHandlerSpec, +) + +# Public interface member names gathered from the Protocol; used for delegation. +_DELEGATE_NAMES = { + name + for name, member in vars(NodeProto).items() + if not name.startswith("_") + and (inspect.isfunction(member) or isinstance(member, property)) +} + + +def _is_overridden(cls: Type[Any], name: str) -> bool: + if not hasattr(NodeProto, name): + return True # Not part of proto; treat as overridden/local. + member_cls = getattr(cls, name, None) + member_proto = getattr(NodeProto, name, None) + if inspect.isfunction(member_cls) and inspect.isfunction(member_proto): + return member_cls is not member_proto + if isinstance(member_proto, property) and isinstance(member_cls, property): + return member_cls.fget is not member_proto.fget + return True + + +class AsyncNode(NodeProto): + """Asynchronous wrapper around a ROS 2 `Node` using dynamic delegation. + + Delegates all protocol-defined attributes/methods to `__inner` unless overridden. + Provides decorators for subscription and timer handlers executed with anyio. + """ + + __inner: Optional[Node] = None # Underlying rclpy Node instance + state = State() # Arbitrary user state container + __node_name: str + __exit_stack = ExitStack() + __fallback_values: dict[str, Any] = ( + {} + ) # Local storage for protocol-only properties when inner lacks them. + + __action_handler_specs: List[ActionHandlerSpec] = [] + __service_handler_specs: List[ServiceHandlerSpec] = [] + __timer_handler_specs: List[TimerHandlerSpec] = [] + __topic_handler_specs: List[TopicHandlerSpec] = [] + + def __init__(self, node_name: str): + self.__node_name = node_name + + def initialize(self) -> None: + """Create underlying rclpy Node and declare parameters if a schema is provided.""" + if self.__inner is not None: + raise RuntimeError("AsyncNode already initialized") + self.__inner = rclpy.create_node(self.__node_name) + + def __getattr__(self, name: str) -> Any: + """Delegate protocol members to inner node when not overridden locally.""" + inner = self.__inner + if ( + inner is not None + and name in _DELEGATE_NAMES + and not _is_overridden(type(self), name) + ): + return getattr(inner, name) + if inner is not None and hasattr(inner, name): + return getattr(inner, name) + raise AttributeError(name) + + def __getattribute__(self, name: str) -> Any: + """Attribute access with delegation for Protocol stub members.""" + # Fast-path for private/internal attributes to avoid recursion. + if name.startswith("_AsyncNode__"): + return super().__getattribute__(name) + + inner = super().__getattribute__("_AsyncNode__inner") + # For protocol-defined members not overridden locally, attempt delegation. + if name in _DELEGATE_NAMES and not _is_overridden(type(self), name): + if inner is not None and hasattr(inner, name): + return getattr(inner, name) + # Fallback to locally stored value if inner doesn't provide it. + fb = super().__getattribute__("_AsyncNode__fallback_values") + if name in fb: + return fb[name] + return super().__getattribute__(name) + + def __setattr__(self, name: str, value: Any) -> None: + # Allow normal setting for private / explicitly implemented attributes on AsyncNode itself. + if ( + name.startswith("_AsyncNode__") + or name in {"params", "state"} + or name in type(self).__dict__ + ): + super().__setattr__(name, value) + return + # Delegation path for protocol-defined members not overridden here. + if name in _DELEGATE_NAMES and not _is_overridden(type(self), name): + inner = getattr(self, "_AsyncNode__inner", None) + if inner is not None and hasattr(inner, name): + setattr(inner, name, value) + return + # Store locally when inner lacks the attribute (e.g. protocol stub property). + fb = getattr(self, "_AsyncNode__fallback_values") + fb[name] = value + return + # Ordinary attribute (not protocol-defined): store on wrapper instance. + super().__setattr__(name, value) + + def __dir__(self) -> List[str]: + inner = self.__inner + names = set(super().__dir__()) + if inner is not None: + names.update(dir(inner)) + return sorted(names) + + def __repr__(self) -> str: + return f"" + + def action( + self, + action_type: type, + action_name: str, + backpressure_handler: Optional[BackpressureHandlerSpec] = None, + **kwargs, + ) -> Callable[ + [Callable[[ServerGoalHandle], Awaitable[Any]]], + Callable[[ServerGoalHandle], Awaitable[Any]], + ]: + """Decorator registering an async action handler.""" + + def _decorator(async_fn: Callable[[ServerGoalHandle], Awaitable[Any]]): + self.__action_handler_specs.append( + ActionHandlerSpec( + action_name=action_name, + action_type=action_type, + async_fn=async_fn, + backpressure_handler=backpressure_handler, + kwargs=kwargs, + ) + ) + return async_fn + + return _decorator + + def service( + self, + srv_type: type, + srv_name: str, + backpressure_handler: Optional[BackpressureHandlerSpec] = None, + **kwargs, + ) -> Callable[[Callable[[Any], Awaitable[None]]], Callable[[Any], Awaitable[None]]]: + """Decorator registering an async service handler.""" + + def _decorator(async_fn: Callable[[Any], Awaitable[None]]): + self.__service_handler_specs.append( + ServiceHandlerSpec( + async_fn=async_fn, + backpressure_handler=backpressure_handler, + kwargs=kwargs, + srv_name=srv_name, + srv_type=srv_type, + ) + ) + return async_fn + + return _decorator + + def subscription( + self, + msg_type: type, + topic_name: str, + qos_profile: QoSProfile = 10, + backpressure_handler: Optional[BackpressureHandlerSpec] = None, + **kwargs, + ) -> Callable[[Callable[[Any], Awaitable[None]]], Callable[[Any], Awaitable[None]]]: + """Decorator registering an async subscription handler.""" + + def _decorator(async_fn: Callable[[Any], Awaitable[None]]): + self.__topic_handler_specs.append( + TopicHandlerSpec( + async_fn=async_fn, + backpressure_handler=backpressure_handler, + kwargs=kwargs, + msg_type=msg_type, + qos_profile=qos_profile, + topic_name=topic_name, + ) + ) + return async_fn + + return _decorator + + def timer( + self, + timer_period_sec: Union[float, Callable[[], float]], + backpressure_handler: Optional[BackpressureHandlerSpec] = None, + **kwargs, + ) -> Callable[[Callable[[], Awaitable[None]]], Callable[[], Awaitable[None]]]: + """Decorator registering an async timer handler.""" + + def _decorator(async_fn: Callable[[], Awaitable[None]]): + self.__timer_handler_specs.append( + TimerHandlerSpec( + async_fn=async_fn, + backpressure_handler=backpressure_handler, + kwargs=kwargs, + timer_period_sec=timer_period_sec, + ) + ) + return async_fn + + return _decorator + + def destroy_node(self) -> None: + """Ensure underlying node is properly destroyed.""" + self.__exit_stack.close() + inner = self.__inner + if inner is not None: + inner.destroy_node() + self.__inner = None + + def gather_coroutines( + self, + ) -> List[Callable[[], Awaitable[None]]]: + """Start processing subscription and timer handlers until cancelled.""" + if self.__inner is None: + raise RuntimeError("AsyncNode not initialized; call initialize() first") + + _attached_consumers = [] + + for spec in self.__action_handler_specs: + handler, coro = _wrap_handler_with_backpressure( + spec.async_fn, spec.backpressure_handler + ) + + self.__exit_stack.enter_context( + rclpy_async.action_server( + self.__inner, + spec.action_type, + spec.action_name, + handler, + **spec.kwargs, + ) + ) + + if coro is not None: + _attached_consumers.append(coro) + + for spec in self.__service_handler_specs: + async_fn = spec.async_fn + srv_type = spec.srv_type + srv_name = spec.srv_name + + handler, coro = _wrap_handler_with_backpressure( + async_fn, spec.backpressure_handler + ) + self.__inner.create_service( + srv_type, + srv_name, + handler, + **spec.kwargs, + ) + + if coro is not None: + _attached_consumers.append(coro) + + for spec in self.__topic_handler_specs: + async_fn = spec.async_fn + msg_type = spec.msg_type + topic_name = spec.topic_name + qos_profile = spec.qos_profile + + handler, coro = _wrap_handler_with_backpressure( + async_fn, spec.backpressure_handler + ) + self.__inner.create_subscription( + msg_type, + topic_name, + handler, + qos_profile=qos_profile, + **spec.kwargs, + ) + + if coro is not None: + _attached_consumers.append(coro) + + for spec in self.__timer_handler_specs: + timer_period_sec = spec.timer_period_sec + async_fn = spec.async_fn + if callable(timer_period_sec): + try: + period_value = float(timer_period_sec()) + except Exception: + period_value = 1.0 + else: + period_value = float(timer_period_sec) + + handler, coro = _wrap_handler_with_backpressure( + async_fn, spec.backpressure_handler + ) + self.__inner.create_timer(period_value, handler, **spec.kwargs) + + if coro is not None: + _attached_consumers.append(coro) + + return _attached_consumers + + +def gather_nodes(*nodes: AsyncNode) -> List[Callable[[], Awaitable[None]]]: + """Gather coroutines from multiple AsyncNode instances.""" + coroutines = [] + for node in nodes: + coroutines.extend(node.gather_coroutines()) + return coroutines + + +async def run(*nodes: AsyncNode) -> None: + async with anyio.create_task_group() as tg: + coroutines = gather_nodes(*nodes) + for consumer_task in coroutines: + tg.start_soon(consumer_task) + + await anyio.sleep(float("inf")) + + +def _wrap_handler_with_backpressure( + async_fn: Callable[..., Awaitable[None]], + spec: Optional[BackpressureHandlerSpec] = None, +) -> Tuple[Callable[..., None], Callable[[], Awaitable[None]]]: + """Wrap an async function with a memory object stream for backpressure handling.""" + + if spec is None: + return async_fn, None + + send_stream, receive_stream = anyio.create_memory_object_stream(spec.max_queue_size) + + def _sub_callback(*args, **kwargs): + try: + send_stream.send_nowait((args, kwargs)) + except anyio.WouldBlock: + if spec.max_queue_size == 0: + return + if spec.drop_oldest: + try: + _ = receive_stream.receive_nowait() + except anyio.WouldBlock: + return + try: + send_stream.send_nowait((args, kwargs)) + except anyio.WouldBlock: + pass + + async def _consumer_task(): + async for args, kwargs in receive_stream: + await async_fn(*args, **kwargs) + + return _sub_callback, _consumer_task