Skip to content

Commit

Permalink
Add unlimited block unpacking (#11597)
Browse files Browse the repository at this point in the history
Co-authored-by: Johannes Müller <straightshoota@gmail.com>
  • Loading branch information
asterite and straight-shoota committed Jul 25, 2023
1 parent 8c2c248 commit b69838c
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 177 deletions.
2 changes: 2 additions & 0 deletions spec/compiler/formatter/formatter_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,8 @@ describe Crystal::Formatter do
assert_format "foo { | a, ( b , c ) | a + b + c }", "foo { |a, (b, c)| a + b + c }"
assert_format "foo { | a, ( b , c, ), | a + b + c }", "foo { |a, (b, c)| a + b + c }"
assert_format "foo { | a, ( _ , c ) | a + c }", "foo { |a, (_, c)| a + c }"
assert_format "foo { | a, ( b , (c, d) ) | a + b + c }", "foo { |a, (b, (c, d))| a + b + c }"
assert_format "foo { | ( a, *b , c ) | a }", "foo { |(a, *b, c)| a }"

assert_format "def foo\n {{@type}}\nend"

Expand Down
96 changes: 96 additions & 0 deletions spec/compiler/normalize/block_spec.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
require "../../spec_helper"

describe "Normalize: block" do
it "normalizes unpacking with empty body" do
assert_normalize <<-FROM, <<-TO
foo do |(x, y), z|
end
FROM
foo do |__temp_1, z|
x, y = __temp_1
end
TO
end

it "normalizes unpacking with single expression body" do
assert_normalize <<-FROM, <<-TO
foo do |(x, y), z|
z
end
FROM
foo do |__temp_1, z|
x, y = __temp_1
z
end
TO
end

it "normalizes unpacking with multiple body expressions" do
assert_normalize <<-FROM, <<-TO
foo do |(x, y), z|
x
y
z
end
FROM
foo do |__temp_1, z|
x, y = __temp_1
x
y
z
end
TO
end

it "normalizes unpacking with underscore" do
assert_normalize <<-FROM, <<-TO
foo do |(x, _), z|
end
FROM
foo do |__temp_1, z|
x, _ = __temp_1
end
TO
end

it "normalizes nested unpacking" do
assert_normalize <<-FROM, <<-TO
foo do |(a, (b, c))|
1
end
FROM
foo do |__temp_1|
a, __temp_2 = __temp_1
b, c = __temp_2
1
end
TO
end

it "normalizes multiple nested unpackings" do
assert_normalize <<-FROM, <<-TO
foo do |(a, (b, (c, (d, e)), f))|
1
end
FROM
foo do |__temp_1|
a, __temp_2 = __temp_1
b, __temp_3, f = __temp_2
c, __temp_4 = __temp_3
d, e = __temp_4
1
end
TO
end

it "normalizes unpacking with splat" do
assert_normalize <<-FROM, <<-TO
foo do |(x, *y, z)|
end
FROM
foo do |__temp_1|
x, *y, z = __temp_1
end
TO
end
end
74 changes: 57 additions & 17 deletions spec/compiler/parser/parser_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -770,28 +770,68 @@ module Crystal
assert_syntax_error "foo(&block) {}"

it_parses "foo { |a, (b, c), (d, e)| a; b; c; d; e }", Call.new(nil, "foo",
block: Block.new(["a".var, "__arg0".var, "__arg1".var],
block: Block.new(
["a".var, "".var, "".var],
Expressions.new([
Assign.new("b".var, Call.new("__arg0".var, "[]", 0.int32)),
Assign.new("c".var, Call.new("__arg0".var, "[]", 1.int32)),
Assign.new("d".var, Call.new("__arg1".var, "[]", 0.int32)),
Assign.new("e".var, Call.new("__arg1".var, "[]", 1.int32)),
"a".var, "b".var, "c".var, "d".var, "e".var,
] of ASTNode)))
"a".var,
"b".var,
"c".var,
"d".var,
"e".var,
] of ASTNode),
unpacks: {
1 => Expressions.new(["b".var, "c".var] of ASTNode),
2 => Expressions.new(["d".var, "e".var] of ASTNode),
},
),
)

