From 0402892d8ee5a223453ee41cf4bfad446f11aea2 Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Thu, 5 Nov 2020 17:58:21 -0500 Subject: [PATCH] compile: Find dependents of entrypoints and compile them Signed-off-by: Torin Sandall --- compile/compile.go | 36 +++++++++++++++++++++++++--- compile/compile_test.go | 52 +++++++++++++++++++++++++++++++++++++++-- topdown/resolver.go | 11 +++++++++ 3 files changed, 94 insertions(+), 5 deletions(-) diff --git a/compile/compile.go b/compile/compile.go index 2383f1cb13..af06ad3485 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -370,6 +370,32 @@ func (c *Compiler) compileWasm(ctx context.Context) error { } } + // Find transitive dependents of entrypoints and add them to the set to compile. + // + // NOTE(tsandall): We compile entrypoints because the evaluator does not support + // evaluation of wasm-compiled rules when 'with' statements are in-scope. Compiling + // out the dependents avoids the need to support that case for now. + deps := map[*ast.Rule]struct{}{} + for i := range c.entrypointrefs { + transitiveDocumentDependents(c.compiler, c.entrypointrefs[i], deps) + } + + extras := ast.NewSet() + for rule := range deps { + extras.Add(ast.NewTerm(rule.Path())) + } + + sorted := extras.Sorted() + + for i := 0; i < sorted.Len(); i++ { + p, err := sorted.Elem(i).Value.(ast.Ref).Ptr() + if err != nil { + return err + } + c.entrypoints = append(c.entrypoints, p) + c.entrypointrefs = append(c.entrypointrefs, sorted.Elem(i)) + } + // Create query sets for each of the entrypoints. resultSym := ast.NewTerm(wasmResultVar) queries := make([]planner.QuerySet, len(c.entrypointrefs)) @@ -641,9 +667,7 @@ func (o *optimizer) findRequiredDocuments(ref *ast.Term) []string { keep := map[string]*ast.Location{} deps := map[*ast.Rule]struct{}{} - for _, r := range o.compiler.GetRules(ref.Value.(ast.Ref)) { - transitiveDependents(o.compiler, r, deps) - } + transitiveDocumentDependents(o.compiler, ref, deps) for rule := range deps { ast.WalkExprs(rule, func(expr *ast.Expr) bool { @@ -840,6 +864,12 @@ func compile(c *ast.Capabilities, b *bundle.Bundle) (*ast.Compiler, error) { return compiler, nil } +func transitiveDocumentDependents(compiler *ast.Compiler, ref *ast.Term, deps map[*ast.Rule]struct{}) { + for _, rule := range compiler.GetRules(ref.Value.(ast.Ref)) { + transitiveDependents(compiler, rule, deps) + } +} + func transitiveDependents(compiler *ast.Compiler, rule *ast.Rule, deps map[*ast.Rule]struct{}) { for x := range compiler.Graph.Dependents(rule) { other := x.(*ast.Rule) diff --git a/compile/compile_test.go b/compile/compile_test.go index c67adc5706..d230b3af77 100644 --- a/compile/compile_test.go +++ b/compile/compile_test.go @@ -429,10 +429,10 @@ func TestCompilerWasmTargetMultipleEntrypoints(t *testing.T) { p = true`, "policy.rego": `package policy - + authz = true`, "mask.rego": `package system.log - + mask["/input/password"]`, } @@ -470,6 +470,54 @@ func TestCompilerWasmTargetMultipleEntrypoints(t *testing.T) { }) } +func TestCompilerWasmTargetEntrypointDependents(t *testing.T) { + files := map[string]string{ + "test.rego": `package test + + p { q } + q { r } + r = 1 + s = 2`} + + test.WithTempFS(files, func(root string) { + + compiler := New().WithPaths(root).WithTarget("wasm").WithEntrypoints("test/r") + err := compiler.Build(context.Background()) + if err != nil { + t.Fatal(err) + } + + if len(compiler.bundle.WasmModules) != 1 { + t.Fatalf("expected 1 Wasm modules, got: %d", len(compiler.bundle.WasmModules)) + } + + expManifest := bundle.Manifest{} + expManifest.Init() + expManifest.WasmResolvers = []bundle.WasmResolver{ + { + Entrypoint: "test/r", + Module: "/policy.wasm", + }, + { + Entrypoint: "test/p", + Module: "/policy.wasm", + }, + { + Entrypoint: "test/q", + Module: "/policy.wasm", + }, + } + + if !compiler.bundle.Manifest.Equal(expManifest) { + t.Fatalf("\nExpected manifest: %+v\nGot: %+v\n", expManifest, compiler.bundle.Manifest) + } + + ensureEntrypointRemoved(t, compiler.bundle, "test/p") + ensureEntrypointRemoved(t, compiler.bundle, "test/q") + ensureEntrypointRemoved(t, compiler.bundle, "test/r") + }) +} + func TestCompilerWasmTargetLazyCompile(t *testing.T) { files := map[string]string{ "test.rego": `package test diff --git a/topdown/resolver.go b/topdown/resolver.go index 50ed33880c..5ed6c1e443 100644 --- a/topdown/resolver.go +++ b/topdown/resolver.go @@ -49,6 +49,9 @@ func (t *resolverTrie) Resolve(e *eval, ref ast.Ref) (ast.Value, error) { Metrics: e.metrics, } e.traceWasm(e.query[e.index], &in.Ref) + if e.data != nil { + return nil, errInScopeWithStmt + } result, err := node.r.Eval(e.ctx, in) if err != nil { return nil, err @@ -73,6 +76,9 @@ func (t *resolverTrie) Resolve(e *eval, ref ast.Ref) (ast.Value, error) { func (t *resolverTrie) mktree(e *eval, in resolver.Input) (ast.Value, error) { if t.r != nil { e.traceWasm(e.query[e.index], &in.Ref) + if e.data != nil { + return nil, errInScopeWithStmt + } result, err := t.r.Eval(e.ctx, in) if err != nil { return nil, err @@ -94,3 +100,8 @@ func (t *resolverTrie) mktree(e *eval, in resolver.Input) (ast.Value, error) { } return obj, nil } + +var errInScopeWithStmt = &Error{ + Code: InternalErr, + Message: "wasm cannot be executed when 'with' statements are in-scope", +}