Skip to content

Commit

Permalink
Merge pull request ros2#114 from irobot-ros/asoragna/defer-callback
Browse files Browse the repository at this point in the history
cherry-pick defer callbacks
  • Loading branch information
alsora committed Jul 18, 2023
2 parents db5e7f1 + 88ef5fa commit 578d5a4
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 67 deletions.
4 changes: 2 additions & 2 deletions rclcpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ find_package(rosidl_typesupport_cpp REQUIRED)
find_package(statistics_msgs REQUIRED)
find_package(tracetools REQUIRED)

# Default to C++14
# Default to C++17
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
endif()
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# About -Wno-sign-conversion: With Clang, -Wconversion implies -Wsign-conversion. There are a number of
Expand Down
158 changes: 100 additions & 58 deletions rclcpp/include/rclcpp/any_service_callback.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#ifndef RCLCPP__ANY_SERVICE_CALLBACK_HPP_
#define RCLCPP__ANY_SERVICE_CALLBACK_HPP_

#include <variant>
#include <functional>
#include <memory>
#include <stdexcept>
#include <type_traits>
#include <utility>

#include "rclcpp/function_traits.hpp"
#include "rclcpp/visibility_control.hpp"
Expand All @@ -29,93 +31,133 @@
namespace rclcpp
{

template<typename ServiceT>
class AnyServiceCallback
namespace detail
{
private:
using SharedPtrCallback = std::function<
void (
const std::shared_ptr<typename ServiceT::Request>,
std::shared_ptr<typename ServiceT::Response>
)>;
using SharedPtrWithRequestHeaderCallback = std::function<
void (
const std::shared_ptr<rmw_request_id_t>,
const std::shared_ptr<typename ServiceT::Request>,
std::shared_ptr<typename ServiceT::Response>
)>;
template<typename T, typename = void>
struct can_be_nullptr : std::false_type {};

// Some lambdas define a comparison with nullptr,
// but we see a warning that they can never be null when using it.
// We also test if `T &` can be assigned to `nullptr` to avoid the issue.
template<typename T>
struct can_be_nullptr<T, std::void_t<
decltype(std::declval<T>() == nullptr), decltype(std::declval<T &>() = nullptr)>>
: std::true_type {};
} // namespace detail

SharedPtrCallback shared_ptr_callback_;
SharedPtrWithRequestHeaderCallback shared_ptr_with_request_header_callback_;
// Forward declare
template<typename ServiceT>
class Service;

template<typename ServiceT>
class AnyServiceCallback
{
public:
AnyServiceCallback()
: shared_ptr_callback_(nullptr), shared_ptr_with_request_header_callback_(nullptr)
: callback_(std::monostate{})
{}

AnyServiceCallback(const AnyServiceCallback &) = default;

template<
typename CallbackT,
typename std::enable_if<
rclcpp::function_traits::same_arguments<
CallbackT,
SharedPtrCallback
>::value
>::type * = nullptr
>
void set(CallbackT callback)
typename std::enable_if_t<!detail::can_be_nullptr<CallbackT>::value, int> = 0>
void
set(CallbackT && callback)
{
shared_ptr_callback_ = callback;
callback_ = std::forward<CallbackT>(callback);
}

template<
typename CallbackT,
typename std::enable_if<
rclcpp::function_traits::same_arguments<
CallbackT,
SharedPtrWithRequestHeaderCallback
>::value
>::type * = nullptr
>
void set(CallbackT callback)
typename std::enable_if_t<detail::can_be_nullptr<CallbackT>::value, int> = 0>
void
set(CallbackT && callback)
{
shared_ptr_with_request_header_callback_ = callback;
if (!callback) {
throw std::invalid_argument("AnyServiceCallback::set(): callback cannot be nullptr");
}
callback_ = std::forward<CallbackT>(callback);
}

void dispatch(
std::shared_ptr<rmw_request_id_t> request_header,
std::shared_ptr<typename ServiceT::Request> request,
std::shared_ptr<typename ServiceT::Response> response)
// template<typename Allocator = std::allocator<typename ServiceT::Response>>
std::shared_ptr<typename ServiceT::Response>
dispatch(
const std::shared_ptr<rclcpp::Service<ServiceT>> & service_handle,
const std::shared_ptr<rmw_request_id_t> & request_header,
std::shared_ptr<typename ServiceT::Request> request)
{
TRACEPOINT(callback_start, static_cast<const void *>(this), false);
if (shared_ptr_callback_ != nullptr) {
if (std::holds_alternative<std::monostate>(callback_)) {
// TODO(ivanpauno): Remove the set method, and force the users of this class
// to pass a callback at construnciton.
throw std::runtime_error{"unexpected request without any callback set"};
}
if (std::holds_alternative<SharedPtrDeferResponseCallback>(callback_)) {
const auto & cb = std::get<SharedPtrDeferResponseCallback>(callback_);
cb(request_header, std::move(request));
return nullptr;
}
if (std::holds_alternative<SharedPtrDeferResponseCallbackWithServiceHandle>(callback_)) {
const auto & cb = std::get<SharedPtrDeferResponseCallbackWithServiceHandle>(callback_);
cb(service_handle, request_header, std::move(request));
return nullptr;
}
// auto response = allocate_shared<typename ServiceT::Response, Allocator>();
auto response = std::make_shared<typename ServiceT::Response>();
if (std::holds_alternative<SharedPtrCallback>(callback_)) {
(void)request_header;
shared_ptr_callback_(request, response);
} else if (shared_ptr_with_request_header_callback_ != nullptr) {
shared_ptr_with_request_header_callback_(request_header, request, response);
} else {
throw std::runtime_error("unexpected request without any callback set");
const auto & cb = std::get<SharedPtrCallback>(callback_);
cb(std::move(request), response);
} else if (std::holds_alternative<SharedPtrWithRequestHeaderCallback>(callback_)) {
const auto & cb = std::get<SharedPtrWithRequestHeaderCallback>(callback_);
cb(request_header, std::move(request), response);
}
TRACEPOINT(callback_end, static_cast<const void *>(this));
return response;
}

void register_callback_for_tracing()
{
#ifndef TRACETOOLS_DISABLED
if (shared_ptr_callback_) {
TRACEPOINT(
rclcpp_callback_register,
static_cast<const void *>(this),
tracetools::get_symbol(shared_ptr_callback_));
} else if (shared_ptr_with_request_header_callback_) {
TRACEPOINT(
rclcpp_callback_register,
static_cast<const void *>(this),
tracetools::get_symbol(shared_ptr_with_request_header_callback_));
}
std::visit(
[this](auto && arg) {
TRACEPOINT(
rclcpp_callback_register,
static_cast<const void *>(this),
tracetools::get_symbol(arg));
}, callback_);
#endif // TRACETOOLS_DISABLED
}

private:
using SharedPtrCallback = std::function<
void (
std::shared_ptr<typename ServiceT::Request>,
std::shared_ptr<typename ServiceT::Response>
)>;
using SharedPtrWithRequestHeaderCallback = std::function<
void (
std::shared_ptr<rmw_request_id_t>,
std::shared_ptr<typename ServiceT::Request>,
std::shared_ptr<typename ServiceT::Response>
)>;
using SharedPtrDeferResponseCallback = std::function<
void (
std::shared_ptr<rmw_request_id_t>,
std::shared_ptr<typename ServiceT::Request>
)>;
using SharedPtrDeferResponseCallbackWithServiceHandle = std::function<
void (
std::shared_ptr<rclcpp::Service<ServiceT>>,
std::shared_ptr<rmw_request_id_t>,
std::shared_ptr<typename ServiceT::Request>
)>;

std::variant<
std::monostate,
SharedPtrCallback,
SharedPtrWithRequestHeaderCallback,
SharedPtrDeferResponseCallback,
SharedPtrDeferResponseCallbackWithServiceHandle> callback_;
};

} // namespace rclcpp
Expand Down
11 changes: 7 additions & 4 deletions rclcpp/include/rclcpp/service.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ class ServiceBase
};

