Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ast: Importing rego.v1 in v0 support modules when applicable #6698

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2570,6 +2570,11 @@ var futureKeywords = map[string]tokens.Token{
"if": tokens.If,
}

func IsFutureKeyword(s string) bool {
_, ok := futureKeywords[s]
return ok
}

func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]tokens.Token) {
path := imp.Path.Value.(Ref)

Expand Down
6 changes: 6 additions & 0 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,12 @@ func (mod *Module) RegoVersion() RegoVersion {
return mod.regoVersion
}

// SetRegoVersion sets the RegoVersion for the module.
// Note: Setting a rego-version that does not match the module's rego-version might have unintended consequences.
func (mod *Module) SetRegoVersion(v RegoVersion) {
mod.regoVersion = v
}

// NewComment returns a new Comment object.
func NewComment(text []byte) *Comment {
return &Comment{
Expand Down
16 changes: 8 additions & 8 deletions bundle/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,15 @@ func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveMo
var err error

for i, module := range b.Modules {
opts := format.Opts{}
if preserveModuleRegoVersion {
opts.RegoVersion = module.Parsed.RegoVersion()
} else {
opts.RegoVersion = version
}

if module.Raw == nil {
module.Raw, err = format.AstWithOpts(module.Parsed, format.Opts{RegoVersion: version})
module.Raw, err = format.AstWithOpts(module.Parsed, opts)
if err != nil {
return err
}
Expand All @@ -1093,13 +1100,6 @@ func (b *Bundle) FormatModulesForRegoVersion(version ast.RegoVersion, preserveMo
path = module.Path
}

opts := format.Opts{}
if preserveModuleRegoVersion {
opts.RegoVersion = module.Parsed.RegoVersion()
} else {
opts.RegoVersion = version
}

module.Raw, err = format.SourceWithOpts(path, module.Raw, opts)
if err != nil {
return err
Expand Down
86 changes: 78 additions & 8 deletions cmd/build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1792,14 +1792,70 @@ allow if {
}
}

func TestBuildWithV1CompatibleFlagOptimized(t *testing.T) {
func TestBuildOptimizedWithRegoVersion(t *testing.T) {
tests := []struct {
note string
files map[string]string
expectedFiles map[string]string
note string
v1Compatible bool
regoV1ImportCapable bool
files map[string]string
expectedFiles map[string]string
}{
{
note: "No imports",
note: "v0, no future keywords",
v1Compatible: false,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
# METADATA
# entrypoint: true
p[v] {
v := input.v
}
`,
},
expectedFiles: map[string]string{
"/.manifest": `{"revision":"","roots":[""],"rego_version":0}
`,
// rego.v1 import added to optimized support module
"/optimized/test.rego": `package test

import rego.v1

p contains __local0__1 if {
__local0__1 = input.v
}
`,
},
},
{
note: "v0, No future keywords, not rego.v1 import capable",
v1Compatible: false,
regoV1ImportCapable: false,
files: map[string]string{
"test.rego": `package test
# METADATA
# entrypoint: true
p[v] {
v := input.v
}
`,
},
expectedFiles: map[string]string{
"/.manifest": `{"revision":"","roots":[""],"rego_version":0}
`,
// rego.v1 import NOT added to optimized support module
"/optimized/test.rego": `package test

p[__local0__1] {
__local0__1 = input.v
}
`,
},
},
{
note: "v1, No imports",
v1Compatible: true,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
# METADATA
Expand All @@ -1822,7 +1878,9 @@ foo contains __local1__1 if {
},
},
{
note: "rego.v1 imported",
note: "v1, rego.v1 imported",
v1Compatible: true,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
import rego.v1
Expand All @@ -1849,7 +1907,9 @@ foo contains __local1__1 if {
},
},
{
note: "future.keywords imported",
note: "v1, future.keywords imported",
v1Compatible: true,
regoV1ImportCapable: true,
files: map[string]string{
"test.rego": `package test
import future.keywords
Expand Down Expand Up @@ -1879,9 +1939,19 @@ foo contains __local1__1 if {
test.WithTempFS(tc.files, func(root string) {
params := newBuildParams()
params.outputFile = path.Join(root, "bundle.tar.gz")
params.v1Compatible = true
params.v1Compatible = tc.v1Compatible
params.optimizationLevel = 1

if !tc.regoV1ImportCapable {
caps := newcapabilitiesFlag()
caps.C = ast.CapabilitiesForThisVersion()
caps.C.Features = []string{
ast.FeatureRefHeadStringPrefixes,
ast.FeatureRefHeads,
}
params.capabilities = caps
}

err := dobuild(params, []string{root})

if err != nil {
Expand Down
137 changes: 137 additions & 0 deletions cmd/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,143 @@ time.clock(input.y, time.clock(input.x))
}
}

func TestEvalPartialRegoVersionOutput(t *testing.T) {
tests := []struct {
note string
regoV1ImportCapable bool
v1Compatible bool
query string
module string
expected string
}{
{
note: "v0, no future keywords",
regoV1ImportCapable: true,
query: "data.test.p",
module: `package test

p[v] {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0

# Module 1
package partial.test

import rego.v1

p contains __local0__1 if __local0__1 = input.v
`,
},
{
note: "v0, no future keywords, not rego.v1 import capable",
regoV1ImportCapable: false,
query: "data.test.p",
module: `package test

p[v] {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0

# Module 1
package partial.test

p[__local0__1] {
__local0__1 = input.v
}
`,
},
{
note: "v0, future keywords",
regoV1ImportCapable: true,
query: "data.test.p",
module: `package test

import rego.v1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the module had the future import would the optimized module use the rego import if regoV1ImportCapable: true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. When we build the support module, we don't have this information retained. We could probably trace the content of a single support rule back to it's origin module and do whatever relevant imports it does, but I'm not sure that's worth the effort, given that multiple origin modules could be contributing to the same support module, and they can all have their own import scheme.
I think the important thing here is that we produce Rego that is in line with the code style we recommend, which is to use the rego.v1 import.


p contains v if {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0

# Module 1
package partial.test

import rego.v1

p contains __local0__1 if __local0__1 = input.v
`,
},
{
note: "v1",
regoV1ImportCapable: true,
v1Compatible: true,
query: "data.test.p",
module: `package test

p contains v if {
v := input.v
}
`,
expected: `# Query 1
data.partial.test.p = _term_0_0
_term_0_0

# Module 1
package partial.test

p contains __local0__1 if __local0__1 = input.v
`,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
files := map[string]string{
"test.rego": tc.module,
}

test.WithTempFS(files, func(path string) {
params := newEvalCommandParams()
_ = params.dataPaths.Set(filepath.Join(path, "test.rego"))
params.partial = true
params.shallowInlining = true
params.v1Compatible = tc.v1Compatible
_ = params.outputFormat.Set(evalSourceOutput)

if !tc.regoV1ImportCapable {
caps := newcapabilitiesFlag()
caps.C = ast.CapabilitiesForThisVersion()
caps.C.Features = []string{
ast.FeatureRefHeadStringPrefixes,
ast.FeatureRefHeads,
}
params.capabilities = caps
}

buf := new(bytes.Buffer)
_, err := eval([]string{tc.query}, params, buf)
if err != nil {
t.Fatal("unexpected error:", err)
}
if actual := buf.String(); actual != tc.expected {
t.Errorf("expected output %q\ngot %q", tc.expected, actual)
}
})
})
}
}

func TestEvalDiscardOutput(t *testing.T) {
tests := map[string]struct {
query, format, expected string
Expand Down
11 changes: 10 additions & 1 deletion compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ func (c *Compiler) optimize(ctx context.Context) error {
WithEntrypoints(c.entrypointrefs).
WithDebug(c.debug.Writer()).
WithShallowInlining(c.optimizationLevel <= 1).
WithEnablePrintStatements(c.enablePrintStatements)
WithEnablePrintStatements(c.enablePrintStatements).
WithRegoVersion(c.regoVersion)

if c.ns != "" {
o = o.WithPartialNamespace(c.ns)
Expand Down Expand Up @@ -869,6 +870,7 @@ type optimizer struct {
shallow bool
debug debug.Debug
enablePrintStatements bool
regoVersion ast.RegoVersion
}

func newOptimizer(c *ast.Capabilities, b *bundle.Bundle) *optimizer {
Expand Down Expand Up @@ -909,6 +911,11 @@ func (o *optimizer) WithPartialNamespace(ns string) *optimizer {
return o
}

func (o *optimizer) WithRegoVersion(regoVersion ast.RegoVersion) *optimizer {
o.regoVersion = regoVersion
return o
}

func (o *optimizer) Do(ctx context.Context) error {

// NOTE(tsandall): if there are multiple entrypoints, copy the bundle because
Expand Down Expand Up @@ -958,6 +965,8 @@ func (o *optimizer) Do(ctx context.Context) error {
rego.ParsedUnknowns(unknowns),
rego.Compiler(o.compiler),
rego.Store(store),
rego.Capabilities(o.capabilities),
rego.SetRegoVersion(o.regoVersion),
)

o.debug.Printf("optimizer: entrypoint: %v", e)
Expand Down