diff --git a/ast/compile.go b/ast/compile.go index 65907eebc1..7292926acc 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -2171,28 +2171,6 @@ func (c *Compiler) rewriteLocalVars() { gen := c.localvargen WalkRules(mod, func(rule *Rule) bool { - // Rewrite assignments contained in head of rule. Assignments can - // occur in rule head if they're inside a comprehension. Note, - // assigned vars in comprehensions in the head will be rewritten - // first to preserve scoping rules. For example: - // - // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } - // - // This behaviour is consistent scoping inside the body. For example: - // - // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } - nestedXform := &rewriteNestedHeadVarLocalTransform{ - gen: gen, - RewrittenVars: c.RewrittenVars, - strict: c.strict, - } - - NewGenericVisitor(nestedXform.Visit).Walk(rule.Head) - - for _, err := range nestedXform.errs { - c.err(err) - } - argsStack := newLocalDeclaredVars() args := NewVarVisitor() @@ -2235,6 +2213,28 @@ func (c *Compiler) rewriteLocalVars() { } func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) { + // Rewrite assignments contained in head of rule. Assignments can + // occur in rule head if they're inside a comprehension. Note, + // assigned vars in comprehensions in the head will be rewritten + // first to preserve scoping rules. For example: + // + // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } + // + // This behaviour is consistent scoping inside the body. For example: + // + // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } + nestedXform := &rewriteNestedHeadVarLocalTransform{ + gen: gen, + RewrittenVars: c.RewrittenVars, + strict: c.strict, + } + + NewGenericVisitor(nestedXform.Visit).Walk(rule.Head) + + for _, err := range nestedXform.errs { + c.err(err) + } + // Rewrite assignments in body. used := NewVarSet() diff --git a/ast/compile_test.go b/ast/compile_test.go index 2aa0185e70..14c9da7fd0 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -3378,6 +3378,50 @@ q = [true | true] { true } `), exp: MustParseRule(`q = __local0__ { true; __local0__ = [true | true] }`), }, + { + note: "array comprehension value in else head", + mod: MustParseModule(`package head +q { + false +} else = [true | true] { + true +} +`), + exp: MustParseRule(`q = true { false } else = __local0__ { true; __local0__ = [true | true] }`), + }, + { + note: "array comprehension value in head (comprehension-local var)", + mod: MustParseModule(`package head +q = [a | a := true] { + false +} else = [a | a := true] { + true +} +`), + exp: MustParseRule(`q = __local2__ { false; __local2__ = [__local0__ | __local0__ = true] } else = __local3__ { true; __local3__ = [__local1__ | __local1__ = true] }`), + }, + { + note: "array comprehension value in function head (comprehension-local var)", + mod: MustParseModule(`package head +f(x) = [a | a := true] { + false +} else = [a | a := true] { + true +} +`), + exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = [__local1__ | __local1__ = true] } else = __local4__ { true; __local4__ = [__local2__ | __local2__ = true] }`), + }, + { + note: "array comprehension value in else-func head (reused arg rewrite)", + mod: MustParseModule(`package head +f(x, y) = [x | y] { + false +} else = [x | y] { + true +} +`), + exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = [__local0__ | __local1__] } else = __local3__ { true; __local3__ = [__local0__ | __local1__] }`), + }, { note: "object comprehension value", mod: MustParseModule(`package head @@ -3385,6 +3429,50 @@ r = {"true": true | true} { true } `), exp: MustParseRule(`r = __local0__ { true; __local0__ = {"true": true | true} }`), }, + { + note: "object comprehension value in else head", + mod: MustParseModule(`package head +q { + false +} else = {"true": true | true} { + true +} +`), + exp: MustParseRule(`q = true { false } else = __local0__ { true; __local0__ = {"true": true | true} }`), + }, + { + note: "object comprehension value in head (comprehension-local var)", + mod: MustParseModule(`package head +q = {"a": a | a := true} { + false +} else = {"a": a | a := true} { + true +} +`), + exp: MustParseRule(`q = __local2__ { false; __local2__ = {"a": __local0__ | __local0__ = true} } else = __local3__ { true; __local3__ = {"a": __local1__ | __local1__ = true} }`), + }, + { + note: "object comprehension value in function head (comprehension-local var)", + mod: MustParseModule(`package head +f(x) = {"a": a | a := true} { + false +} else = {"a": a | a := true} { + true +} +`), + exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = {"a": __local1__ | __local1__ = true} } else = __local4__ { true; __local4__ = {"a": __local2__ | __local2__ = true} }`), + }, + { + note: "object comprehension value in else-func head (reused arg rewrite)", + mod: MustParseModule(`package head +f(x, y) = {x: y | true} { + false +} else = {x: y | true} { + true +} +`), + exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = {__local0__: __local1__ | true} } else = __local3__ { true; __local3__ = {__local0__: __local1__ | true} }`), + }, { note: "set comprehension value", mod: MustParseModule(`package head @@ -3392,6 +3480,50 @@ s = {true | true} { true } `), exp: MustParseRule(`s = __local0__ { true; __local0__ = {true | true} }`), }, + { + note: "set comprehension value in else head", + mod: MustParseModule(`package head +q = {false | false} { + false +} else = {true | true} { + true +} +`), + exp: MustParseRule(`q = __local0__ { false; __local0__ = {false | false} } else = __local1__ { true; __local1__ = {true | true} }`), + }, + { + note: "set comprehension value in head (comprehension-local var)", + mod: MustParseModule(`package head +q = {a | a := true} { + false +} else = {a | a := true} { + true +} +`), + exp: MustParseRule(`q = __local2__ { false; __local2__ = {__local0__ | __local0__ = true} } else = __local3__ { true; __local3__ = {__local1__ | __local1__ = true} }`), + }, + { + note: "set comprehension value in function head (comprehension-local var)", + mod: MustParseModule(`package head +f(x) = {a | a := true} { + false +} else = {a | a := true} { + true +} +`), + exp: MustParseRule(`f(__local0__) = __local3__ { false; __local3__ = {__local1__ | __local1__ = true} } else = __local4__ { true; __local4__ = {__local2__ | __local2__ = true} }`), + }, + { + note: "set comprehension value in else-func head (reused arg rewrite)", + mod: MustParseModule(`package head +f(x, y) = {x | y} { + false +} else = {x | y} { + true +} +`), + exp: MustParseRule(`f(__local0__, __local1__) = __local2__ { false; __local2__ = {__local0__ | __local1__} } else = __local3__ { true; __local3__ = {__local0__ | __local1__} }`), + }, { note: "import in else value", mod: MustParseModule(`package head @@ -5157,6 +5289,17 @@ func TestCheckUnusedFunctionArgVars(t *testing.T) { }`, expectedErrors: Errors{}, }, + { + note: "argvar not used in body but in else-head value comprehension", + module: `package test + a := {"foo": 1} + func(x) { + input.test == "foo" + } else := { x: v | v := a[x] } { + input.test == "bar" + }`, + expectedErrors: Errors{}, + }, { note: "argvar not used in body and shadowed in head value comprehension", module: `package test