template<typename ServiceT>
class Service : public ServiceBase
class Service
: public ServiceBase,
public std::enable_shared_from_this<Service<ServiceT>>
{
public:
using CallbackType = std::function<
Expand Down Expand Up @@ -439,9 +441,10 @@ class Service : public ServiceBase
std::shared_ptr<void> request) override
{
auto typed_request = std::static_pointer_cast<typename ServiceT::Request>(request);
auto response = std::make_shared<typename ServiceT::Response>();
any_callback_.dispatch(request_header, typed_request, response);
send_response(*request_header, *response);
auto response = any_callback_.dispatch(this->shared_from_this(), request_header, typed_request);
if (response) {
send_response(*request_header, *response);
}
}

void
Expand Down
37 changes: 34 additions & 3 deletions rclcpp/test/rclcpp/test_any_service_callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <utility>

#include "rclcpp/any_service_callback.hpp"
#include "rclcpp/service.hpp"
#include "test_msgs/srv/empty.hpp"
#include "test_msgs/srv/empty.h"

Expand All @@ -44,7 +45,7 @@ class TestAnyServiceCallback : public ::testing::Test

TEST_F(TestAnyServiceCallback, no_set_and_dispatch_throw) {
EXPECT_THROW(
any_service_callback_.dispatch(request_header_, request_, response_),
any_service_callback_.dispatch(nullptr, request_header_, request_),
std::runtime_error);
}

