Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions orc-rt/include/orc-rt/CallableTraitsHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//===- CallableTraitsHelper.h - Callable arg/ret type extractor -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// CallableTraitsHelper API.
//
//===----------------------------------------------------------------------===//

#ifndef ORC_RT_CALLABLETRAITSHELPER_H
#define ORC_RT_CALLABLETRAITSHELPER_H

#include <tuple>
#include <type_traits>

namespace orc_rt {

/// CallableTraitsHelper takes an implementation class template Impl and some
/// callable type C and passes the return and argument types of C to the Impl
/// class template.
///
/// This can be used to simplify the implementation of classes that need to
/// operate on callable types.
template <template <typename...> typename ImplT, typename C>
struct CallableTraitsHelper
: public CallableTraitsHelper<
ImplT,
decltype(&std::remove_cv_t<std::remove_reference_t<C>>::operator())> {
};

template <template <typename...> typename ImplT, typename RetT,
typename... ArgTs>
struct CallableTraitsHelper<ImplT, RetT(ArgTs...)>
: public ImplT<RetT, ArgTs...> {};

template <template <typename...> typename ImplT, typename RetT,
typename... ArgTs>
struct CallableTraitsHelper<ImplT, RetT (*)(ArgTs...)>
: public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};

template <template <typename...> typename ImplT, typename RetT,
typename... ArgTs>
struct CallableTraitsHelper<ImplT, RetT (&)(ArgTs...)>
: public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};

template <template <typename...> typename ImplT, typename ClassT, typename RetT,
typename... ArgTs>
struct CallableTraitsHelper<ImplT, RetT (ClassT::*)(ArgTs...)>
: public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};

template <template <typename...> typename ImplT, typename ClassT, typename RetT,
typename... ArgTs>
struct CallableTraitsHelper<ImplT, RetT (ClassT::*)(ArgTs...) const>
: public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};

namespace detail {
template <typename RetT, typename... ArgTs> struct CallableArgInfoImpl {
typedef RetT return_type;
typedef std::tuple<ArgTs...> args_tuple_type;
};
} // namespace detail

/// CallableArgInfo provides typedefs for the return type and argument types
/// (as a tuple) of the given callable type.
template <typename Callable>
struct CallableArgInfo
: public CallableTraitsHelper<detail::CallableArgInfoImpl, Callable> {};

} // namespace orc_rt

#endif // ORC_RT_CALLABLETRAITSHELPER_H
77 changes: 32 additions & 45 deletions orc-rt/include/orc-rt/WrapperFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define ORC_RT_WRAPPERFUNCTION_H

#include "orc-rt-c/WrapperFunction.h"
#include "orc-rt/CallableTraitsHelper.h"
#include "orc-rt/Error.h"
#include "orc-rt/bind.h"

