Skip to content

Commit

Permalink
topdown/copypropagation: keep refs into livevars (#4936)
Browse files Browse the repository at this point in the history
Before, a query of

    input.a == input.a

would not survive copypropagation.

With this change, it'll be recorded as removedEq, and subsequent processing
steps ensure that it's kept in the body.

Changing the sort order in sortBindings allows us to limit the unnecessary
variable bindings: with the previous ordering, we'd get

    __local0__1 = input; __localcp0__ = input.a

for the query `x := input; input.a == input.a`. Sorting the other way, we'll
process `__localcp0__ = input.a` first, add it to the body, and when we check
`__local0__1 = input`, we find that `input` is already contained in the body,
and is thus not needed.

Fixes #4848.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Jul 26, 2022
1 parent 7f78653 commit 1c1957c
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 43 deletions.
7 changes: 4 additions & 3 deletions rego/rego.go
Expand Up @@ -2147,9 +2147,10 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,

var unknowns []*ast.Term

if ectx.parsedUnknowns != nil {
switch {
case ectx.parsedUnknowns != nil:
unknowns = ectx.parsedUnknowns
} else if ectx.unknowns != nil {
case ectx.unknowns != nil:
unknowns = make([]*ast.Term, len(ectx.unknowns))
for i := range ectx.unknowns {
var err error
Expand All @@ -2158,7 +2159,7 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,
return nil, err
}
}
} else {
default:
// Use input document as unknown if caller has not specified any.
unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)}
}
Expand Down
58 changes: 47 additions & 11 deletions topdown/copypropagation/copypropagation.go
Expand Up @@ -5,6 +5,7 @@
package copypropagation

import (
"fmt"
"sort"

"github.com/open-policy-agent/opa/ast"
Expand All @@ -31,6 +32,18 @@ type CopyPropagator struct {
sorted []ast.Var // sorted copy of vars to ensure deterministic result
ensureNonEmptyBody bool
compiler *ast.Compiler
localvargen *localVarGenerator
}

type localVarGenerator struct {
next int
}

func (l *localVarGenerator) Generate() ast.Var {
result := ast.Var(fmt.Sprintf("__localcp%d__", l.next))
l.next++
return result

}

// New returns a new CopyPropagator that optimizes queries while preserving vars
Expand All @@ -46,7 +59,7 @@ func New(livevars ast.VarSet) *CopyPropagator {
return sorted[i].Compare(sorted[j]) < 0
})

return &CopyPropagator{livevars: livevars, sorted: sorted}
return &CopyPropagator{livevars: livevars, sorted: sorted, localvargen: &localVarGenerator{}}
}

// WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
Expand Down Expand Up @@ -282,12 +295,16 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.
// updateBindings returns false if the expression can be killed. If the
// expression is killed, the binding list is updated to map a var to value.
func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
if pctx.negated || len(expr.With) > 0 {
switch {
case pctx.negated || len(expr.With) > 0:
return true
}
if expr.IsEquality() {

case expr.IsEquality():
a, b := expr.Operand(0), expr.Operand(1)
if a.Equal(b) {
if p.livevarRef(a) {
pctx.removedEqs.Put(p.localvargen.Generate(), a.Value)
}
return false
}
k, v, keep := p.updateBindingsEq(a, b)
Expand All @@ -297,7 +314,8 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool
}
return false
}
} else if expr.IsCall() {

case expr.IsCall():
terms := expr.Terms.([]*ast.Term)
if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output
output := terms[len(terms)-1]
Expand All @@ -310,6 +328,21 @@ func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool
return !isNoop(expr)
}

func (p *CopyPropagator) livevarRef(a *ast.Term) bool {
ref, ok := a.Value.(ast.Ref)
if !ok {
return false
}

for _, v := range p.sorted {
if ref[0].Value.Compare(v) == 0 {
return true
}
}

return false
}

