Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 149 additions & 62 deletions runtime/function_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@
#ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_
#define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_

#include <cstddef>
#include <functional>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/functional/any_invocable.h"
#include "absl/functional/bind_front.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "common/function_descriptor.h"
#include "common/kind.h"
#include "common/value.h"
#include "internal/status_macros.h"
#include "runtime/function.h"
Expand Down Expand Up @@ -94,79 +95,73 @@ struct AdaptedTypeTraits<const T&> {
static T ToArg(AssignableType v) { return v; }
};

template <typename... Args>
struct KindAdderImpl;

template <typename Arg, typename... Args>
struct KindAdderImpl<Arg, Args...> {
static void AddTo(std::vector<cel::Kind>& args) {
args.push_back(AdaptedKind<Arg>());
KindAdderImpl<Args...>::AddTo(args);
template <size_t I, typename... Args>
struct AdaptHelperImpl {
template <typename T>
static absl::Status Apply(absl::Span<const Value> input, T& output) {
static_assert(sizeof...(Args) > 0);
static_assert(std::tuple_size_v<T> == sizeof...(Args));
CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[I]}(&std::get<I>(output)));
if constexpr (I == sizeof...(Args) - 1) {
return absl::OkStatus();
} else {
CEL_RETURN_IF_ERROR(
(AdaptHelperImpl<I + 1, Args...>::template Apply<T>(input, output)));
}
return absl::OkStatus();
}
};

template <>
struct KindAdderImpl<> {
static void AddTo(std::vector<cel::Kind>& args) {}
};

