diff --git a/ast/parser.go b/ast/parser.go index 98f9c56bd6..be4724a5e1 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -2570,6 +2570,11 @@ var futureKeywords = map[string]tokens.Token{ "if": tokens.If, } +func IsFutureKeyword(s string) bool { + _, ok := futureKeywords[s] + return ok +} + func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]tokens.Token) { path := imp.Path.Value.(Ref) diff --git a/ast/policy.go b/ast/policy.go index 270e9aaf72..051eccc1e6 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -407,6 +407,12 @@ func (mod *Module) RegoVersion() RegoVersion { return mod.regoVersion } +// SetRegoVersion sets the RegoVersion for the module. +// Note: Setting a rego-version that does not match the module's rego-version might have unintended consequences. +func (mod *Module) SetRegoVersion(v RegoVersion) { + mod.regoVersion = v +} + // NewComment returns a new Comment object. func NewComment(text []byte) *Comment { return &Comment{ diff --git a/bundle/bundle.go b/bundle/bundle.go index 09fd18d493..9c02568b53 100644 --- a/bundle/bundle.go +++ b/bundle/bundle.go @@ -1082,8 +1082,15 @@ func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveMo var err error for i, module := range b.Modules { + opts := format.Opts{} + if preserveModuleRegoVersion { + opts.RegoVersion = module.Parsed.RegoVersion() + } else { + opts.RegoVersion = version + } + if module.Raw == nil { - module.Raw, err = format.AstWithOpts(module.Parsed, format.Opts{RegoVersion: version}) + module.Raw, err = format.AstWithOpts(module.Parsed, opts) if err != nil { return err } @@ -1093,13 +1100,6 @@ func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveMo path = module.Path } - opts := format.Opts{} - if preserveModuleRegoVersion { - opts.RegoVersion = module.Parsed.RegoVersion() - } else { - opts.RegoVersion = version - } - module.Raw, err = format.SourceWithOpts(path, module.Raw, opts) if err != nil { return err diff --git a/cmd/build_test.go b/cmd/build_test.go index 758df258ae..8a47d66109 100644 --- a/cmd/build_test.go +++ b/cmd/build_test.go @@ -1792,14 +1792,70 @@ allow if { } } -func TestBuildWithV1CompatibleFlagOptimized(t *testing.T) { +func TestBuildOptimizedWithRegoVersion(t *testing.T) { tests := []struct { - note string - files map[string]string - expectedFiles map[string]string + note string + v1Compatible bool + regoV1ImportCapable bool + files map[string]string + expectedFiles map[string]string }{ { - note: "No imports", + note: "v0, no future keywords", + v1Compatible: false, + regoV1ImportCapable: true, + files: map[string]string{ + "test.rego": `package test +# METADATA +# entrypoint: true +p[v] { + v := input.v +} +`, + }, + expectedFiles: map[string]string{ + "/.manifest": `{"revision":"","roots":[""],"rego_version":0} +`, + // rego.v1 import added to optimized support module + "/optimized/test.rego": `package test + +import rego.v1 + +p contains __local0__1 if { + __local0__1 = input.v +} +`, + }, + }, + { + note: "v0, No future keywords, not rego.v1 import capable", + v1Compatible: false, + regoV1ImportCapable: false, + files: map[string]string{ + "test.rego": `package test +# METADATA +# entrypoint: true +p[v] { + v := input.v +} +`, + }, + expectedFiles: map[string]string{ + "/.manifest": `{"revision":"","roots":[""],"rego_version":0} +`, + // rego.v1 import NOT added to optimized support module + "/optimized/test.rego": `package test + +p[__local0__1] { + __local0__1 = input.v +} +`, + }, + }, + { + note: "v1, No imports", + v1Compatible: true, + regoV1ImportCapable: true, files: map[string]string{ "test.rego": `package test # METADATA @@ -1822,7 +1878,9 @@ foo contains __local1__1 if { }, }, { - note: "rego.v1 imported", + note: "v1, rego.v1 imported", + v1Compatible: true, + regoV1ImportCapable: true, files: map[string]string{ "test.rego": `package test import rego.v1 @@ -1849,7 +1907,9 @@ foo contains __local1__1 if { }, }, { - note: "future.keywords imported", + note: "v1, future.keywords imported", + v1Compatible: true, + regoV1ImportCapable: true, files: map[string]string{ "test.rego": `package test import future.keywords @@ -1879,9 +1939,19 @@ foo contains __local1__1 if { test.WithTempFS(tc.files, func(root string) { params := newBuildParams() params.outputFile = path.Join(root, "bundle.tar.gz") - params.v1Compatible = true + params.v1Compatible = tc.v1Compatible params.optimizationLevel = 1 + if !tc.regoV1ImportCapable { + caps := newcapabilitiesFlag() + caps.C = ast.CapabilitiesForThisVersion() + caps.C.Features = []string{ + ast.FeatureRefHeadStringPrefixes, + ast.FeatureRefHeads, + } + params.capabilities = caps + } + err := dobuild(params, []string{root}) if err != nil { diff --git a/cmd/eval_test.go b/cmd/eval_test.go index c5adbdc9c2..6eb432a96e 100755 --- a/cmd/eval_test.go +++ b/cmd/eval_test.go @@ -1313,6 +1313,143 @@ time.clock(input.y, time.clock(input.x)) } } +func TestEvalPartialRegoVersionOutput(t *testing.T) { + tests := []struct { + note string + regoV1ImportCapable bool + v1Compatible bool + query string + module string + expected string + }{ + { + note: "v0, no future keywords", + regoV1ImportCapable: true, + query: "data.test.p", + module: `package test + +p[v] { + v := input.v +} +`, + expected: `# Query 1 +data.partial.test.p = _term_0_0 +_term_0_0 + +# Module 1 +package partial.test + +import rego.v1 + +p contains __local0__1 if __local0__1 = input.v +`, + }, + { + note: "v0, no future keywords, not rego.v1 import capable", + regoV1ImportCapable: false, + query: "data.test.p", + module: `package test + +p[v] { + v := input.v +} +`, + expected: `# Query 1 +data.partial.test.p = _term_0_0 +_term_0_0 + +# Module 1 +package partial.test + +p[__local0__1] { + __local0__1 = input.v +} +`, + }, + { + note: "v0, future keywords", + regoV1ImportCapable: true, + query: "data.test.p", + module: `package test + +import rego.v1 + +p contains v if { + v := input.v +} +`, + expected: `# Query 1 +data.partial.test.p = _term_0_0 +_term_0_0 + +# Module 1 +package partial.test + +import rego.v1 + +p contains __local0__1 if __local0__1 = input.v +`, + }, + { + note: "v1", + regoV1ImportCapable: true, + v1Compatible: true, + query: "data.test.p", + module: `package test + +p contains v if { + v := input.v +} +`, + expected: `# Query 1 +data.partial.test.p = _term_0_0 +_term_0_0 + +# Module 1 +package partial.test + +p contains __local0__1 if __local0__1 = input.v +`, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + files := map[string]string{ + "test.rego": tc.module, + } + + test.WithTempFS(files, func(path string) { + params := newEvalCommandParams() + _ = params.dataPaths.Set(filepath.Join(path, "test.rego")) + params.partial = true + params.shallowInlining = true + params.v1Compatible = tc.v1Compatible + _ = params.outputFormat.Set(evalSourceOutput) + + if !tc.regoV1ImportCapable { + caps := newcapabilitiesFlag() + caps.C = ast.CapabilitiesForThisVersion() + caps.C.Features = []string{ + ast.FeatureRefHeadStringPrefixes, + ast.FeatureRefHeads, + } + params.capabilities = caps + } + + buf := new(bytes.Buffer) + _, err := eval([]string{tc.query}, params, buf) + if err != nil { + t.Fatal("unexpected error:", err) + } + if actual := buf.String(); actual != tc.expected { + t.Errorf("expected output %q\ngot %q", tc.expected, actual) + } + }) + }) + } +} + func TestEvalDiscardOutput(t *testing.T) { tests := map[string]struct { query, format, expected string diff --git a/compile/compile.go b/compile/compile.go index 3f1f1d4212..9e80d01988 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -544,7 +544,8 @@ func (c *Compiler) optimize(ctx context.Context) error { WithEntrypoints(c.entrypointrefs). WithDebug(c.debug.Writer()). WithShallowInlining(c.optimizationLevel <= 1). - WithEnablePrintStatements(c.enablePrintStatements) + WithEnablePrintStatements(c.enablePrintStatements). + WithRegoVersion(c.regoVersion) if c.ns != "" { o = o.WithPartialNamespace(c.ns) @@ -869,6 +870,7 @@ type optimizer struct { shallow bool debug debug.Debug enablePrintStatements bool + regoVersion ast.RegoVersion } func newOptimizer(c *ast.Capabilities, b *bundle.Bundle) *optimizer { @@ -909,6 +911,11 @@ func (o *optimizer) WithPartialNamespace(ns string) *optimizer { return o } +func (o *optimizer) WithRegoVersion(regoVersion ast.RegoVersion) *optimizer { + o.regoVersion = regoVersion + return o +} + func (o *optimizer) Do(ctx context.Context) error { // NOTE(tsandall): if there are multiple entrypoints, copy the bundle because @@ -958,6 +965,8 @@ func (o *optimizer) Do(ctx context.Context) error { rego.ParsedUnknowns(unknowns), rego.Compiler(o.compiler), rego.Store(store), + rego.Capabilities(o.capabilities), + rego.SetRegoVersion(o.regoVersion), ) o.debug.Printf("optimizer: entrypoint: %v", e) diff --git a/compile/compile_test.go b/compile/compile_test.go index 85505281d3..3238d2247d 100644 --- a/compile/compile_test.go +++ b/compile/compile_test.go @@ -1608,11 +1608,18 @@ update { for _, useMemoryFS := range []bool{false, true} { test.WithTestFS(tc.files, useMemoryFS, func(root string, fsys fs.FS) { + caps := ast.CapabilitiesForThisVersion() + caps.Features = []string{ + ast.FeatureRefHeadStringPrefixes, + ast.FeatureRefHeads, + } + compiler := New(). WithFS(fsys). WithPaths(root). WithOptimizationLevel(1). - WithEntrypoints(tc.entrypoint) + WithEntrypoints(tc.entrypoint). + WithCapabilities(caps) err := compiler.Build(context.Background()) if err != nil { @@ -1641,6 +1648,150 @@ update { } } +func TestCompilerOptimizationSupportRegoVersion(t *testing.T) { + tests := []struct { + note string + modulesRegoVersion ast.RegoVersion + regoV1ImportCapable bool + entrypoint string + files map[string]string + expected []string + }{ + { + note: "v0 module, rego.v1 capable", + modulesRegoVersion: ast.RegoV0, + regoV1ImportCapable: true, + entrypoint: "test/p", + files: map[string]string{ + "test.rego": `package test +p { + input.x == 1 +}`, + }, + expected: []string{ + `package test + +import rego.v1 + +p if { + input.x = 1 +} +`, + }, + }, + { + note: "v0 module, not rego.v1 capable", + modulesRegoVersion: ast.RegoV0, + regoV1ImportCapable: false, + entrypoint: "test/p", + files: map[string]string{ + "test.rego": `package test +p { + input.x == 1 +}`, + }, + expected: []string{ + `package test + +p { + input.x = 1 +} +`, + }, + }, + { + note: "v0-compat_v1 module, rego.v1 capable", + modulesRegoVersion: ast.RegoV0CompatV1, + regoV1ImportCapable: true, + entrypoint: "test/p", + files: map[string]string{ + "test.rego": `package test + +import rego.v1 + +p if { + input.x == 1 +}`, + }, + expected: []string{ + `package test + +import rego.v1 + +p if { + input.x = 1 +} +`, + }, + }, + { + note: "v1 module, rego.v1 capable", + modulesRegoVersion: ast.RegoV1, + regoV1ImportCapable: true, + entrypoint: "test/p", + files: map[string]string{ + "test.rego": `package test +p if { + input.x == 1 +}`, + }, + expected: []string{ + `package test + +p if { + input.x = 1 +} +`, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + test.WithTestFS(tc.files, true, func(root string, fsys fs.FS) { + capabilities := ast.CapabilitiesForThisVersion() + capabilities.Features = []string{ + ast.FeatureRefHeadStringPrefixes, + ast.FeatureRefHeads, + } + if tc.regoV1ImportCapable { + capabilities.Features = append(capabilities.Features, ast.FeatureRegoV1Import) + } + + compiler := New(). + WithCapabilities(capabilities). + WithRegoVersion(tc.modulesRegoVersion). + WithFS(fsys). + WithPaths(root). + WithOptimizationLevel(1). + WithEntrypoints(tc.entrypoint) + + err := compiler.Build(context.Background()) + if err != nil { + t.Fatal(err) + } + + if len(compiler.bundle.Modules) != len(tc.expected) { + t.Fatalf("expected %v modules but got: %v:\n\n%v", + len(tc.expected), len(compiler.bundle.Modules), modulesToString(compiler.bundle.Modules)) + } + + actual := make(map[string]struct{}) + for _, m := range compiler.bundle.Modules { + actual[string(m.Raw)] = struct{}{} + } + + for _, e := range tc.expected { + if _, ok := actual[e]; !ok { + t.Fatalf("expected to find module:\n\n%v\n\nin bundle but got:\n\n%v", + e, modulesToString(compiler.bundle.Modules)) + } + } + }) + }) + } +} + func modulesToString(modules []bundle.ModuleFile) string { var buf bytes.Buffer for i, m := range modules { diff --git a/internal/presentation/presentation.go b/internal/presentation/presentation.go index 57b6b63dd7..42fef24600 100644 --- a/internal/presentation/presentation.go +++ b/internal/presentation/presentation.go @@ -349,7 +349,7 @@ func Source(w io.Writer, r Output) error { for i := range r.Partial.Support { fmt.Fprintf(w, "# Module %d\n", i+1) - bs, err := format.AstWithOpts(r.Partial.Support[i], format.Opts{IgnoreLocations: true}) + bs, err := format.AstWithOpts(r.Partial.Support[i], format.Opts{IgnoreLocations: true, RegoVersion: r.Partial.Support[i].RegoVersion()}) if err != nil { return err } diff --git a/rego/rego.go b/rego/rego.go index 17873e380a..2c7601d586 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -2405,6 +2405,51 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, return nil, err } + if r.regoVersion == ast.RegoV0 && (r.capabilities == nil || r.capabilities.ContainsFeature(ast.FeatureRegoV1Import)) { + // If the target rego-version in v0, and the rego.v1 import is available, then we attempt to apply it to support modules. + + for i, mod := range support { + if mod.RegoVersion() != ast.RegoV0 { + continue + } + + // We can't apply the RegoV0CompatV1 version to the support module if it contains rules or vars that + // conflict with future keywords. + applyRegoVersion := true + + ast.WalkRules(mod, func(r *ast.Rule) bool { + name := r.Head.Name + if name == "" && len(r.Head.Reference) > 0 { + name = r.Head.Reference[0].Value.(ast.Var) + } + if ast.IsFutureKeyword(name.String()) { + applyRegoVersion = false + return true + } + return false + }) + + if applyRegoVersion { + ast.WalkVars(mod, func(v ast.Var) bool { + if ast.IsFutureKeyword(v.String()) { + applyRegoVersion = false + return true + } + return false + }) + } + + if applyRegoVersion { + support[i].SetRegoVersion(ast.RegoV0CompatV1) + } + } + } else { + // If the target rego-version is not v0, then we apply the target rego-version to the support modules. + for i := range support { + support[i].SetRegoVersion(r.regoVersion) + } + } + pq := &PartialQueries{ Queries: queries, Support: support,