Skip to content

Commit

Permalink
planner: Adding support for general ref rule heads (#6235)
Browse files Browse the repository at this point in the history
Fixes: #5995

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Sep 27, 2023
1 parent 391cb0c commit c9d1a8d
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 36 deletions.
14 changes: 14 additions & 0 deletions ast/term.go
Expand Up @@ -1032,6 +1032,20 @@ func (ref Ref) ConstantPrefix() Ref {
return ref[:i]
}

func (ref Ref) StringPrefix() Ref {
r := ref.Copy()

for i := 1; i < len(ref); i++ {
switch r[i].Value.(type) {
case String: // pass
default: // cut off
return r[:i]
}
}

return r
}

// GroundPrefix returns the ground portion of the ref starting from the head. By
// definition, the head of the reference is always ground.
func (ref Ref) GroundPrefix() Ref {
Expand Down
131 changes: 109 additions & 22 deletions internal/planner/planner.go
Expand Up @@ -25,6 +25,8 @@ type QuerySet struct {
}

type planiter func() error
type planLocalIter func(ir.Local) error
type stmtFactory func(ir.Local) ir.Stmt

// Planner implements a query planner for Rego queries.
type Planner struct {
Expand Down Expand Up @@ -147,32 +149,31 @@ func (p *Planner) buildFunctrie() error {
}

for _, rule := range module.Rules {
r := rule.Ref()
switch r[len(r)-1].Value.(type) {
case ast.String: // pass
default: // cut off
r = r[:len(r)-1]
}
r := rule.Ref().StringPrefix()
val := p.rules.LookupOrInsert(r)

val.rules = val.DescendantRules()
val.rules = append(val.rules, rule)
val.children = nil
}
}
return nil
}

func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
// We know the rules with closer to the root (shorter static path) are ordered first.
pathRef := rules[0].Ref()

// figure out what our rules' collective name/path is:
// if we're planning both p.q.r and p.q[s], we'll name
// the function p.q (for the mapping table)
// TODO(sr): this has to change when allowing `p[v].q.r[w]` ref rules
// including the mapping lookup structure and lookup functions
pieces := len(pathRef)
for i := range rules {
r := rules[i].Ref()
if _, ok := r[len(r)-1].Value.(ast.String); !ok {
pieces = len(r) - 1
for j, t := range r {
if _, ok := t.Value.(ast.String); !ok && j > 0 && j < pieces {
pieces = j
}
}
}
// control if p.a = 1 is to return 1 directly; or insert 1 under key "a" into an object
Expand Down Expand Up @@ -236,7 +237,11 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return}))
}
case ast.MultiValue:
fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeSetStmt{Target: fn.Return}))
if buildObject {
fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return}))
} else {
fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeSetStmt{Target: fn.Return}))
}
}

// For complete document rules, allocate one local variable for output
Expand All @@ -252,6 +257,12 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
var defaultRule *ast.Rule
var ruleLoc *location.Location

// We sort rules by ref length, to ensure that when merged, we can detect conflicts when one
// rule attempts to override values (deep and shallow) defined by another rule.
sort.Slice(rules, func(i, j int) bool {
return len(rules[i].Ref()) > len(rules[j].Ref())
})