Expand All @@ -57,7 +58,7 @@ TEST_F(TestAnyServiceCallback, set_and_dispatch_no_header) {

any_service_callback_.set(callback);
EXPECT_NO_THROW(
any_service_callback_.dispatch(request_header_, request_, response_));
EXPECT_NE(nullptr, any_service_callback_.dispatch(nullptr, request_header_, request_)));
EXPECT_EQ(callback_calls, 1);
}

Expand All @@ -73,6 +74,36 @@ TEST_F(TestAnyServiceCallback, set_and_dispatch_header) {

any_service_callback_.set(callback_with_header);
EXPECT_NO_THROW(
any_service_callback_.dispatch(request_header_, request_, response_));
EXPECT_NE(nullptr, any_service_callback_.dispatch(nullptr, request_header_, request_)));
EXPECT_EQ(callback_with_header_calls, 1);
}

TEST_F(TestAnyServiceCallback, set_and_dispatch_defered) {
int callback_with_header_calls = 0;
auto callback_with_header =
[&callback_with_header_calls](const std::shared_ptr<rmw_request_id_t>,
const std::shared_ptr<test_msgs::srv::Empty::Request>) {
callback_with_header_calls++;
};

any_service_callback_.set(callback_with_header);
EXPECT_NO_THROW(
EXPECT_EQ(nullptr, any_service_callback_.dispatch(nullptr, request_header_, request_)));
EXPECT_EQ(callback_with_header_calls, 1);
}

TEST_F(TestAnyServiceCallback, set_and_dispatch_defered_with_service_handle) {
int callback_with_header_calls = 0;
auto callback_with_header =
[&callback_with_header_calls](std::shared_ptr<rclcpp::Service<test_msgs::srv::Empty>>,
const std::shared_ptr<rmw_request_id_t>,
const std::shared_ptr<test_msgs::srv::Empty::Request>)
{
callback_with_header_calls++;
};

any_service_callback_.set(callback_with_header);
EXPECT_NO_THROW(
EXPECT_EQ(nullptr, any_service_callback_.dispatch(nullptr, request_header_, request_)));
EXPECT_EQ(callback_with_header_calls, 1);
}

0 comments on commit 578d5a4

Please sign in to comment.