template <typename... Args>
struct KindAdder {
static std::vector<cel::Kind> Kinds() {
std::vector<cel::Kind> args;
KindAdderImpl<Args...>::AddTo(args);
return args;
struct AdaptHelper {
template <typename T>
static absl::Status Apply(absl::Span<const Value> input, T& output) {
return AdaptHelperImpl<0, Args...>::template Apply<T>(input, output);
}
};

template <typename T>
struct ApplyReturnType {
using type = absl::StatusOr<T>;
};

template <typename T>
struct ApplyReturnType<absl::StatusOr<T>> {
using type = absl::StatusOr<T>;
};

template <int N, typename Arg, typename... Args>
struct IndexerImpl {
using type = typename IndexerImpl<N - 1, Args...>::type;
};

template <typename Arg, typename... Args>
struct IndexerImpl<0, Arg, Args...> {
using type = Arg;
};
template <typename... Args>
struct ToArgsImpl {
template <int I, typename T>
struct El {
using type = T;
constexpr static size_t index = I;
};

template <int N, typename... Args>
struct Indexer {
static_assert(N < sizeof...(Args) && N >= 0);
using type = typename IndexerImpl<N, Args...>::type;
};
template <typename... Es>
struct ZipHolder {
template <typename ResultType, typename TupleType, typename Op>
static ResultType ToArgs(
Op&& op, const TupleType& argbuffer,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
return std::forward<Op>(op)(
runtime_internal::AdaptedTypeTraits<typename Es::type>::ToArg(
std::get<Es::index>(argbuffer))...,
descriptor_pool, message_factory, arena);
}
};

template <int N, typename... Args>
struct ApplyHelper {
template <typename T, typename Op>
static typename ApplyReturnType<T>::type Apply(
Op&& op, absl::Span<const Value> input) {
constexpr int idx = sizeof...(Args) - N;
using Arg = typename Indexer<idx, Args...>::type;
using ArgTraits = AdaptedTypeTraits<Arg>;
typename ArgTraits::AssignableType arg_i;
CEL_RETURN_IF_ERROR(HandleToAdaptedVisitor{input[idx]}(&arg_i));

return ApplyHelper<N - 1, Args...>::template Apply<T>(
absl::bind_front(std::forward<Op>(op), ArgTraits::ToArg(arg_i)), input);
template <size_t... Is>
static ZipHolder<El<Is, Args>...> MakeZip(const std::index_sequence<Is...>&) {
return ZipHolder<El<Is, Args>...>{};
}
};

template <typename... Args>
struct ApplyHelper<0, Args...> {
template <typename T, typename Op>
static typename ApplyReturnType<T>::type Apply(
Op&& op, absl::Span<const Value> input) {
return op();
struct ToArgsHelper {
template <typename ResultType, typename TupleType, typename Op>
static ResultType Apply(
Op&& op, const TupleType& argbuffer,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) {
using Impl = ToArgsImpl<Args...>;
using Zip = decltype(Impl::MakeZip(std::index_sequence_for<Args...>{}));
return Zip::template ToArgs<ResultType>(std::forward<Op>(op), argbuffer,
descriptor_pool, message_factory,
arena);
}
};

Expand Down Expand Up @@ -629,6 +624,98 @@ class QuaternaryFunctionAdapter
};
};

// Primary template for n-ary adapter.
template <typename T, typename... Args>
class NaryFunctionAdapter;

template <typename T>
class NaryFunctionAdapter<T> : public NullaryFunctionAdapter<T> {};

template <typename T, typename U>
class NaryFunctionAdapter<T, U> : public UnaryFunctionAdapter<T, U> {};

template <typename T, typename U, typename V>
class NaryFunctionAdapter<T, U, V> : public BinaryFunctionAdapter<T, U, V> {};

template <typename T, typename U, typename V, typename W>
class NaryFunctionAdapter<T, U, V, W>
: public TernaryFunctionAdapter<T, U, V, W> {};

template <typename T, typename U, typename V, typename W, typename X>
class NaryFunctionAdapter<T, U, V, W, X>
: public QuaternaryFunctionAdapter<T, U, V, W, X> {};

// N-ary function adapter.
//
// Prefer using one of the specific count adapters above for readability and
// better error messages.
template <typename T, typename... Args>
class NaryFunctionAdapter
: public RegisterHelper<NaryFunctionAdapter<T, Args...>> {
public:
using FunctionType = absl::AnyInvocable<T(
Args..., const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const>;

static FunctionDescriptor CreateDescriptor(absl::string_view name,
bool receiver_style,
bool is_strict = true) {
return FunctionDescriptor(name, receiver_style,
{runtime_internal::AdaptedKind<Args>()...},
is_strict);
}

static std::unique_ptr<cel::Function> WrapFunction(FunctionType fn) {
return std::make_unique<NaryFunctionImpl>(std::move(fn));
}

static std::unique_ptr<cel::Function> WrapFunction(
absl::AnyInvocable<T(Args...) const> function) {
return WrapFunction(
[function = std::move(function)](
Args... args, const google::protobuf::DescriptorPool* absl_nonnull,
google::protobuf::MessageFactory* absl_nonnull,
google::protobuf::Arena* absl_nonnull) -> T { return function(args...); });
}

private:
class NaryFunctionImpl : public cel::Function {
private:
using ArgBuffer = std::tuple<
typename runtime_internal::AdaptedTypeTraits<Args>::AssignableType...>;

public:
explicit NaryFunctionImpl(FunctionType fn) : fn_(std::move(fn)) {}
absl::StatusOr<Value> Invoke(
absl::Span<const Value> args,
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
google::protobuf::MessageFactory* absl_nonnull message_factory,
google::protobuf::Arena* absl_nonnull arena) const override {
if (args.size() != sizeof...(Args)) {
return absl::InvalidArgumentError(
absl::StrCat("unexpected number of arguments for ", sizeof...(Args),
"-ary function"));
}
ArgBuffer arg_buffer;
CEL_RETURN_IF_ERROR(
runtime_internal::AdaptHelper<Args...>::Apply(args, arg_buffer));
if constexpr (std::is_same_v<T, Value> ||
std::is_same_v<T, absl::StatusOr<Value>>) {
return runtime_internal::ToArgsHelper<Args...>::template Apply<T>(
fn_, arg_buffer, descriptor_pool, message_factory, arena);
} else {
T result = runtime_internal::ToArgsHelper<Args...>::template Apply<T>(
fn_, arg_buffer, descriptor_pool, message_factory, arena);
return runtime_internal::AdaptedToHandleVisitor{}(std::move(result));
}
}

private:
FunctionType fn_;
};
};

} // namespace cel

#endif // THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_ADAPTER_H_
Loading