func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
k, v, keep := p.updateBindingsEqAsymmetric(a, b)
if !keep {
Expand Down Expand Up @@ -340,8 +373,7 @@ type plugContext struct {
}

type binding struct {
k ast.Value
v ast.Value
k, v ast.Value
}

func containedIn(value ast.Value, x interface{}) bool {
Expand Down Expand Up @@ -374,7 +406,7 @@ func sortbindings(bindings *ast.ValueMap) []*binding {
return false
})
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].k.Compare(sorted[j].k) < 0
return sorted[i].k.Compare(sorted[j].k) > 0
})
return sorted
}
Expand All @@ -397,17 +429,21 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
a, b := expr.Operand(0), expr.Operand(1)
varA, ok1 := a.Value.(ast.Var)
varB, ok2 := b.Value.(ast.Var)
if ok1 && ok2 {

switch {
case ok1 && ok2:
if _, ok := uf.Merge(varA, varB); !ok {
return nil, false
}
} else if ok1 && ast.IsConstant(b.Value) {

case ok1 && ast.IsConstant(b.Value):
root := uf.MakeSet(varA)
if root.constant != nil && !root.constant.Equal(b) {
return nil, false
}
root.constant = b
} else if ok2 && ast.IsConstant(a.Value) {

case ok2 && ast.IsConstant(a.Value):
root := uf.MakeSet(varB)
if root.constant != nil && !root.constant.Equal(a) {
return nil, false
Expand Down
8 changes: 8 additions & 0 deletions topdown/query.go
Expand Up @@ -340,6 +340,14 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
defer q.metrics.Timer(metrics.RegoPartialEval).Stop()

livevars := ast.NewVarSet()
for _, t := range q.unknowns {
switch v := t.Value.(type) {
case ast.Var:
livevars.Add(v)
case ast.Ref:
livevars.Add(v[0].Value.(ast.Var))
}
}

ast.WalkVars(q.query, func(x ast.Var) bool {
if !x.IsGenerated() {
Expand Down
129 changes: 100 additions & 29 deletions topdown/topdown_partial_test.go
Expand Up @@ -1826,6 +1826,100 @@ func TestTopDownPartialEval(t *testing.T) {
),
},
},
{
note: "copy propagation: circular reference (bug 3559)",
query: "data.test.p",
modules: []string{`package test
p {
q[_]
}
q[x] {
x = input[x]
}`,
},
wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`},
},
{
note: "copy propagation: circular reference (bug 3071)",
query: "data.test.p",
modules: []string{`package test
p[y] {
s := { i | input[i] }
s & set() != s
y := sprintf("%v", [s])
}`,
},
wantQueries: []string{`data.partial.test.p`},
wantSupport: []string{`package partial.test
p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) }
`},
},
{
note: "copy propagation: tautology in query, input ref",
query: "input.a == input.a",
wantQueries: []string{`__localq1__ = input.a`},
},
{
note: "copy propagation: tautology in query, var ref, var is input",
query: "x := input; x.a == x.a",
wantQueries: []string{`__localq2__ = input.a`},
},
{
note: "copy propagation: tautology, input ref",
query: "data.test.p",
modules: []string{`package test
p {
input.a == input.a
}`,
},
wantQueries: []string{`__localcp0__ = input.a`},
},
{
note: "copy propagation: tautology, var ref, ref is input",
query: "data.test.p",
modules: []string{`package test
p {
x := input
x.a == x.a
}`,
},
wantQueries: []string{`__localcp0__ = input.a`},
},
{
note: "copy propagation: tautology, var ref, ref is unknown data",
query: "data.test.p",
unknowns: []string{"data.bar.foo"},
modules: []string{`package test
p {
data.bar.foo.a == data.bar.foo.a
}`,
},
wantQueries: []string{`__localcp0__ = data.bar.foo.a`},
},
{
note: "copy propagation: tautology, var ref, ref is input, via unknown",
// NOTE(sr): If we were having unkowns: [input.foo] and the rule body was
// input.a == input.a, we'd never reach copy-propagation -- partial eval would
// have failed before.
query: "data.test.p",
unknowns: []string{"input"},
modules: []string{`package test
p {
input.foo.a == input.foo.a
}`,
},
wantQueries: []string{`__localcp0__ = input.foo.a`},
},
{
note: "copy propagation: tautology, var ref, ref is head var",
query: "data.test.p(input)",
modules: []string{`package test
p(x) {
x.a == x.a
}`,
},
wantQueries: []string{`__localcp1__ = input.a`},
},
{
note: "save set vars are namespaced",
query: "input = x; data.test.f(1)",
Expand Down Expand Up @@ -2985,7 +3079,12 @@ func TestTopDownPartialEval(t *testing.T) {
x = true
}`,
},
wantQueries: []string{"a1 = input.foo1; b1 = input.foo2; c1 = input.foo3; d1 = input.foo4; e1 = input.foo5"},
wantQueries: []string{`
e1 = input.foo5
d1 = input.foo4
c1 = input.foo3
b1 = input.foo2
a1 = input.foo1`},
},
{
note: "partial object rules not memoized",
Expand Down Expand Up @@ -3054,34 +3153,6 @@ func TestTopDownPartialEval(t *testing.T) {
shallow: true,
skipPartialNamespace: true,
},
{
note: "copypropagation: circular reference (bug 3559)",
query: "data.test.p",
modules: []string{`package test
p {
q[_]
}
q[x] {
x = input[x]
}`,
},
wantQueries: []string{`x_term_1_01; x_term_1_01 = input[x_term_1_01]`},
},
{
note: "copypropagation: circular reference (bug 3071)",
query: "data.test.p",
modules: []string{`package test
p[y] {
s := { i | input[i] }
s & set() != s
y := sprintf("%v", [s])
}`,
},
wantQueries: []string{`data.partial.test.p`},
wantSupport: []string{`package partial.test
p[__local1__1] { __local0__1 = {i1 | input[i1]}; neq(and(__local0__1, set()), __local0__1); sprintf("%v", [__local0__1], __local1__1) }
`},
},
{
note: "every: empty domain, no unknowns",
query: "data.test.p",
Expand Down

0 comments on commit 1c1957c

Please sign in to comment.