it_parses "foo { |(_, c)| c }", Call.new(nil, "foo",
block: Block.new(["__arg0".var],
Expressions.new([
Assign.new("c".var, Call.new("__arg0".var, "[]", 1.int32)),
"c".var,
] of ASTNode)))
block: Block.new(["".var],
"c".var,
unpacks: {0 => Expressions.new([Underscore.new, "c".var] of ASTNode)},
)
)

it_parses "foo { |(_, c, )| c }", Call.new(nil, "foo",
block: Block.new(["__arg0".var],
Expressions.new([
Assign.new("c".var, Call.new("__arg0".var, "[]", 1.int32)),
"c".var,
] of ASTNode)))
block: Block.new(["".var],
"c".var,
unpacks: {0 => Expressions.new([Underscore.new, "c".var] of ASTNode)},
)
)

it_parses "foo { |(a, (b, (c, d)))| }", Call.new(nil, "foo",
block: Block.new(
["".var],
Nop.new,
unpacks: {
0 => Expressions.new([
"a".var,
Expressions.new([
"b".var,
Expressions.new([
"c".var,
"d".var,
] of ASTNode),
]),
]),
},
),
)

it_parses "foo { |(a, *b, c)| }", Call.new(nil, "foo",
block: Block.new(
["".var],
Nop.new,
unpacks: {
0 => Expressions.new([
"a".var,
Splat.new("b".var),
"c".var,
]),
},
),
)

assert_syntax_error "foo { |a b| }", "expecting ',' or '|', not b"
assert_syntax_error "foo { |(a b)| }", "expecting ',' or ')', not b"
Expand Down
10 changes: 10 additions & 0 deletions spec/compiler/parser/to_s_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ describe "ASTNode#to_s" do
expect_to_s "->::foo(Int32, String)"
expect_to_s "->::Foo::Bar.foo"
expect_to_s "yield(1)"
expect_to_s "foo { |(x, y)| x }", <<-CODE
foo do |(x, y)|
x
end
CODE
expect_to_s "foo { |(x, (y, z))| x }", <<-CODE
foo do |(x, (y, z))|
x
end
CODE
expect_to_s "def foo\n yield\nend", "def foo(&)\n yield\nend"
expect_to_s "def foo(x)\n yield\nend", "def foo(x, &)\n yield\nend"
expect_to_s "def foo(**x)\n yield\nend", "def foo(**x, &)\n yield\nend"
Expand Down
12 changes: 12 additions & 0 deletions spec/compiler/semantic/block_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,18 @@ describe "Block inference" do
CRYSTAL
end

it "unpacks block argument" do
assert_type(%(
def foo
yield({1, 'a'})
end
foo do |(x, y)|
{x, y}
end
)) { tuple_of([int32, char]) }
end

