Skip to content

Commit

Permalink
cmd & format: Adding rego-v1 mode to opa fmt (#6413)
Browse files Browse the repository at this point in the history
Fixes: #6297
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Nov 30, 2023
1 parent 4f9058b commit 187d688
Show file tree
Hide file tree
Showing 26 changed files with 917 additions and 97 deletions.
108 changes: 30 additions & 78 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ func (c *Compiler) buildRequiredCapabilities() {
for _, imp := range c.imports[name] {
path := imp.Path.Value.(Ref)
switch {
case path.Equal(regoV1CompatibleRef):
case path.Equal(RegoV1CompatibleRef):
features[FeatureRegoV1Import] = struct{}{}
case path.HasPrefix(futureKeywordsPrefix):
if len(path) == 2 {
Expand Down Expand Up @@ -1537,9 +1537,11 @@ func (c *Compiler) checkUnsafeBuiltins() {
func (c *Compiler) checkDeprecatedBuiltins() {
for _, name := range c.sorted {
mod := c.Modules[name]
errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, mod, c.strict || mod.regoV1Compatible)
for _, err := range errs {
c.err(err)
if c.strict || mod.regoV1Compatible {
errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, mod)
for _, err := range errs {
c.err(err)
}
}
}
}
Expand Down Expand Up @@ -1677,67 +1679,31 @@ func (c *Compiler) GetAnnotationSet() *AnnotationSet {
}

func (c *Compiler) checkDuplicateImports() {
modules := make([]*Module, 0, len(c.Modules))

for _, name := range c.sorted {
mod := c.Modules[name]
if !c.strict && !mod.regoV1Compatible {
continue
if c.strict || mod.regoV1Compatible {
modules = append(modules, mod)
}
}

processedImports := map[Var]*Import{}

for _, imp := range mod.Imports {
name := imp.Name()

if processed, conflict := processedImports[name]; conflict {
c.err(NewError(CompileErr, imp.Location, "import must not shadow %v", processed))
} else {
processedImports[name] = imp
}
}
errs := checkDuplicateImports(modules)
for _, err := range errs {
c.err(err)
}
}

func (c *Compiler) checkKeywordOverrides() {
for _, name := range c.sorted {
mod := c.Modules[name]
errs := checkKeywordOverrides(mod, c.strict || mod.regoV1Compatible)
for _, err := range errs {
c.err(err)
}
}
}

func checkKeywordOverrides(node interface{}, strict bool) Errors {
if !strict {
return nil
}

errs := Errors{}

WalkRules(node, func(rule *Rule) bool {
var name string
if len(rule.Head.Reference) > 0 {
name = rule.Head.Reference[0].Value.(Var).String()
} else {
name = rule.Head.Name.String()
}
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errs = append(errs, NewError(CompileErr, rule.Location, "rules must not shadow %v (use a different rule name)", name))
}
return true
})

WalkExprs(node, func(expr *Expr) bool {
if expr.IsAssignment() {
name := expr.Operand(0).String()
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errs = append(errs, NewError(CompileErr, expr.Location, "variables must not shadow %v (use a different variable name)", name))
if c.strict || mod.regoV1Compatible {
errs := checkRootDocumentOverrides(mod)
for _, err := range errs {
c.err(err)
}
}
return false
})

return errs
}
}

// resolveAllRefs resolves references in expressions to their fully qualified values.
Expand All @@ -1757,7 +1723,6 @@ func checkKeywordOverrides(node interface{}, strict bool) Errors {
// c.d[e] := 1 if e := "e"
//
// The reference "c.d.e" would be resolved to "data.a.b.c.d.e".

func (c *Compiler) resolveAllRefs() {

rules := c.getExports()
Expand Down Expand Up @@ -2848,8 +2813,10 @@ func (qc *queryCompiler) applyErrorLimit(err error) error {
}

func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) {
if errs := checkKeywordOverrides(body, qc.compiler.strict); len(errs) > 0 {
return nil, errs
if qc.compiler.strict {
if errs := checkRootDocumentOverrides(body); len(errs) > 0 {
return nil, errs
}
}
return body, nil
}
Expand Down Expand Up @@ -2985,9 +2952,11 @@ func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} {
}

func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) {
errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body, qc.compiler.strict)
if len(errs) > 0 {
return nil, errs
if qc.compiler.strict {
errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body)
if len(errs) > 0 {
return nil, errs
}
}
return body, nil
}
Expand Down Expand Up @@ -4226,6 +4195,8 @@ func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error {
// this would require rewriting terms in the head and body.
// Preventing root document shadowing is simpler, and
// arguably, will prevent confusing names from being used.
// NOTE: this check is also performed as part of strict-mode in
// checkRootDocumentOverrides.
err = fmt.Errorf("args must not shadow %v (use a different variable name)", x)
return true
}
Expand Down Expand Up @@ -5701,25 +5672,6 @@ func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}
return errs
}

func checkDeprecatedBuiltins(deprecatedBuiltinsMap map[string]struct{}, node interface{}, strict bool) Errors {
// Early out; deprecatedBuiltinsMap is only populated in strict-mode.
if !strict {
return nil
}

errs := make(Errors, 0)
WalkExprs(node, func(x *Expr) bool {
if x.IsCall() {
operator := x.Operator().String()
if _, ok := deprecatedBuiltinsMap[operator]; ok {
errs = append(errs, NewError(TypeErr, x.Loc(), "deprecated built-in function calls in expression: %v", operator))
}
}
return false
})
return errs
}

func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
return func(node Ref) Ref {
i, _ := TransformVars(node, func(v Var) (Value, error) {
Expand Down
12 changes: 6 additions & 6 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"github.com/open-policy-agent/opa/ast/location"
)

var regoV1CompatibleRef = Ref{VarTerm("rego"), StringTerm("v1")}
var RegoV1CompatibleRef = Ref{VarTerm("rego"), StringTerm("v1")}

// Note: This state is kept isolated from the parser so that we
// can do efficient shallow copies of these values when doing a
Expand Down Expand Up @@ -2528,7 +2528,7 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke
}

if p.s.s.RegoV1Compatible() {
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", regoV1CompatibleRef)
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", RegoV1CompatibleRef)
return
}

Expand Down Expand Up @@ -2562,14 +2562,14 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke

func (p *Parser) regoV1Import(imp *Import) {
if !p.po.Capabilities.ContainsFeature(FeatureRegoV1Import) {
p.errorf(imp.Path.Location, "invalid import, `%s` is not supported by current capabilities", regoV1CompatibleRef)
p.errorf(imp.Path.Location, "invalid import, `%s` is not supported by current capabilities", RegoV1CompatibleRef)
return
}

path := imp.Path.Value.(Ref)

if len(path) == 1 || !path[1].Equal(regoV1CompatibleRef[1]) || len(path) > 2 {
p.errorf(imp.Path.Location, "invalid import, must be `%s`", regoV1CompatibleRef)
if len(path) == 1 || !path[1].Equal(RegoV1CompatibleRef[1]) || len(path) > 2 {
p.errorf(imp.Path.Location, "invalid import, must be `%s`", RegoV1CompatibleRef)
return
}

Expand All @@ -2586,7 +2586,7 @@ func (p *Parser) regoV1Import(imp *Import) {

if p.s.s.HasKeyword(futureKeywords) && !p.s.s.RegoV1Compatible() {
// We have imported future keywords, but they didn't come from another `rego.v1` import.
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", regoV1CompatibleRef)
p.errorf(imp.Path.Location, "the `%s` import implies `future.keywords`, these are therefore mutually exclusive", RegoV1CompatibleRef)
return
}

Expand Down
2 changes: 1 addition & 1 deletion ast/parser_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu
switch stmt := stmt.(type) {
case *Import:
mod.Imports = append(mod.Imports, stmt)
if Compare(stmt.Path.Value, regoV1CompatibleRef) == 0 {
if Compare(stmt.Path.Value, RegoV1CompatibleRef) == 0 {
mod.regoV1Compatible = true
}
case *Rule:
Expand Down
126 changes: 126 additions & 0 deletions ast/rego_v1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package ast

func checkDuplicateImports(modules []*Module) (errors Errors) {
for _, module := range modules {
processedImports := map[Var]*Import{}

for _, imp := range module.Imports {
name := imp.Name()

if processed, conflict := processedImports[name]; conflict {
errors = append(errors, NewError(CompileErr, imp.Location, "import must not shadow %v", processed))
} else {
processedImports[name] = imp
}
}
}
return
}

func checkRootDocumentOverrides(node interface{}) Errors {
errors := Errors{}

WalkRules(node, func(rule *Rule) bool {
var name string
if len(rule.Head.Reference) > 0 {
name = rule.Head.Reference[0].Value.(Var).String()
} else {
name = rule.Head.Name.String()
}
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errors = append(errors, NewError(CompileErr, rule.Location, "rules must not shadow %v (use a different rule name)", name))
}

for _, arg := range rule.Head.Args {
if _, ok := arg.Value.(Ref); ok {
if RootDocumentRefs.Contains(arg) {
errors = append(errors, NewError(CompileErr, arg.Location, "args must not shadow %v (use a different variable name)", arg))
}
}
}

return true
})

WalkExprs(node, func(expr *Expr) bool {
if expr.IsAssignment() {
name := expr.Operand(0).String()
if RootDocumentRefs.Contains(RefTerm(VarTerm(name))) {
errors = append(errors, NewError(CompileErr, expr.Location, "variables must not shadow %v (use a different variable name)", name))
}
}
return false
})

return errors
}

func walkCalls(node interface{}, f func(interface{}) bool) {
vis := &GenericVisitor{func(x interface{}) bool {
switch x := x.(type) {
case Call:
return f(x)
case *Expr:
if x.IsCall() {
return f(x)
}
case *Head:
// GenericVisitor doesn't walk the rule head ref
walkCalls(x.Reference, f)
}
return false
}}
vis.Walk(node)
}

func checkDeprecatedBuiltins(deprecatedBuiltinsMap map[string]struct{}, node interface{}) Errors {
errs := make(Errors, 0)

walkCalls(node, func(x interface{}) bool {
var operator string
var loc *Location

switch x := x.(type) {
case *Expr:
operator = x.Operator().String()
loc = x.Loc()
case Call:
terms := []*Term(x)
if len(terms) > 0 {
operator = terms[0].Value.String()
loc = terms[0].Loc()
}
}

if operator != "" {
if _, ok := deprecatedBuiltinsMap[operator]; ok {
errs = append(errs, NewError(TypeErr, loc, "deprecated built-in function calls in expression: %v", operator))
}
}

return false
})

return errs
}

func checkDeprecatedBuiltinsForCurrentVersion(node interface{}) Errors {
deprecatedBuiltins := make(map[string]struct{})
capabilities := CapabilitiesForThisVersion()
for _, bi := range capabilities.Builtins {
if bi.IsDeprecated() {
deprecatedBuiltins[bi.Name] = struct{}{}
}
}

return checkDeprecatedBuiltins(deprecatedBuiltins, node)
}

// CheckRegoV1 checks the given module for errors that are specific to Rego v1
func CheckRegoV1(module *Module) Errors {
var errors Errors
errors = append(errors, checkDuplicateImports([]*Module{module})...)
errors = append(errors, checkRootDocumentOverrides(module)...)
errors = append(errors, checkDeprecatedBuiltinsForCurrentVersion(module)...)
return errors
}

0 comments on commit 187d688

Please sign in to comment.