Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "eval/compiler/flat_expr_builder.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <deque>
Expand Down Expand Up @@ -668,6 +669,52 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
return absl::nullopt;
}

void MaybeMakeTernaryRecursive(const cel::ast_internal::Expr* expr) {
if (options_.max_recursion_depth == 0) {
return;
}
if (expr->call_expr().args().size() != 3) {
SetProgressStatusError(absl::InvalidArgumentError(
"unexpected number of args for builtin ternary"));
}

const cel::ast_internal::Expr* condition_expr =
&expr->call_expr().args()[0];
const cel::ast_internal::Expr* left_expr = &expr->call_expr().args()[1];
const cel::ast_internal::Expr* right_expr = &expr->call_expr().args()[2];

auto* condition_plan = program_builder_.GetSubexpression(condition_expr);
auto* left_plan = program_builder_.GetSubexpression(left_expr);
auto* right_plan = program_builder_.GetSubexpression(right_expr);

int max_depth = 0;
if (condition_plan == nullptr || !condition_plan->IsRecursive()) {
return;
}
max_depth = std::max(max_depth, condition_plan->recursive_program().depth);

if (left_plan == nullptr || !left_plan->IsRecursive()) {
return;
}
max_depth = std::max(max_depth, left_plan->recursive_program().depth);

if (right_plan == nullptr || !right_plan->IsRecursive()) {
return;
}
max_depth = std::max(max_depth, right_plan->recursive_program().depth);

if (max_depth >= options_.max_recursion_depth) {
return;
}

SetRecursiveStep(
CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step,
left_plan->ExtractRecursiveProgram().step,
right_plan->ExtractRecursiveProgram().step,
expr->id(), options_.short_circuiting),
max_depth + 1);
}

// Invoked after all child nodes are processed.
void PostVisitCall(const cel::ast_internal::Call* call_expr,
const cel::ast_internal::Expr* expr,
Expand All @@ -680,6 +727,9 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
if (cond_visitor) {
cond_visitor->PostVisit(expr);
cond_visitor_stack_.pop();
if (call_expr->function() == cel::builtin::kTernary) {
MaybeMakeTernaryRecursive(expr);
}
return;
}

Expand Down
41 changes: 41 additions & 0 deletions eval/eval/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,17 @@ cc_library(
"logic_step.h",
],
deps = [
":attribute_trail",
":direct_expression_step",
":evaluator_core",
":expression_step_base",
"//base:builtins",
"//common:casting",
"//common:value",
"//common:value_kind",
"//eval/internal:errors",
"//internal:status_macros",
"//runtime/internal:errors",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -642,17 +648,30 @@ cc_test(
"logic_step_test.cc",
],
deps = [
":attribute_trail",
":cel_expression_flat_impl",
":const_value_step",
":direct_expression_step",
":evaluator_core",
":ident_step",
":logic_step",
"//base:attributes",
"//base:data",
"//base/ast_internal:expr",
"//common:casting",
"//common:value",
"//eval/public:activation",
"//eval/public:unknown_attribute_set",
"//eval/public:unknown_set",
"//extensions/protobuf:memory_manager",
"//internal:status_macros",
"//internal:testing",
"//runtime:activation",
"//runtime:managed_value_factory",
"//runtime:runtime_options",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_protobuf//:protobuf",
],
)

Expand Down Expand Up @@ -816,10 +835,14 @@ cc_library(
":attribute_trail",
"//base:attributes",
"//base:function_descriptor",
"//base:function_result",
"//base:function_result_set",
"//base/internal:unknown_set",
"//common:value",
"//eval/internal:errors",
"//internal:status_macros",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
Expand Down Expand Up @@ -854,11 +877,16 @@ cc_library(
"ternary_step.h",
],
deps = [
":attribute_trail",
":direct_expression_step",
":evaluator_core",
":expression_step_base",
"//base:builtins",
"//common:casting",
"//common:value",
"//eval/internal:errors",
"//internal:status_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)
Expand All @@ -870,17 +898,30 @@ cc_test(
"ternary_step_test.cc",
],
deps = [
":attribute_trail",
":cel_expression_flat_impl",
":const_value_step",
":direct_expression_step",
":evaluator_core",
":ident_step",
":ternary_step",
"//base:attributes",
"//base:data",
"//base/ast_internal:expr",
"//common:casting",
"//common:value",
"//eval/public:activation",
"//eval/public:cel_value",
"//eval/public:unknown_attribute_set",
"//eval/public:unknown_set",
"//extensions/protobuf:memory_manager",
"//internal:status_macros",
"//internal:testing",
"//runtime:activation",
"//runtime:managed_value_factory",
"//runtime:runtime_options",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_protobuf//:protobuf",
],
)
Expand Down
125 changes: 125 additions & 0 deletions eval/eval/ternary_step.cc
Original file line number Diff line number Diff line change
@@ -1,26 +1,136 @@
#include "eval/eval/ternary_step.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "base/builtins.h"
#include "common/casting.h"
#include "common/value.h"
#include "eval/eval/attribute_trail.h"
#include "eval/eval/direct_expression_step.h"
#include "eval/eval/evaluator_core.h"
#include "eval/eval/expression_step_base.h"
#include "eval/internal/errors.h"
#include "internal/status_macros.h"

namespace google::api::expr::runtime {

namespace {

using ::cel::BoolValue;
using ::cel::Cast;
using ::cel::ErrorValue;
using ::cel::InstanceOf;
using ::cel::UnknownValue;
using ::cel::builtin::kTernary;
using ::cel::runtime_internal::CreateNoMatchingOverloadError;

inline constexpr size_t kTernaryStepCondition = 0;
inline constexpr size_t kTernaryStepTrue = 1;
inline constexpr size_t kTernaryStepFalse = 2;

class ExhaustiveDirectTernaryStep : public DirectExpressionStep {
public:
ExhaustiveDirectTernaryStep(std::unique_ptr<DirectExpressionStep> condition,
std::unique_ptr<DirectExpressionStep> left,
std::unique_ptr<DirectExpressionStep> right,
int64_t expr_id)
: DirectExpressionStep(expr_id),
condition_(std::move(condition)),
left_(std::move(left)),
right_(std::move(right)) {}

absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result,
AttributeTrail& attribute) const override {
cel::Value condition;
cel::Value lhs;
cel::Value rhs;

AttributeTrail condition_attr;
AttributeTrail lhs_attr;
AttributeTrail rhs_attr;

CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr));
CEL_RETURN_IF_ERROR(left_->Evaluate(frame, lhs, lhs_attr));
CEL_RETURN_IF_ERROR(right_->Evaluate(frame, rhs, rhs_attr));

if (InstanceOf<ErrorValue>(condition) ||
InstanceOf<UnknownValue>(condition)) {
result = std::move(condition);
attribute = std::move(condition_attr);
return absl::OkStatus();
}

if (!InstanceOf<BoolValue>(condition)) {
result = frame.value_manager().CreateErrorValue(
CreateNoMatchingOverloadError(kTernary));
return absl::OkStatus();
}

if (Cast<BoolValue>(condition).NativeValue()) {
result = std::move(lhs);
attribute = std::move(lhs_attr);
} else {
result = std::move(rhs);
attribute = std::move(rhs_attr);
}
return absl::OkStatus();
}

private:
std::unique_ptr<DirectExpressionStep> condition_;
std::unique_ptr<DirectExpressionStep> left_;
std::unique_ptr<DirectExpressionStep> right_;
};

class ShortcircuitingDirectTernaryStep : public DirectExpressionStep {
public:
ShortcircuitingDirectTernaryStep(
std::unique_ptr<DirectExpressionStep> condition,
std::unique_ptr<DirectExpressionStep> left,
std::unique_ptr<DirectExpressionStep> right, int64_t expr_id)
: DirectExpressionStep(expr_id),
condition_(std::move(condition)),
left_(std::move(left)),
right_(std::move(right)) {}

absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result,
AttributeTrail& attribute) const override {
cel::Value condition;

AttributeTrail condition_attr;

CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr));

