Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deps: Improving deps command performance #6688

Merged
merged 6 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 68 additions & 21 deletions dependencies/deps.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,38 +89,40 @@ func Minimal(x interface{}) (resolved []ast.Ref, err error) {
// The returned refs are always constant and are truncated at any point where they become
// dynamic. That is, a ref like data.a.b[x] will be truncated to data.a.b.
func Base(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
baseRefs, err := base(compiler, x)
baseRefs := newRefSet()
err := base(compiler, x, baseRefs)
if err != nil {
return nil, err
}

return dedup(baseRefs), nil
return dedup(baseRefs.toSlice()), nil
}

func base(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
func base(compiler *ast.Compiler, x interface{}, baseRefs *dependencies) error {
refs, err := Minimal(x)
if err != nil {
return nil, err
return err
}

var baseRefs []ast.Ref
for _, r := range refs {
r = r.ConstantPrefix()
if rules := compiler.GetRules(r); len(rules) > 0 {
for _, rule := range rules {
bases, err := base(compiler, rule)
if err != nil {
if baseRefs.visited(rule) {
continue
}
baseRefs.visit(rule)
if err := base(compiler, rule, baseRefs); err != nil {
panic("not reached")
}

baseRefs = append(baseRefs, bases...)
}
} else {
baseRefs = append(baseRefs, r)
baseRefs.add(r)
}
}

return baseRefs, nil
return nil
}

// Virtual returns the list of virtual data documents that the given AST element depends
Expand All @@ -129,37 +131,82 @@ func base(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
// The returned refs are always constant and are truncated at any point where they become
// dynamic. That is, a ref like data.a.b[x] will be truncated to data.a.b.
func Virtual(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
virtualRefs, err := virtual(compiler, x)
virtualRefs := newRefSet()
err := virtual(compiler, x, virtualRefs)
if err != nil {
return nil, err
}

return dedup(virtualRefs), nil
return dedup(virtualRefs.toSlice()), nil
}

func virtual(compiler *ast.Compiler, x interface{}) ([]ast.Ref, error) {
func virtual(compiler *ast.Compiler, x interface{}, virtualRefs *dependencies) error {
refs, err := Minimal(x)
if err != nil {
return nil, err
return err
}

var virtualRefs []ast.Ref
for _, r := range refs {
r = r.ConstantPrefix()
if rules := compiler.GetRules(r); len(rules) > 0 {
for _, rule := range rules {
virtuals, err := virtual(compiler, rule)
if virtualRefs.visited(rule) {
continue
}
virtualRefs.visit(rule)
err := virtual(compiler, rule, virtualRefs)
if err != nil {
panic("not reached")
}

virtualRefs = append(virtualRefs, rule.Path())
virtualRefs = append(virtualRefs, virtuals...)
virtualRefs.add(rule.Path())
}
}
}

return virtualRefs, nil
return nil
}

type dependencies struct {
refs *util.HashMap
visitedRules *util.HashMap
}

func newRefSet() *dependencies {
return &dependencies{
refs: util.NewHashMap(func(a, b util.T) bool {
return a.(ast.Ref).Equal(b.(ast.Ref))
}, func(a util.T) int {
return a.(ast.Ref).Hash()
}),
visitedRules: util.NewHashMap(func(a, b util.T) bool {
return a.(*ast.Rule).Equal(b.(*ast.Rule))
}, func(a util.T) int {
return a.(*ast.Rule).Ref().Hash()
}),
}
}

func (rs *dependencies) add(r ast.Ref) {
rs.refs.Put(r, r)
}

func (rs *dependencies) visit(rule *ast.Rule) {
rs.visitedRules.Put(rule, rule)
}

func (rs *dependencies) visited(rule *ast.Rule) bool {
_, found := rs.visitedRules.Get(rule)
return found
}

func (rs *dependencies) toSlice() []ast.Ref {
var result []ast.Ref
rs.refs.Iter(func(k, _ util.T) bool {
result = append(result, k.(ast.Ref))
return false
})
return result
}

func dedup(refs []ast.Ref) []ast.Ref {
Expand All @@ -172,9 +219,9 @@ func dedup(refs []ast.Ref) []ast.Ref {
})
}

// filter removes all items from the list that cause pref to return true. It is
// filter removes all items from the list that cause pred to return true. It is
// called on adjacent pairs of elements, and the one passed as the second argument
// to pref is considered the current one being examined. The first argument will
// to pred is considered the current one being examined. The first argument will
// be the element immediately preceding it.
func filter(rs []ast.Ref, pred func(ast.Ref, ast.Ref) bool) (filtered []ast.Ref) {
if len(rs) == 0 {
Expand Down
80 changes: 80 additions & 0 deletions dependencies/deps_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package dependencies

import (
"fmt"
"strings"
"testing"

"github.com/open-policy-agent/opa/ast"
)

func BenchmarkBase(b *testing.B) {
ruleCounts := []int{10, 20, 50}
for _, ruleCount := range ruleCounts {
b.Run(fmt.Sprint(ruleCount), func(b *testing.B) {
policy := makePolicy(ruleCount)
module := ast.MustParseModule(policy)
compiler := ast.NewCompiler()
if compiler.Compile(map[string]*ast.Module{"test": module}); compiler.Failed() {
b.Fatalf("Failed to compile policy: %v", compiler.Errors)
}

ref := ast.MustParseRef("data.test.main")

b.ResetTimer()

_, err := Base(compiler, ref)
if err != nil {
b.Fatalf("Failed to compute base doc deps: %v", err)
}
})
}
}

func BenchmarkVirtual(b *testing.B) {
ruleCounts := []int{10, 20, 50}
for _, ruleCount := range ruleCounts {
b.Run(fmt.Sprint(ruleCount), func(b *testing.B) {
policy := makePolicy(ruleCount)
module := ast.MustParseModule(policy)
compiler := ast.NewCompiler()
if compiler.Compile(map[string]*ast.Module{"test": module}); compiler.Failed() {
b.Fatalf("Failed to compile policy: %v", compiler.Errors)
}

ref := ast.MustParseRef("data.test.main")

b.ResetTimer()

_, err := Virtual(compiler, ref)
if err != nil {
b.Fatalf("Failed to compute virtual doc deps: %v", err)
}
})
}
}

// makePolicy constructs a policy with ruleCount number of rules.
// Each rule will depend on as many other rules as possible without creating circular dependencies.
func makePolicy(ruleCount int) string {
var b strings.Builder
b.WriteString("package test\n\n")

b.WriteString("main {\n")
for i := 0; i < ruleCount; i++ {
b.WriteString(fmt.Sprintf(" p_%d\n", i))
}
b.WriteString("}\n\n")

for i := 0; i < ruleCount; i++ {
b.WriteString(fmt.Sprintf("p_%d {\n", i))
for j := i + 1; j < ruleCount; j++ {
b.WriteString(fmt.Sprintf(" p_%d\n", j))
}
b.WriteString(" input.x == 1\n")
b.WriteString(" input.y == 2\n")
b.WriteString(" input.z == 3\n")
b.WriteString("}\n")
}
return b.String()
}
Loading