diff --git a/runtime/function_adapter.h b/runtime/function_adapter.h index e64edd232..641ad0609 100644 --- a/runtime/function_adapter.h +++ b/runtime/function_adapter.h @@ -18,20 +18,21 @@ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ +#include #include #include +#include +#include #include #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" -#include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/function_descriptor.h" -#include "common/kind.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function.h" @@ -94,79 +95,73 @@ struct AdaptedTypeTraits { static T ToArg(AssignableType v) { return v; } }; -template -struct KindAdderImpl; - -template -struct KindAdderImpl { - static void AddTo(std::vector& args) { - args.push_back(AdaptedKind()); - KindAdderImpl::AddTo(args); +template +struct AdaptHelperImpl { + template + static absl::Status Apply(absl::Span input, T& output) { + static_assert(sizeof...(Args) > 0); + static_assert(std::tuple_size_v == sizeof...(Args)); + CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[I]}(&std::get(output))); + if constexpr (I == sizeof...(Args) - 1) { + return absl::OkStatus(); + } else { + CEL_RETURN_IF_ERROR( + (AdaptHelperImpl::template Apply(input, output))); + } + return absl::OkStatus(); } }; -template <> -struct KindAdderImpl<> { - static void AddTo(std::vector& args) {} -}; - template -struct KindAdder { - static std::vector Kinds() { - std::vector args; - KindAdderImpl::AddTo(args); - return args; +struct AdaptHelper { + template + static absl::Status Apply(absl::Span input, T& output) { + return AdaptHelperImpl<0, Args...>::template Apply(input, output); } }; -template -struct ApplyReturnType { - using type = absl::StatusOr; -}; - -template -struct ApplyReturnType> { - using type = absl::StatusOr; -}; - -template -struct IndexerImpl { - using type = typename IndexerImpl::type; -}; - -template -struct IndexerImpl<0, Arg, Args...> { - using type = Arg; -}; +template +struct ToArgsImpl { + template + struct El { + using type = T; + constexpr static size_t index = I; + }; -template -struct Indexer { - static_assert(N < sizeof...(Args) && N >= 0); - using type = typename IndexerImpl::type; -}; + template + struct ZipHolder { + template + static ResultType ToArgs( + Op&& op, const TupleType& argbuffer, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return std::forward(op)( + runtime_internal::AdaptedTypeTraits::ToArg( + std::get(argbuffer))..., + descriptor_pool, message_factory, arena); + } + }; -template -struct ApplyHelper { - template - static typename ApplyReturnType::type Apply( - Op&& op, absl::Span input) { - constexpr int idx = sizeof...(Args) - N; - using Arg = typename Indexer::type; - using ArgTraits = AdaptedTypeTraits; - typename ArgTraits::AssignableType arg_i; - CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[idx]}(&arg_i)); - - return ApplyHelper::template Apply( - absl::bind_front(std::forward(op), ArgTraits::ToArg(arg_i)), input); + template + static ZipHolder...> MakeZip(const std::index_sequence&) { + return ZipHolder...>{}; } }; template -struct ApplyHelper<0, Args...> { - template - static typename ApplyReturnType::type Apply( - Op&& op, absl::Span input) { - return op(); +struct ToArgsHelper { + template + static ResultType Apply( + Op&& op, const TupleType& argbuffer, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + using Impl = ToArgsImpl; + using Zip = decltype(Impl::MakeZip(std::index_sequence_for{})); + return Zip::template ToArgs(std::forward(op), argbuffer, + descriptor_pool, message_factory, + arena); } }; @@ -629,6 +624,98 @@ class QuaternaryFunctionAdapter }; }; +// Primary template for n-ary adapter. +template +class NaryFunctionAdapter; + +template +class NaryFunctionAdapter : public NullaryFunctionAdapter {}; + +template +class NaryFunctionAdapter : public UnaryFunctionAdapter {}; + +template +class NaryFunctionAdapter : public BinaryFunctionAdapter {}; + +template +class NaryFunctionAdapter + : public TernaryFunctionAdapter {}; + +template +class NaryFunctionAdapter + : public QuaternaryFunctionAdapter {}; + +// N-ary function adapter. +// +// Prefer using one of the specific count adapters above for readability and +// better error messages. +template +class NaryFunctionAdapter + : public RegisterHelper> { + public: + using FunctionType = absl::AnyInvocable; + + static FunctionDescriptor CreateDescriptor(absl::string_view name, + bool receiver_style, + bool is_strict = true) { + return FunctionDescriptor(name, receiver_style, + {runtime_internal::AdaptedKind()...}, + is_strict); + } + + static std::unique_ptr WrapFunction(FunctionType fn) { + return std::make_unique(std::move(fn)); + } + + static std::unique_ptr WrapFunction( + absl::AnyInvocable function) { + return WrapFunction( + [function = std::move(function)]( + Args... args, const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull) -> T { return function(args...); }); + } + + private: + class NaryFunctionImpl : public cel::Function { + private: + using ArgBuffer = std::tuple< + typename runtime_internal::AdaptedTypeTraits::AssignableType...>; + + public: + explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {} + absl::StatusOr Invoke( + absl::Span args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const override { + if (args.size() != sizeof...(Args)) { + return absl::InvalidArgumentError( + absl::StrCat("unexpected number of arguments for ", sizeof...(Args), + "-ary function")); + } + ArgBuffer arg_buffer; + CEL_RETURN_IF_ERROR( + runtime_internal::AdaptHelper::Apply(args, arg_buffer)); + if constexpr (std::is_same_v || + std::is_same_v>) { + return runtime_internal::ToArgsHelper::template Apply( + fn_, arg_buffer, descriptor_pool, message_factory, arena); + } else { + T result = runtime_internal::ToArgsHelper::template Apply( + fn_, arg_buffer, descriptor_pool, message_factory, arena); + return runtime_internal::AdaptedToHandleVisitor{}(std::move(result)); + } + } + + private: + FunctionType fn_; + }; +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_ diff --git a/runtime/function_adapter_test.cc b/runtime/function_adapter_test.cc index 820a08600..aad5e4086 100644 --- a/runtime/function_adapter_test.cc +++ b/runtime/function_adapter_test.cc @@ -686,7 +686,7 @@ TEST_F(FunctionAdapterTest, BinaryFunctionAdapterCreateDescriptorNonStrict) { EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny, Kind::kAny)); } -TEST_F(FunctionAdapterTest, VariadicFunctionAdapterCreateDescriptor0Args) { +TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor0Args) { FunctionDescriptor desc = NullaryFunctionAdapter>::CreateDescriptor( "ZeroArgs", false); @@ -697,7 +697,7 @@ TEST_F(FunctionAdapterTest, VariadicFunctionAdapterCreateDescriptor0Args) { EXPECT_THAT(desc.types(), IsEmpty()); } -TEST_F(FunctionAdapterTest, VariadicFunctionAdapterWrapFunction0Args) { +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction0Args) { std::unique_ptr fn = NullaryFunctionAdapter>::WrapFunction( []() { return StringValue("abc"); }); @@ -708,7 +708,7 @@ TEST_F(FunctionAdapterTest, VariadicFunctionAdapterWrapFunction0Args) { EXPECT_EQ(result.GetString().ToString(), "abc"); } -TEST_F(FunctionAdapterTest, VariadicFunctionAdapterCreateDescriptor3Args) { +TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor3Args) { FunctionDescriptor desc = TernaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::CreateDescriptor("MyFormatter", false); @@ -720,8 +720,8 @@ TEST_F(FunctionAdapterTest, VariadicFunctionAdapterCreateDescriptor3Args) { ElementsAre(Kind::kInt64, Kind::kBool, Kind::kString)); } -TEST_F(FunctionAdapterTest, VariadicFunctionAdapterWrapFunction3Args) { - std::unique_ptr fn = TernaryFunctionAdapter< +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3Args) { + std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val) @@ -738,9 +738,8 @@ TEST_F(FunctionAdapterTest, VariadicFunctionAdapterWrapFunction3Args) { EXPECT_EQ(result.GetString().ToString(), "42_false_abcd"); } -TEST_F(FunctionAdapterTest, - VariadicFunctionAdapterWrapFunction3ArgsBadArgType) { - std::unique_ptr fn = TernaryFunctionAdapter< +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3ArgsBadArgType) { + std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val) @@ -756,9 +755,8 @@ TEST_F(FunctionAdapterTest, HasSubstr("expected string value"))); } -TEST_F(FunctionAdapterTest, - VariadicFunctionAdapterWrapFunction3ArgsBadArgCount) { - std::unique_ptr fn = TernaryFunctionAdapter< +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction3ArgsBadArgCount) { + std::unique_ptr fn = NaryFunctionAdapter< absl::StatusOr, int64_t, bool, const StringValue&>::WrapFunction([](int64_t int_val, bool bool_val, const StringValue& string_val) @@ -773,5 +771,82 @@ TEST_F(FunctionAdapterTest, HasSubstr("unexpected number of arguments"))); } +TEST_F(FunctionAdapterTest, NaryFunctionAdapterCreateDescriptor5Args) { + FunctionDescriptor desc = + NaryFunctionAdapter, int64_t, bool, + const StringValue&, int64_t, + int64_t>::CreateDescriptor("MyFormatter", false); + + EXPECT_EQ(desc.name(), "MyFormatter"); + EXPECT_TRUE(desc.is_strict()); + EXPECT_FALSE(desc.receiver_style()); + EXPECT_THAT(desc.types(), + ElementsAre(Kind::kInt64, Kind::kBool, Kind::kString, + Kind::kInt64, Kind::kInt64)); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5Args) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, const StringValue&, int64_t, + int64_t>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val, + int64_t extra_arg, + int64_t extra_arg2) -> absl::StatusOr { + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString(), "_", extra_arg, + "_", extra_arg2)); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + args.emplace_back() = StringValue("abcd"); + args.push_back(IntValue(123)); + args.push_back(IntValue(456)); + ASSERT_OK_AND_ASSIGN(auto result, fn->Invoke(args, descriptor_pool(), + message_factory(), arena())); + ASSERT_TRUE(result->Is()); + EXPECT_EQ(result.GetString().ToString(), "42_false_abcd_123_456"); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5ArgsBadArgType) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, const StringValue&, int64_t, + int64_t>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val, + int64_t extra_arg, + int64_t extra_arg2) -> absl::StatusOr { + static_cast(extra_arg); + static_cast(extra_arg2); + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString())); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + args.emplace_back() = TimestampValue(absl::UnixEpoch()); + args.push_back(IntValue(123)); + args.push_back(IntValue(456)); + EXPECT_THAT(fn->Invoke(args, descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("expected string value"))); +} + +TEST_F(FunctionAdapterTest, NaryFunctionAdapterWrapFunction5ArgsBadArgCount) { + std::unique_ptr fn = NaryFunctionAdapter< + absl::StatusOr, int64_t, bool, const StringValue&, int64_t, + int64_t>::WrapFunction([](int64_t int_val, bool bool_val, + const StringValue& string_val, + int64_t extra_arg, + int64_t extra_arg2) -> absl::StatusOr { + static_cast(extra_arg); + static_cast(extra_arg2); + return StringValue(absl::StrCat(int_val, "_", (bool_val ? "true" : "false"), + "_", string_val.ToString())); + }); + + std::vector args{IntValue(42), BoolValue(false)}; + EXPECT_THAT(fn->Invoke(args, descriptor_pool(), message_factory(), arena()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unexpected number of arguments"))); +} + } // namespace } // namespace cel