if (InstanceOf<ErrorValue>(condition) ||
InstanceOf<UnknownValue>(condition)) {
result = std::move(condition);
attribute = std::move(condition_attr);
return absl::OkStatus();
}

if (!InstanceOf<BoolValue>(condition)) {
result = frame.value_manager().CreateErrorValue(
CreateNoMatchingOverloadError(kTernary));
return absl::OkStatus();
}

if (Cast<BoolValue>(condition).NativeValue()) {
return left_->Evaluate(frame, result, attribute);
}
return right_->Evaluate(frame, result, attribute);
}

private:
std::unique_ptr<DirectExpressionStep> condition_;
std::unique_ptr<DirectExpressionStep> left_;
std::unique_ptr<DirectExpressionStep> right_;
};

class TernaryStep : public ExpressionStepBase {
public:
// Constructs FunctionStep that uses overloads specified.
Expand Down Expand Up @@ -72,6 +182,21 @@ absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const {

} // namespace

// Factory method for ternary (_?_:_) recursive execution step
std::unique_ptr<DirectExpressionStep> CreateDirectTernaryStep(
std::unique_ptr<DirectExpressionStep> condition,
std::unique_ptr<DirectExpressionStep> left,
std::unique_ptr<DirectExpressionStep> right, int64_t expr_id,
bool shortcircuiting) {
if (shortcircuiting) {
return std::make_unique<ShortcircuitingDirectTernaryStep>(
std::move(condition), std::move(left), std::move(right), expr_id);
}

return std::make_unique<ExhaustiveDirectTernaryStep>(
std::move(condition), std::move(left), std::move(right), expr_id);
}

absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateTernaryStep(
int64_t expr_id) {
return std::make_unique<TernaryStep>(expr_id);
Expand Down
9 changes: 9 additions & 0 deletions eval/eval/ternary_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_

#include <cstdint>
#include <memory>

#include "absl/status/statusor.h"
#include "eval/eval/direct_expression_step.h"
#include "eval/eval/evaluator_core.h"

namespace google::api::expr::runtime {

// Factory method for ternary (_?_:_) recursive execution step
std::unique_ptr<DirectExpressionStep> CreateDirectTernaryStep(
std::unique_ptr<DirectExpressionStep> condition,
std::unique_ptr<DirectExpressionStep> left,
std::unique_ptr<DirectExpressionStep> right, int64_t expr_id,
bool shortcircuiting = true);

// Factory method for ternary (_?_:_) execution step
absl::StatusOr<std::unique_ptr<ExpressionStep>> CreateTernaryStep(
int64_t expr_id);
Expand Down
Loading