Expand Down Expand Up @@ -105,37 +106,16 @@ class WrapperFunctionBuffer {

namespace detail {

template <typename C>
struct WFCallableTraits
: public WFCallableTraits<
decltype(&std::remove_cv_t<std::remove_reference_t<C>>::operator())> {
};

template <typename RetT> struct WFCallableTraits<RetT()> {
typedef void HeadArgType;
template <typename RetT, typename ReturnT, typename... ArgTs>
struct WFHandlerTraitsImpl {
static_assert(std::is_void_v<RetT>,
"Async wrapper function handler must return void");
typedef ReturnT YieldType;
typedef std::tuple<ArgTs...> ArgTupleType;
};

template <typename RetT, typename ArgT, typename... ArgTs>
struct WFCallableTraits<RetT(ArgT, ArgTs...)> {
typedef ArgT HeadArgType;
typedef std::tuple<ArgTs...> TailArgTuple;
};

template <typename RetT, typename... ArgTs>
struct WFCallableTraits<RetT (*)(ArgTs...)>
: public WFCallableTraits<RetT(ArgTs...)> {};

template <typename RetT, typename... ArgTs>
struct WFCallableTraits<RetT (&)(ArgTs...)>
: public WFCallableTraits<RetT(ArgTs...)> {};

template <typename ClassT, typename RetT, typename... ArgTs>
struct WFCallableTraits<RetT (ClassT::*)(ArgTs...)>
: public WFCallableTraits<RetT(ArgTs...)> {};

template <typename ClassT, typename RetT, typename... ArgTs>
struct WFCallableTraits<RetT (ClassT::*)(ArgTs...) const>
: public WFCallableTraits<RetT(ArgTs...)> {};
template <typename C>
using WFHandlerTraits = CallableTraitsHelper<WFHandlerTraitsImpl, C>;

template <typename Serializer> class StructuredYieldBase {
public:
Expand All @@ -151,8 +131,11 @@ template <typename Serializer> class StructuredYieldBase {
std::decay_t<Serializer> S;
};

template <typename RetT, typename Serializer> class StructuredYield;

template <typename RetT, typename Serializer>
class StructuredYield : public StructuredYieldBase<Serializer> {
class StructuredYield<std::tuple<RetT>, Serializer>
: public StructuredYieldBase<Serializer> {
public:
using StructuredYieldBase<Serializer>::StructuredYieldBase;
void operator()(RetT &&R) {
Expand All @@ -167,7 +150,7 @@ class StructuredYield : public StructuredYieldBase<Serializer> {
};

template <typename Serializer>
class StructuredYield<void, Serializer>
class StructuredYield<std::tuple<>, Serializer>
: public StructuredYieldBase<Serializer> {
public:
using StructuredYieldBase<Serializer>::StructuredYieldBase;
Expand All @@ -180,7 +163,7 @@ class StructuredYield<void, Serializer>
template <typename T, typename Serializer> struct ResultDeserializer;

template <typename T, typename Serializer>
struct ResultDeserializer<Expected<T>, Serializer> {
struct ResultDeserializer<std::tuple<Expected<T>>, Serializer> {
static Expected<T> deserialize(WrapperFunctionBuffer ResultBytes,
Serializer &S) {
T Val;
Expand All @@ -191,7 +174,8 @@ struct ResultDeserializer<Expected<T>, Serializer> {
}
};

template <typename Serializer> struct ResultDeserializer<Error, Serializer> {
template <typename Serializer>
struct ResultDeserializer<std::tuple<Error>, Serializer> {
static Error deserialize(WrapperFunctionBuffer ResultBytes, Serializer &S) {
assert(ResultBytes.empty());
return Error::success();
Expand All @@ -213,11 +197,13 @@ struct WrapperFunction {
typename... ArgTs>
static void call(Caller &&C, Serializer &&S, ResultHandler &&RH,
ArgTs &&...Args) {
typedef detail::WFCallableTraits<ResultHandler> ResultHandlerTraits;
typedef CallableArgInfo<ResultHandler> ResultHandlerTraits;
static_assert(std::is_void_v<typename ResultHandlerTraits::return_type>,
"Result handler should return void");
static_assert(
std::tuple_size_v<typename ResultHandlerTraits::TailArgTuple> == 0,
"Expected one argument to result-handler");
typedef typename ResultHandlerTraits::HeadArgType ResultType;
std::tuple_size_v<typename ResultHandlerTraits::args_tuple_type> == 1,
"Result-handler should have exactly one argument");
typedef typename ResultHandlerTraits::args_tuple_type ResultTupleType;

if (auto ArgBytes = S.argumentSerializer()(std::forward<ArgTs>(Args)...)) {
C(
Expand All @@ -227,9 +213,8 @@ struct WrapperFunction {
if (const char *ErrMsg = ResultBytes.getOutOfBandError())
RH(make_error<StringError>(ErrMsg));
else
RH(detail::ResultDeserializer<
ResultType, Serializer>::deserialize(std::move(ResultBytes),
S));
RH(detail::ResultDeserializer<ResultTupleType, Serializer>::
deserialize(std::move(ResultBytes), S));
},
std::move(*ArgBytes));
} else
Expand All @@ -246,10 +231,12 @@ struct WrapperFunction {
orc_rt_WrapperFunctionReturn Return,
WrapperFunctionBuffer ArgBytes, Serializer &&S,
Handler &&H) {
typedef detail::WFCallableTraits<Handler> HandlerTraits;
typedef typename HandlerTraits::HeadArgType Yield;
typedef typename HandlerTraits::TailArgTuple ArgTuple;
typedef typename detail::WFCallableTraits<Yield>::HeadArgType RetType;
typedef detail::WFHandlerTraits<Handler> HandlerTraits;
typedef typename HandlerTraits::ArgTupleType ArgTuple;
typedef typename HandlerTraits::YieldType Yield;
static_assert(std::is_void_v<typename CallableArgInfo<Yield>::return_type>,
"Return callback must return void");
typedef typename CallableArgInfo<Yield>::args_tuple_type RetTupleType;

if (ArgBytes.getOutOfBandError())
return Return(Session, CallCtx, ArgBytes.release());
Expand All @@ -258,7 +245,7 @@ struct WrapperFunction {
if (std::apply(bind_front(S.argumentDeserializer(), std::move(ArgBytes)),
Args))
std::apply(bind_front(std::forward<Handler>(H),
detail::StructuredYield<RetType, Serializer>(
detail::StructuredYield<RetTupleType, Serializer>(
Session, CallCtx, Return, std::move(S))),
std::move(Args));
else
Expand Down
1 change: 1 addition & 0 deletions orc-rt/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ endfunction()
add_orc_rt_unittest(CoreTests
AllocActionTest.cpp
BitmaskEnumTest.cpp
CallableTraitsHelperTest.cpp
CommonTestUtils.cpp
ErrorTest.cpp
ExecutorAddressTest.cpp
Expand Down
69 changes: 69 additions & 0 deletions orc-rt/unittests/CallableTraitsHelperTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//===- CallableTraitsHelperTest.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Tests for orc-rt's CallableTraitsHelper.h APIs.
//
// NOTE: All tests in this file are testing compile-time functionality, so the
// tests at runtime all end up being noops. That's fine -- those are
// cheap.
//===----------------------------------------------------------------------===//

#include "orc-rt/CallableTraitsHelper.h"
#include "gtest/gtest.h"

using namespace orc_rt;

static void freeVoidVoid() {}

TEST(CallableTraitsHelperTest, FreeVoidVoid) {
(void)freeVoidVoid;
typedef CallableArgInfo<decltype(freeVoidVoid)> CAI;
static_assert(std::is_void_v<CAI::return_type>);
static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<>>);
}

static int freeBinaryOp(int, float) { return 0; }

TEST(CallableTraitsHelperTest, FreeBinaryOp) {
(void)freeBinaryOp;
typedef CallableArgInfo<decltype(freeBinaryOp)> CAI;
static_assert(std::is_same_v<CAI::return_type, int>);
static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int, float>>);
}

TEST(CallableTraitsHelperTest, VoidVoidObj) {
auto VoidVoid = []() {};
typedef CallableArgInfo<decltype(VoidVoid)> CAI;
static_assert(std::is_void_v<CAI::return_type>);
static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<>>);
}

TEST(CallableTraitsHelperTest, BinaryOpObj) {
auto BinaryOp = [](int X, float Y) -> int { return X + Y; };
typedef CallableArgInfo<decltype(BinaryOp)> CAI;
static_assert(std::is_same_v<CAI::return_type, int>);
static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int, float>>);
}

TEST(CallableTraitsHelperTest, PreservesLValueRef) {
auto RefOp = [](int &) {};
typedef CallableArgInfo<decltype(RefOp)> CAI;
static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int &>>);
}

TEST(CallableTraitsHelperTest, PreservesLValueRefConstness) {
auto RefOp = [](const int &) {};
typedef CallableArgInfo<decltype(RefOp)> CAI;
static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<const int &>>);
}

TEST(CallableTraitsHelperTest, PreservesRValueRef) {
auto RefOp = [](int &&) {};
typedef CallableArgInfo<decltype(RefOp)> CAI;
static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int &&>>);
}
Loading