Skip to content

Commit

Permalink
ast: Fixing bug where comprehensions in rule else-heads weren't rewri…
Browse files Browse the repository at this point in the history
…tten correctly (#5772)

Previously, vars in the rule head were only rewritten for the "primary" rule,
and not for else branches. This has been fixed by walking the head of each branch
and rewriting all vars found.

Fixes: #5771

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Mar 21, 2023
1 parent a9cbc04 commit bdd4604
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 22 deletions.
44 changes: 22 additions & 22 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -2171,28 +2171,6 @@ func (c *Compiler) rewriteLocalVars() {
gen := c.localvargen

WalkRules(mod, func(rule *Rule) bool {
// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
// first to preserve scoping rules. For example:
//
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
//
// This behaviour is consistent scoping inside the body. For example:
//
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
nestedXform := &rewriteNestedHeadVarLocalTransform{
gen: gen,
RewrittenVars: c.RewrittenVars,
strict: c.strict,
}

NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)

for _, err := range nestedXform.errs {
c.err(err)
}

argsStack := newLocalDeclaredVars()

args := NewVarVisitor()
Expand Down Expand Up @@ -2235,6 +2213,28 @@ func (c *Compiler) rewriteLocalVars() {
}

func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) {
// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
// first to preserve scoping rules. For example:
//
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
//
// This behaviour is consistent scoping inside the body. For example:
//
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
nestedXform := &rewriteNestedHeadVarLocalTransform{
gen: gen,
RewrittenVars: c.RewrittenVars,
strict: c.strict,
}

NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)

for _, err := range nestedXform.errs {
c.err(err)
}

// Rewrite assignments in body.
used := NewVarSet()

Expand Down
143 changes: 143 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3378,20 +3378,152 @@ q = [true | true] { true }
`),
exp: MustParseRule(`q = __local0__ { true; __local0__ = [true | true] }`),
},
{
note: "array comprehension value in else head",
mod: MustParseModule(`package head
q {
false
} else = [true | true] {
true
}
`),
exp: MustParseRule(`q = true { false } else = __local0__ { true; __local0__ = [true | true] }`),
},
{
note: "array comprehension value in head (comprehension-local var)",
mod: MustParseModule(`package head
q = [a | a := true] {
false
} else = [a | a := true] {
true
}
`),
exp: MustParseRule(`q = __local2__ { false; __local2__ = [__local0__ | __local0__ = true] } else = __local3__ { true; __local3__ = [__local1__ | __local1__ = true] }`),
},
{
note: "array comprehension value in function head (comprehension-local var)",
mod: MustParseModule(`package head
f(x) = [a | a := true] {
false
} else = [a | a := true] {
true
}
`),
exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = [__local1__ | __local1__ = true] } else = __local4__ { true; __local4__ = [__local2__ | __local2__ = true] }`),
},
{
note: "array comprehension value in else-func head (reused arg rewrite)",
mod: MustParseModule(`package head
f(x, y) = [x | y] {
false
} else = [x | y] {
true
}
`),
exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = [__local0__ | __local1__] } else = __local3__ { true; __local3__ = [__local0__ | __local1__] }`),
},
{
note: "object comprehension value",
mod: MustParseModule(`package head
r = {"true": true | true} { true }
`),
exp: MustParseRule(`r = __local0__ { true; __local0__ = {"true": true | true} }`),
},
{
note: "object comprehension value in else head",
mod: MustParseModule(`package head
q {
false
} else = {"true": true | true} {
true
}
`),
exp: MustParseRule(`q = true { false } else = __local0__ { true; __local0__ = {"true": true | true} }`),
},
{
note: "object comprehension value in head (comprehension-local var)",
mod: MustParseModule(`package head
q = {"a": a | a := true} {
false
} else = {"a": a | a := true} {
true
}
`),
exp: MustParseRule(`q = __local2__ { false; __local2__ = {"a": __local0__ | __local0__ = true} } else = __local3__ { true; __local3__ = {"a": __local1__ | __local1__ = true} }`),
},
{
note: "object comprehension value in function head (comprehension-local var)",
mod: MustParseModule(`package head
f(x) = {"a": a | a := true} {
false
} else = {"a": a | a := true} {
true
}
`),
exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = {"a": __local1__ | __local1__ = true} } else = __local4__ { true; __local4__ = {"a": __local2__ | __local2__ = true} }`),
},
{
note: "object comprehension value in else-func head (reused arg rewrite)",
mod: MustParseModule(`package head
f(x, y) = {x: y | true} {
false
} else = {x: y | true} {
true
}
`),
exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = {__local0__: __local1__ | true} } else = __local3__ { true; __local3__ = {__local0__: __local1__ | true} }`),
},
{
note: "set comprehension value",
mod: MustParseModule(`package head
s = {true | true} { true }
`),
exp: MustParseRule(`s = __local0__ { true; __local0__ = {true | true} }`),
},
{
note: "set comprehension value in else head",
mod: MustParseModule(`package head
q = {false | false} {
false
} else = {true | true} {
true
}
`),
exp: MustParseRule(`q = __local0__ { false; __local0__ = {false | false} } else = __local1__ { true; __local1__ = {true | true} }`),
},
{
note: "set comprehension value in head (comprehension-local var)",
mod: MustParseModule(`package head
q = {a | a := true} {
false
} else = {a | a := true} {
true
}
`),
exp: MustParseRule(`q = __local2__ { false; __local2__ = {__local0__ | __local0__ = true} } else = __local3__ { true; __local3__ = {__local1__ | __local1__ = true} }`),
},
{
note: "set comprehension value in function head (comprehension-local var)",
mod: MustParseModule(`package head
f(x) = {a | a := true} {
false
} else = {a | a := true} {
true
}
`),
exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = {__local1__ | __local1__ = true} } else = __local4__ { true; __local4__ = {__local2__ | __local2__ = true} }`),
},
{
note: "set comprehension value in else-func head (reused arg rewrite)",
mod: MustParseModule(`package head
f(x, y) = {x | y} {
false
} else = {x | y} {
true
}
`),
exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = {__local0__ | __local1__} } else = __local3__ { true; __local3__ = {__local0__ | __local1__} }`),
},
{
note: "import in else value",
mod: MustParseModule(`package head
Expand Down Expand Up @@ -5157,6 +5289,17 @@ func TestCheckUnusedFunctionArgVars(t *testing.T) {
}`,
expectedErrors: Errors{},
},
{
note: "argvar not used in body but in else-head value comprehension",
module: `package test
a := {"foo": 1}
func(x) {
input.test == "foo"
} else := { x: v | v := a[x] } {
input.test == "bar"
}`,
expectedErrors: Errors{},
},
{
note: "argvar not used in body and shadowed in head value comprehension",
module: `package test
Expand Down

0 comments on commit bdd4604

Please sign in to comment.