Skip to content

Commit

Permalink
compile: Find dependents of entrypoints and compile them
Browse files Browse the repository at this point in the history
Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall committed Nov 6, 2020
1 parent 76f232c commit 0402892
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 5 deletions.
36 changes: 33 additions & 3 deletions compile/compile.go
Expand Up @@ -370,6 +370,32 @@ func (c *Compiler) compileWasm(ctx context.Context) error {
}
}

// Find transitive dependents of entrypoints and add them to the set to compile.
//
// NOTE(tsandall): We compile entrypoints because the evaluator does not support
// evaluation of wasm-compiled rules when 'with' statements are in-scope. Compiling
// out the dependents avoids the need to support that case for now.
deps := map[*ast.Rule]struct{}{}
for i := range c.entrypointrefs {
transitiveDocumentDependents(c.compiler, c.entrypointrefs[i], deps)
}

extras := ast.NewSet()
for rule := range deps {
extras.Add(ast.NewTerm(rule.Path()))
}

sorted := extras.Sorted()

for i := 0; i < sorted.Len(); i++ {
p, err := sorted.Elem(i).Value.(ast.Ref).Ptr()
if err != nil {
return err
}
c.entrypoints = append(c.entrypoints, p)
c.entrypointrefs = append(c.entrypointrefs, sorted.Elem(i))
}

// Create query sets for each of the entrypoints.
resultSym := ast.NewTerm(wasmResultVar)
queries := make([]planner.QuerySet, len(c.entrypointrefs))
Expand Down Expand Up @@ -641,9 +667,7 @@ func (o *optimizer) findRequiredDocuments(ref *ast.Term) []string {
keep := map[string]*ast.Location{}
deps := map[*ast.Rule]struct{}{}

for _, r := range o.compiler.GetRules(ref.Value.(ast.Ref)) {
transitiveDependents(o.compiler, r, deps)
}
transitiveDocumentDependents(o.compiler, ref, deps)

for rule := range deps {
ast.WalkExprs(rule, func(expr *ast.Expr) bool {
Expand Down Expand Up @@ -840,6 +864,12 @@ func compile(c *ast.Capabilities, b *bundle.Bundle) (*ast.Compiler, error) {
return compiler, nil
}

func transitiveDocumentDependents(compiler *ast.Compiler, ref *ast.Term, deps map[*ast.Rule]struct{}) {
for _, rule := range compiler.GetRules(ref.Value.(ast.Ref)) {
transitiveDependents(compiler, rule, deps)
}
}

func transitiveDependents(compiler *ast.Compiler, rule *ast.Rule, deps map[*ast.Rule]struct{}) {
for x := range compiler.Graph.Dependents(rule) {
other := x.(*ast.Rule)
Expand Down
52 changes: 50 additions & 2 deletions compile/compile_test.go
Expand Up @@ -429,10 +429,10 @@ func TestCompilerWasmTargetMultipleEntrypoints(t *testing.T) {
p = true`,
"policy.rego": `package policy
authz = true`,
"mask.rego": `package system.log
mask["/input/password"]`,
}

Expand Down Expand Up @@ -470,6 +470,54 @@ func TestCompilerWasmTargetMultipleEntrypoints(t *testing.T) {
})
}

func TestCompilerWasmTargetEntrypointDependents(t *testing.T) {
files := map[string]string{
"test.rego": `package test
p { q }
q { r }
r = 1
s = 2`}

test.WithTempFS(files, func(root string) {

compiler := New().WithPaths(root).WithTarget("wasm").WithEntrypoints("test/r")
err := compiler.Build(context.Background())
if err != nil {
t.Fatal(err)
}

if len(compiler.bundle.WasmModules) != 1 {
t.Fatalf("expected 1 Wasm modules, got: %d", len(compiler.bundle.WasmModules))
}

expManifest := bundle.Manifest{}
expManifest.Init()
expManifest.WasmResolvers = []bundle.WasmResolver{
{
Entrypoint: "test/r",
Module: "/policy.wasm",
},
{
Entrypoint: "test/p",
Module: "/policy.wasm",
},
{
Entrypoint: "test/q",
Module: "/policy.wasm",
},
}

if !compiler.bundle.Manifest.Equal(expManifest) {
t.Fatalf("\nExpected manifest: %+v\nGot: %+v\n", expManifest, compiler.bundle.Manifest)
}

ensureEntrypointRemoved(t, compiler.bundle, "test/p")
ensureEntrypointRemoved(t, compiler.bundle, "test/q")
ensureEntrypointRemoved(t, compiler.bundle, "test/r")
})
}

func TestCompilerWasmTargetLazyCompile(t *testing.T) {
files := map[string]string{
"test.rego": `package test
Expand Down
11 changes: 11 additions & 0 deletions topdown/resolver.go
Expand Up @@ -49,6 +49,9 @@ func (t *resolverTrie) Resolve(e *eval, ref ast.Ref) (ast.Value, error) {
Metrics: e.metrics,
}
e.traceWasm(e.query[e.index], &in.Ref)
if e.data != nil {
return nil, errInScopeWithStmt
}
result, err := node.r.Eval(e.ctx, in)
if err != nil {
return nil, err
Expand All @@ -73,6 +76,9 @@ func (t *resolverTrie) Resolve(e *eval, ref ast.Ref) (ast.Value, error) {
func (t *resolverTrie) mktree(e *eval, in resolver.Input) (ast.Value, error) {
if t.r != nil {
e.traceWasm(e.query[e.index], &in.Ref)
if e.data != nil {
return nil, errInScopeWithStmt
}
result, err := t.r.Eval(e.ctx, in)
if err != nil {
return nil, err
Expand All @@ -94,3 +100,8 @@ func (t *resolverTrie) mktree(e *eval, in resolver.Input) (ast.Value, error) {
}
return obj, nil
}

var errInScopeWithStmt = &Error{
Code: InternalErr,
Message: "wasm cannot be executed when 'with' statements are in-scope",
}

0 comments on commit 0402892

Please sign in to comment.