Skip to content

Commit 77aad9d

Browse files
jnthntatumcopybara-github
authored andcommitted
Draft: add support for cloning cel::Expr.
PiperOrigin-RevId: 723626957
1 parent 148c393 commit 77aad9d

File tree

5 files changed

+149
-0
lines changed

5 files changed

+149
-0
lines changed

common/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ cc_library(
3333
"@com_google_absl//absl/algorithm:container",
3434
"@com_google_absl//absl/base:core_headers",
3535
"@com_google_absl//absl/base:no_destructor",
36+
"@com_google_absl//absl/functional:overload",
3637
"@com_google_absl//absl/strings:string_view",
3738
"@com_google_absl//absl/types:optional",
3839
"@com_google_absl//absl/types:span",

common/expr.cc

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,124 @@
1414

1515
#include "common/expr.h"
1616

17+
#include <vector>
18+
1719
#include "absl/base/no_destructor.h"
20+
#include "absl/functional/overload.h"
21+
#include "absl/types/variant.h"
22+
#include "common/constant.h"
1823

1924
namespace cel {
2025

26+
namespace {
27+
28+
struct CopyStackRecord {
29+
const Expr* src;
30+
Expr* dst;
31+
};
32+
33+
void CopyNode(CopyStackRecord element, std::vector<CopyStackRecord>& stack) {
34+
const Expr* src = element.src;
35+
Expr* dst = element.dst;
36+
dst->set_id(src->id());
37+
absl::visit(
38+
absl::Overload(
39+
[](const UnspecifiedExpr&) {},
40+
[=](const IdentExpr& i) {
41+
dst->mutable_ident_expr().set_name(i.name());
42+
},
43+
[=](const Constant& c) { dst->mutable_const_expr() = c; },
44+
[&](const SelectExpr& s) {
45+
dst->mutable_select_expr().set_field(s.field());
46+
dst->mutable_select_expr().set_test_only(s.test_only());
47+
48+
if (s.has_operand()) {
49+
stack.push_back({&s.operand(),
50+
&dst->mutable_select_expr().mutable_operand()});
51+
}
52+
},
53+
[&](const CallExpr& c) {
54+
dst->mutable_call_expr().set_function(c.function());
55+
if (c.has_target()) {
56+
stack.push_back(
57+
{&c.target(), &dst->mutable_call_expr().mutable_target()});
58+
}
59+
dst->mutable_call_expr().mutable_args().resize(c.args().size());
60+
for (int i = 0; i < dst->mutable_call_expr().mutable_args().size();
61+
++i) {
62+
stack.push_back(
63+
{&c.args()[i], &dst->mutable_call_expr().mutable_args()[i]});
64+
}
65+
},
66+
[&](const ListExpr& c) {
67+
auto& dst_list = dst->mutable_list_expr();
68+
dst_list.mutable_elements().resize(c.elements().size());
69+
for (int i = 0; i < src->list_expr().elements().size(); ++i) {
70+
dst_list.mutable_elements()[i].set_optional(
71+
c.elements()[i].optional());
72+
stack.push_back({&c.elements()[i].expr(),
73+
&dst_list.mutable_elements()[i].mutable_expr()});
74+
}
75+
},
76+
[&](const StructExpr& s) {
77+
auto& dst_struct = dst->mutable_struct_expr();
78+
dst_struct.mutable_fields().resize(s.fields().size());
79+
dst_struct.set_name(s.name());
80+
for (int i = 0; i < s.fields().size(); ++i) {
81+
dst_struct.mutable_fields()[i].set_optional(
82+
s.fields()[i].optional());
83+
dst_struct.mutable_fields()[i].set_name(s.fields()[i].name());
84+
dst_struct.mutable_fields()[i].set_id(s.fields()[i].id());
85+
stack.push_back(
86+
{&s.fields()[i].value(),
87+
&dst_struct.mutable_fields()[i].mutable_value()});
88+
}
89+
},
90+
[&](const MapExpr& c) {
91+
auto& dst_map = dst->mutable_map_expr();
92+
dst_map.mutable_entries().resize(c.entries().size());
93+
for (int i = 0; i < c.entries().size(); ++i) {
94+
dst_map.mutable_entries()[i].set_optional(
95+
c.entries()[i].optional());
96+
dst_map.mutable_entries()[i].set_id(c.entries()[i].id());
97+
stack.push_back({&c.entries()[i].key(),
98+
&dst_map.mutable_entries()[i].mutable_key()});
99+
stack.push_back({&c.entries()[i].value(),
100+
&dst_map.mutable_entries()[i].mutable_value()});
101+
}
102+
},
103+
[&](const ComprehensionExpr& c) {
104+
auto& dst_comprehension = dst->mutable_comprehension_expr();
105+
dst_comprehension.set_iter_var(c.iter_var());
106+
dst_comprehension.set_iter_var2(c.iter_var2());
107+
dst_comprehension.set_accu_var(c.accu_var());
108+
if (c.has_accu_init()) {
109+
stack.push_back(
110+
{&c.accu_init(), &dst_comprehension.mutable_accu_init()});
111+
}
112+
if (c.has_iter_range()) {
113+
stack.push_back(
114+
{&c.iter_range(), &dst_comprehension.mutable_iter_range()});
115+
}
116+
if (c.has_loop_condition()) {
117+
stack.push_back({&c.loop_condition(),
118+
&dst_comprehension.mutable_loop_condition()});
119+
}
120+
if (c.has_loop_step()) {
121+
stack.push_back(
122+
{&c.loop_step(), &dst_comprehension.mutable_loop_step()});
123+
}
124+
if (c.has_result()) {
125+
stack.push_back(
126+
{&c.result(), &dst_comprehension.mutable_result()});
127+
}
128+
}
129+
130+
),
131+
src->kind());
132+
}
133+
} // namespace
134+
21135
const UnspecifiedExpr& UnspecifiedExpr::default_instance() {
22136
static const absl::NoDestructor<UnspecifiedExpr> instance;
23137
return *instance;
@@ -63,4 +177,16 @@ const Expr& Expr::default_instance() {
63177
return *instance;
64178
}
65179

180+
Expr CloneExpr(const Expr& expr) {
181+
Expr result;
182+
std::vector<CopyStackRecord> stack;
183+
stack.push_back({&expr, &result});
184+
while (!stack.empty()) {
185+
CopyStackRecord element = stack.back();
186+
stack.pop_back();
187+
CopyNode(element, stack);
188+
}
189+
return result;
190+
}
191+
66192
} // namespace cel

common/expr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class ComprehensionExpr;
4848

4949
inline constexpr absl::string_view kAccumulatorVariableName = "__result__";
5050

51+
// Returns a deep copy of the given expression node.
52+
Expr CloneExpr(const Expr& expr);
53+
5154
bool operator==(const Expr& lhs, const Expr& rhs);
5255

5356
inline bool operator!=(const Expr& lhs, const Expr& rhs) {

extensions/protobuf/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ cc_test(
7676
":ast_converters",
7777
"//base/ast_internal:ast_impl",
7878
"//base/ast_internal:expr",
79+
"//common:expr",
7980
"//internal:proto_matchers",
8081
"//internal:testing",
8182
"//parser",

extensions/protobuf/ast_converters_test.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "absl/types/variant.h"
3232
#include "base/ast_internal/ast_impl.h"
3333
#include "base/ast_internal/expr.h"
34+
#include "common/expr.h"
3435
#include "internal/proto_matchers.h"
3536
#include "internal/testing.h"
3637
#include "parser/options.h"
@@ -801,6 +802,23 @@ TEST_P(ConversionRoundTripTest, ParsedExprCopyable) {
801802
IsOkAndHolds(EqualsProto(parsed_expr)));
802803
}
803804

805+
TEST_P(ConversionRoundTripTest, ExprClonable) {
806+
ASSERT_OK_AND_ASSIGN(ParsedExprPb parsed_expr,
807+
Parse(GetParam().expr, "<input>", options_));
808+
809+
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast,
810+
CreateAstFromParsedExpr(parsed_expr));
811+
812+
auto& impl = ast_internal::AstImpl::CastFromPublicAst(*ast);
813+
impl.root_expr() = CloneExpr(impl.root_expr());
814+
815+
EXPECT_THAT(CreateCheckedExprFromAst(impl),
816+
StatusIs(absl::StatusCode::kInvalidArgument,
817+
HasSubstr("AST is not type-checked")));
818+
EXPECT_THAT(CreateParsedExprFromAst(impl),
819+
IsOkAndHolds(EqualsProto(parsed_expr)));
820+
}
821+
804822
TEST_P(ConversionRoundTripTest, CheckedExprCopyable) {
805823
ASSERT_OK_AND_ASSIGN(ParsedExprPb parsed_expr,
806824
Parse(GetParam().expr, "<input>", options_));

0 commit comments

Comments
 (0)