Skip to content

Commit

Permalink
format: don't group iterable when one has defaulted location (#4260)
Browse files Browse the repository at this point in the history
As mentioned in the comment, empty file names happen when the format
package's Ast() function does a sweep of its input, and adds a
"default location" to everything that has a nil location.

During PE, when generated the pairs to save in saveUnify, we'll
return Var Terms without locations. Fixing that seemed like a bigger
hurdle, so I went this route.

The new check is such that if any term has the default file in
its location, such as would happen if we're formatting code that
was created programmatically (not parsed), we'll group the terms'
elements, but print them in one line.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Jan 25, 2022
1 parent d496d92 commit 932e4ff
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 11 deletions.
30 changes: 20 additions & 10 deletions format/format.go
Expand Up @@ -14,6 +14,11 @@ import (
"github.com/open-policy-agent/opa/ast"
)

// defaultLocationFile is the file name used in `Ast()` for terms
// without a location, as could happen when pretty-printing the
// results of partial eval.
const defaultLocationFile = "__format_default__"

// Source formats a Rego source file. The bytes provided must describe a complete
// Rego module. If they don't, Source will return an error resulting from the attempt
// to parse the bytes.
Expand Down Expand Up @@ -154,7 +159,7 @@ func squashTrailingNewlines(bs []byte) []byte {
}

func defaultLocation(x ast.Node) *ast.Location {
return ast.NewLocation([]byte(x.String()), "", 1, 1)
return ast.NewLocation([]byte(x.String()), defaultLocationFile, 1, 1)
}

type writer struct {
Expand Down Expand Up @@ -895,22 +900,27 @@ func (w *writer) listWriter() entryWriter {
// location: anything on the same line will be put into a slice.
func groupIterable(elements []interface{}, last *ast.Location) [][]interface{} {
// Generated vars occur in the AST when we're rendering the result of
// partial evaluation in a bundle build with optimization. For those vars,
// there is no location, and the grouping based on source location will
// yield a bad result. So if there's a generated variable among elements,
// we'll render the elements all in one line.
vis := ast.NewVarVisitor()
// partial evaluation in a bundle build with optimization.
// Those variables, and wildcard variables have the "default location",
// set in `Ast()`). That is no proper file location, and the grouping
// based on source location will yield a bad result.
def := false // default location found?
for _, elem := range elements {
vis.Walk(elem)
}
for v := range vis.Vars() {
if v.IsGenerated() {
ast.WalkTerms(elem, func(t *ast.Term) bool {
if t.Location.File == defaultLocationFile {
def = true
return true
}
return false
})
if def { // return as-is
return [][]interface{}{elements}
}
}
sort.Slice(elements, func(i, j int) bool {
return locLess(elements[i], elements[j])
})

var lines [][]interface{}
var cur []interface{}
for i, t := range elements {
Expand Down
48 changes: 47 additions & 1 deletion format/format_test.go
Expand Up @@ -336,6 +336,52 @@ a[_x[y]]`,
expected: `_x
a[_x[y][[z, w]]]`,
},
{
note: "expr with wildcard that has a default location",
toFmt: func() *ast.Expr {
expr := ast.MustParseExpr(`["foo", _] = split(input.foo, ":")`)
ast.WalkTerms(expr, func(term *ast.Term) bool {
v, ok := term.Value.(ast.Var)
if ok && v.IsWildcard() {
term.Location = defaultLocation(term)
return true
}
term.Location.File = "foo.rego"
term.Location.Row = 2
return false
})
return expr
}(),
expected: `["foo", _] = split(input.foo, ":")`,
},
{
note: "expr all terms having empty-file locations",
toFmt: ast.MustParseExpr(`[
"foo",
_
] = split(input.foo, ":")`),
expected: `
[
"foo",
_,
] = split(input.foo, ":")`,
},
{
note: "expr where all terms having empty-file locations, and one is a default location",
toFmt: func() *ast.Expr {
expr := ast.MustParseExpr(`
["foo", __local1__] = split(input.foo, ":")`)
ast.WalkTerms(expr, func(term *ast.Term) bool {
if ast.VarTerm("__local1__").Equal(term) {
term.Location = defaultLocation(term)
return true
}
return false
})
return expr
}(),
expected: `["foo", __local1__] = split(input.foo, ":")`,
},
}

for _, tc := range cases {
Expand All @@ -347,7 +393,7 @@ a[_x[y][[z, w]]]`,
expected := strings.TrimSpace(tc.expected)
actual := strings.TrimSpace(string(bs))
if actual != expected {
t.Fatalf("Expected:\n\n%s\n\nGot:\n\n%s\n\n", expected, actual)
t.Fatalf("Expected:\n\n%q\n\nGot:\n\n%q\n\n", expected, actual)
}
})
}
Expand Down

0 comments on commit 932e4ff

Please sign in to comment.