// Generate function blocks for rules.
for i := range rules {

Expand Down Expand Up @@ -320,18 +331,19 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
switch rule.Head.RuleKind() {
case ast.SingleValue:
if buildObject {
ref := rule.Head.Ref()
last := ref[len(ref)-1]
return p.planTerm(last, func() error {
key := p.ltarget
return p.planTerm(rule.Head.Value, func() error {
value := p.ltarget
p.appendStmt(&ir.ObjectInsertOnceStmt{
Object: fn.Return,
Key: key,
Value: value,
ref := rule.Ref()
return p.planTerm(rule.Head.Value, func() error {
value := p.ltarget
return p.planNestedObjects(fn.Return, ref[pieces:len(ref)-1], func(obj ir.Local) error {
return p.planTerm(ref[len(ref)-1], func() error {
key := p.ltarget
p.appendStmt(&ir.ObjectInsertOnceStmt{
Object: obj,
Key: key,
Value: value,
})
return nil
})
return nil
})
})
}
Expand All @@ -343,6 +355,28 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
return nil
})
case ast.MultiValue:
if buildObject {
ref := rule.Ref()
// we drop the trailing set key from the ref
return p.planNestedObjects(fn.Return, ref[pieces:len(ref)-1], func(obj ir.Local) error {
// Last term on rule ref is the key an which the set is assigned in the deepest nested object
return p.planTerm(ref[len(ref)-1], func() error {
key := p.ltarget
return p.planTerm(rule.Head.Key, func() error {
value := p.ltarget
factory := func(v ir.Local) ir.Stmt { return &ir.MakeSetStmt{Target: v} }
return p.planDotOr(obj, key, factory, func(set ir.Local) error {
p.appendStmt(&ir.SetAddStmt{
Set: set,
Value: value,
})
p.appendStmt(&ir.ObjectInsertStmt{Key: key, Value: op(set), Object: obj})
return nil
})
})
})
})
}
return p.planTerm(rule.Head.Key, func() error {
p.appendStmt(&ir.SetAddStmt{
Set: fn.Return,
Expand Down Expand Up @@ -422,6 +456,59 @@ func (p *Planner) planRules(rules []*ast.Rule) (string, error) {
return fn.Name, nil
}

func (p *Planner) planDotOr(obj ir.Local, key ir.Operand, or stmtFactory, iter planLocalIter) error {
// We're constructing the following plan:
//
// | block a
// | | block b
// | | | dot &{Source:Local<obj> Key:{Value:Local<key>} Target:Local<val>}
// | | | break 1
// | | or &{Target:Local<val>}
// | | *ir.ObjectInsertOnceStmt &{Key:{Value:Local<key>} Value:{Value:Local<val>} Object:Local<obj>}

prev := p.curr
dotBlock := &ir.Block{}
p.curr = dotBlock

val := p.newLocal()
p.appendStmt(&ir.DotStmt{
Source: op(obj),
Key: key,
Target: val,
})
p.appendStmt(&ir.BreakStmt{Index: 1})

outerBlock := &ir.Block{
Stmts: []ir.Stmt{
&ir.BlockStmt{Blocks: []*ir.Block{dotBlock}}, // FIXME: Set Location
or(val),
&ir.ObjectInsertOnceStmt{Key: key, Value: op(val), Object: obj},
},
}

p.curr = prev
p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{outerBlock}})
return iter(val)
}

func (p *Planner) planNestedObjects(obj ir.Local, ref ast.Ref, iter planLocalIter) error {
if len(ref) == 0 {
//return fmt.Errorf("nested object construction didn't create object")
return iter(obj)
}

t := ref[0]

return p.planTerm(t, func() error {
key := p.ltarget

factory := func(v ir.Local) ir.Stmt { return &ir.MakeObjectStmt{Target: v} }
return p.planDotOr(obj, key, factory, func(childObj ir.Local) error {
return p.planNestedObjects(childObj, ref[1:], iter)
})
})
}

func (p *Planner) planFuncParams(params []ir.Local, args ast.Args, idx int, iter planiter) error {
if idx >= len(args) {
return iter()
Expand Down
41 changes: 41 additions & 0 deletions internal/planner/planner_test.go
Expand Up @@ -16,6 +16,7 @@ import (
)

func TestPlannerHelloWorld(t *testing.T) {
t.Setenv("EXPERIMENTAL_GENERAL_RULE_REFS", "true")

// NOTE(tsandall): These tests are not meant to give comprehensive coverage
// of the planner. Currently we have a suite of end-to-end tests in the
Expand Down Expand Up @@ -151,6 +152,46 @@ func TestPlannerHelloWorld(t *testing.T) {
p[v] = 2 { v := "b" }
`},
},
{
note: "partial object (ref-head) with var",
queries: []string{`data.test.p.q = x`},
modules: []string{`
package test
p.q.r["a"] = 1
p.q[v] = 2 { v := "b" }
`},
},
{
note: "partial object (ref-head) with var (shallow query)",
queries: []string{`data.test.p = x`},
modules: []string{`
package test
p.q["a"] = 1
p.q[v] = 2 { v := "b" }
p.r["c"] = 3
p.r[v] = 4 { v := "d" }
`},
},
{
note: "partial object (ref-head) with var (multiple)",
queries: []string{`data.test.p.q = x`},
modules: []string{`
package test
p.q["a"] = 1
p.q[v] = x { l1 := ["b", "c", "d"]; l2 := ["foo", "bar"]; l3 := [2, 3]; v := l1[_]; x := l2[_]; z := l3[_] }
`},
},
{
note: "partial object (general ref-head) with var",
queries: []string{`data.test.p.q = x`},
modules: []string{`
package test
p.q["a"] = 1
p.q.b.s.baz = 2
p.q.b.s.foo.c = 3
p.q[r].s[t].u = v { x := ["foo", "bar"]; r := "b"; t := x[v]}
`},
},
{
note: "every",
queries: []string{`data.test.p`},
Expand Down
36 changes: 35 additions & 1 deletion internal/planner/rules.go
Expand Up @@ -98,6 +98,7 @@ func (t *ruletrie) Rules() []*ast.Rule {
//
// and we're retrieving a.b, we want Rules() to include the rule body
// of a.b.c.
// FIXME: We need to go deeper than just immediate children (?)
for _, rs := range t.children {
if r := rs[len(rs)-1].rules; r != nil {
rules = append(rules, r...)
Expand Down Expand Up @@ -157,13 +158,46 @@ func (t *ruletrie) Lookup(key ast.Ref) *ruletrie {
return node
}

func (t *ruletrie) LookupShallowest(key ast.Ref) *ruletrie {
node := t
for _, elem := range key {
node = node.Get(elem.Value)
if node == nil {
return nil
}
if len(node.rules) > 0 {
return node
}
}
return node
}

// TODO: Collapse rules with overlapping extent to same node(?)
func (t *ruletrie) LookupOrInsert(key ast.Ref) *ruletrie {
if val := t.Lookup(key); val != nil {
if val := t.LookupShallowest(key); val != nil {

return val
}
return t.Insert(key)
}

func (t *ruletrie) DescendantRules() []*ast.Rule {
if len(t.children) == 0 {
return t.rules
}

rules := make([]*ast.Rule, len(t.rules), len(t.rules)+len(t.children)) // could be too little
copy(rules, t.rules)

for _, cs := range t.children {
for _, c := range cs {
rules = append(rules, c.DescendantRules()...)
}
}

return rules
}

func (t *ruletrie) ChildrenCount() int {
return len(t.children)
}
Expand Down
12 changes: 1 addition & 11 deletions internal/wasm/sdk/test/e2e/exceptions.yaml
Expand Up @@ -2,14 +2,4 @@
"data/toplevel integer": "https://github.com/open-policy-agent/opa/issues/3711"
"data/nested integer": "https://github.com/open-policy-agent/opa/issues/3711"
"withkeyword/function: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311"
"withkeyword/builtin: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311"
"refheads/general, single var": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, multiple vars": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, deep query": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, overlapping rule, no conflict": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, overlapping rule, conflict": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, set leaf": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, set leaf, deep query": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, input var": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, external non-ground var": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"refheads/general, multiple result-set entries": "Tests with arbitrary vars in rule refs (general refs) are not supported by the planner yet"
"withkeyword/builtin: indirect call, arity 1, replacement is value that needs eval (array comprehension)": "https://github.com/open-policy-agent/opa/issues/5311"
4 changes: 4 additions & 0 deletions internal/wasm/sdk/test/e2e/external_test.go
Expand Up @@ -65,6 +65,10 @@ func TestWasmE2E(t *testing.T) {
t.SkipNow()
}

for k, v := range tc.Env {
t.Setenv(k, v)
}

opts := []func(*rego.Rego){
rego.Query(tc.Query),
}
Expand Down
@@ -0,0 +1,22 @@
cases:
- note: partialobjectdoc/ref
modules:
- |
package generated
p.q[k] = v {
k := ["foo", "bar"][v]
}
p.baz := 2
q {
x := "bar"
y := "q"
p[y][x] == 1
}
query: data.generated.q = x
want_result:
- x:
true

0 comments on commit c9d1a8d

Please sign in to comment.