it "correctly types unpacked tuple block arg after block (#3339)" do
assert_type(%(
def foo
Expand Down
57 changes: 38 additions & 19 deletions src/compiler/crystal/semantic/main_visitor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,6 @@ module Crystal

before_block_vars = node.vars.try(&.dup) || MetaVars.new

arg_counter = 0
body_exps = node.body.as?(Expressions).try(&.expressions)

# Variables that we don't want to get their type merged
Expand All @@ -1025,28 +1024,32 @@ module Crystal
ignored_vars_after_block = nil

meta_vars = @meta_vars.dup
node.args.each do |arg|
# The parser generates __argN block arguments for tuple unpacking,
# and they need a special treatment because they shouldn't override
# local variables. So we search the unpacked vars in the body.
if arg.name.starts_with?("__arg") && body_exps
ignored_vars_after_block = node.args.dup

while arg_counter < body_exps.size &&
(assign = body_exps[arg_counter]).is_a?(Assign) &&
(target = assign.target).is_a?(Var) &&
(call = assign.value).is_a?(Call) &&
(call_var = call.obj).is_a?(Var) &&
call_var.name == arg.name
bind_block_var(node, target, meta_vars, before_block_vars)
ignored_vars_after_block << Var.new(target.name)
arg_counter += 1
end
end

node.args.each do |arg|
bind_block_var(node, arg, meta_vars, before_block_vars)
end

# If the block has unpacking, like:
#
# do |(x, y)|
# ...
# end
#
# it was transformed to unpack the block vars inside the body:
#
# do |__temp_1|
# x, y = __temp_1
# ...
# end
#
# We need to treat these variables as block arguments (so they don't override existing local variables).
if unpacks = node.unpacks
ignored_vars_after_block = node.args.dup
unpacks.each_value do |unpack|
handle_unpacked_block_argument(node, unpack, meta_vars, before_block_vars, ignored_vars_after_block)
end
end

@block_nest += 1

block_visitor = MainVisitor.new(program, before_block_vars, @typed_def, meta_vars)
Expand Down Expand Up @@ -1096,6 +1099,22 @@ module Crystal
false
end

def handle_unpacked_block_argument(node, arg, meta_vars, before_block_vars, ignored_vars_after_block)
case arg
when Var
bind_block_var(node, arg, meta_vars, before_block_vars)
ignored_vars_after_block << Var.new(arg.name)
when Underscore
# Nothing
when Splat
handle_unpacked_block_argument(node, arg.exp, meta_vars, before_block_vars, ignored_vars_after_block)
when Expressions
arg.expressions.each do |exp|
handle_unpacked_block_argument(node, exp, meta_vars, before_block_vars, ignored_vars_after_block)
end
end
end

def bind_block_var(node, target, meta_vars, before_block_vars)
meta_var = new_meta_var(target.name, context: node)
meta_var.bind_to(target)
Expand Down
76 changes: 76 additions & 0 deletions src/compiler/crystal/semantic/normalizer.cr
Original file line number Diff line number Diff line change
Expand Up @@ -428,5 +428,81 @@ module Crystal

super
end

# Turn block argument unpacking to multi assigns at the beginning
# of a block.
#
# So this:
#
# foo do |(x, y), z|
# x + y + z
# end
#
# is transformed to:
#
# foo do |__temp_1, z|
# x, y = __temp_1
# x + y + z
# end
def transform(node : Block)
node = super

unpacks = node.unpacks
return node unless unpacks

extra_expressions = [] of ASTNode
next_unpacks = [] of {String, Expressions}

unpacks.each do |index, expressions|
temp_name = program.new_temp_var_name
node.args[index] = Var.new(temp_name).at(node.args[index])

extra_expressions << block_unpack_multiassign(temp_name, expressions, next_unpacks)
end

if next_unpacks
while next_unpack = next_unpacks.shift?
var_name, expressions = next_unpack

extra_expressions << block_unpack_multiassign(var_name, expressions, next_unpacks)
end
end

body = node.body
case body
when Nop
node.body = Expressions.new(extra_expressions).at(node.body)
when Expressions
body.expressions = extra_expressions + body.expressions
else
extra_expressions << node.body
node.body = Expressions.new(extra_expressions).at(node.body)
end

node
end

private def block_unpack_multiassign(var_name, expressions, next_unpacks)
targets = expressions.expressions.map do |exp|
case exp
when Var
exp
when Underscore
exp
when Splat
exp
when Expressions
next_temp_name = program.new_temp_var_name

next_unpacks << {next_temp_name, exp}

Var.new(next_temp_name).at(exp)
else
raise "BUG: unexpected block var #{exp} (#{exp.class})"
end
end
values = [Var.new(var_name).at(expressions)] of ASTNode
MultiAssign.new(targets, values).at(expressions)
end
end
end
Loading

0 comments on commit b69838c

Please sign in to comment.