Skip to content

Commit

Permalink
deps: Improving deps command performance (#6688)
Browse files Browse the repository at this point in the history
Improving memory footprint and execution time of deps command for policies with high dependency connectivity.

Fixes: #6685
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Apr 9, 2024
1 parent a94e585 commit e7e5b6b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 21 deletions.
89 changes: 68 additions & 21 deletions dependencies/deps.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,38 +89,40 @@ func Minimal(x interface{}) (resolved []ast.Ref, err error) {
// The returned refs are always constant and are truncated at any point where they become
// dynamic. That is, a ref like data.a.b[x] will be truncated to data.a.b.
func Base(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
baseRefs, err := base(compiler, x)
baseRefs := newRefSet()
err := base(compiler, x, baseRefs)
if err != nil {
return nil, err
}

return dedup(baseRefs), nil
return dedup(baseRefs.toSlice()), nil
}

func base(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
func base(compiler *ast.Compiler, x interface{}, baseRefs *dependencies) error {
refs, err := Minimal(x)
if err != nil {
return nil, err
return err
}

var baseRefs []ast.Ref
for _, r := range refs {
r = r.ConstantPrefix()
if rules := compiler.GetRules(r); len(rules) > 0 {
for _, rule := range rules {
bases, err := base(compiler, rule)
if err != nil {
if baseRefs.visited(rule) {
continue
}
baseRefs.visit(rule)
if err := base(compiler, rule, baseRefs); err != nil {
panic("not reached")
}

baseRefs = append(baseRefs, bases...)
}
} else {
baseRefs = append(baseRefs, r)
baseRefs.add(r)
}
}

return baseRefs, nil
return nil
}

// Virtual returns the list of virtual data documents that the given AST element depends
Expand All @@ -129,37 +131,82 @@ func base(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
// The returned refs are always constant and are truncated at any point where they become
// dynamic. That is, a ref like data.a.b[x] will be truncated to data.a.b.
func Virtual(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
virtualRefs, err := virtual(compiler, x)
virtualRefs := newRefSet()
err := virtual(compiler, x, virtualRefs)
if err != nil {
return nil, err
}

return dedup(virtualRefs), nil
return dedup(virtualRefs.toSlice()), nil
}

func virtual(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
func virtual(compiler *ast.Compiler, x interface{}, virtualRefs *dependencies) error {
refs, err := Minimal(x)
if err != nil {
return nil, err
return err
}

var virtualRefs []ast.Ref
for _, r := range refs {
r = r.ConstantPrefix()
if rules := compiler.GetRules(r); len(rules) > 0 {
for _, rule := range rules {
virtuals, err := virtual(compiler, rule)
if virtualRefs.visited(rule) {
continue
}
virtualRefs.visit(rule)
err := virtual(compiler, rule, virtualRefs)
if err != nil {
panic("not reached")
}

virtualRefs = append(virtualRefs, rule.Path())
virtualRefs = append(virtualRefs, virtuals...)
virtualRefs.add(rule.Path())
}
}
}

return virtualRefs, nil
return nil
}

type dependencies struct {
refs *util.HashMap
visitedRules *util.HashMap
}

func newRefSet() *dependencies {
return &dependencies{
refs: util.NewHashMap(func(a, b util.T) bool {
return a.(ast.Ref).Equal(b.(ast.Ref))
}, func(a util.T) int {
return a.(ast.Ref).Hash()
}),
visitedRules: util.NewHashMap(func(a, b util.T) bool {
return a.(*ast.Rule).Equal(b.(*ast.Rule))
}, func(a util.T) int {
return a.(*ast.Rule).Ref().Hash()
}),
}
}

func (rs *dependencies) add(r ast.Ref) {
rs.refs.Put(r, r)
}

func (rs *dependencies) visit(rule *ast.Rule) {
rs.visitedRules.Put(rule, rule)
}

func (rs *dependencies) visited(rule *ast.Rule) bool {
_, found := rs.visitedRules.Get(rule)
return found
}

func (rs *dependencies) toSlice() []ast.Ref {
var result []ast.Ref
rs.refs.Iter(func(k, _ util.T) bool {
result = append(result, k.(ast.Ref))
return false
})
return result
}

func dedup(refs []ast.Ref) []ast.Ref {
Expand All @@ -172,9 +219,9 @@ func dedup(refs []ast.Ref) []ast.Ref {
})
}

// filter removes all items from the list that cause pref to return true. It is
// filter removes all items from the list that cause pred to return true. It is
// called on adjacent pairs of elements, and the one passed as the second argument
// to pref is considered the current one being examined. The first argument will
// to pred is considered the current one being examined. The first argument will
// be the element immediately preceding it.
func filter(rs []ast.Ref, pred func(ast.Ref, ast.Ref) bool) (filtered []ast.Ref) {
if len(rs) == 0 {
Expand Down
80 changes: 80 additions & 0 deletions dependencies/deps_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package dependencies

import (
"fmt"
"strings"
"testing"

"github.com/open-policy-agent/opa/ast"
)

func BenchmarkBase(b *testing.B) {
ruleCounts := []int{10, 20, 50}
for _, ruleCount := range ruleCounts {
b.Run(fmt.Sprint(ruleCount), func(b *testing.B) {
policy := makePolicy(ruleCount)
module := ast.MustParseModule(policy)
compiler := ast.NewCompiler()
if compiler.Compile(map[string]*ast.Module{"test": module}); compiler.Failed() {
b.Fatalf("Failed to compile policy: %v", compiler.Errors)
}

ref := ast.MustParseRef("data.test.main")

b.ResetTimer()

_, err := Base(compiler, ref)
if err != nil {
b.Fatalf("Failed to compute base doc deps: %v", err)
}
})
}
}

func BenchmarkVirtual(b *testing.B) {
ruleCounts := []int{10, 20, 50}
for _, ruleCount := range ruleCounts {
b.Run(fmt.Sprint(ruleCount), func(b *testing.B) {
policy := makePolicy(ruleCount)
module := ast.MustParseModule(policy)
compiler := ast.NewCompiler()
if compiler.Compile(map[string]*ast.Module{"test": module}); compiler.Failed() {
b.Fatalf("Failed to compile policy: %v", compiler.Errors)
}

ref := ast.MustParseRef("data.test.main")

b.ResetTimer()

_, err := Virtual(compiler, ref)
if err != nil {
b.Fatalf("Failed to compute virtual doc deps: %v", err)
}
})
}
}

// makePolicy constructs a policy with ruleCount number of rules.
// Each rule will depend on as many other rules as possible without creating circular dependencies.
func makePolicy(ruleCount int) string {
var b strings.Builder
b.WriteString("package test\n\n")

b.WriteString("main {\n")
for i := 0; i < ruleCount; i++ {
b.WriteString(fmt.Sprintf(" p_%d\n", i))
}
b.WriteString("}\n\n")

for i := 0; i < ruleCount; i++ {
b.WriteString(fmt.Sprintf("p_%d {\n", i))
for j := i + 1; j < ruleCount; j++ {
b.WriteString(fmt.Sprintf(" p_%d\n", j))
}
b.WriteString(" input.x == 1\n")
b.WriteString(" input.y == 2\n")
b.WriteString(" input.z == 3\n")
b.WriteString("}\n")
}
return b.String()
}

0 comments on commit e7e5b6b

Please sign in to comment.