diff --git a/hphp/hack/src/parser/coroutine/coroutine_state_machine_generator.ml b/hphp/hack/src/parser/coroutine/coroutine_state_machine_generator.ml index 39cfaca76d46a..137dc71bd1c67 100644 --- a/hphp/hack/src/parser/coroutine/coroutine_state_machine_generator.ml +++ b/hphp/hack/src/parser/coroutine/coroutine_state_machine_generator.ml @@ -294,6 +294,117 @@ let rewrite_for next_loop_label node = next_loop_label, Rewriter.Result.Keep in Rewriter.aggregating_rewrite_post rewrite node next_loop_label +(** + * Rewrites foreach blocks + * + * foreach (expression as $k => $v) { + * // User code here + * } + * + * gets rewritten as: + * + * $__iterator_0 = new \CoroutineForeachHelper($x); + * $__iterator_index_0 = 0; + * while ($__iterator_0->valid($__iterator_index_0))) { + * $key = $__iterator_0->key($__iterator_index_0); + * $val = $__iterator_0->current($__iterator_index_0); + * + * // User code goes here + * + * $__iterator_index_0 = $__iterator_0->next($__iterator_index_0); + * } + *) +let rewrite_foreach node = + let rewrite node ((foreach_counter, iterator_variables) as acc) = + match syntax node with + | ForeachStatement { + foreach_collection; + foreach_key; + foreach_value; + foreach_body; + _; + } when not (SuspendRewriter.has_no_suspends foreach_body) -> + let no_key = is_missing foreach_key in + (* $__iterator_# *) + let iterator_variable_string = + make_iterator_variable_string foreach_counter in + let iterator_variable_syntax = + make_variable_expression_syntax iterator_variable_string in + (* $__iterator_index_# *) + let iterator_index_string = + make_iterator_index_string foreach_counter in + let iterator_index_variable_syntax = + make_variable_expression_syntax iterator_index_string in + (* $__iterator_# = new \CoroutineForeachHelper(collection); *) + let iterator_assignment_syntax = + make_assignment_syntax_variable + iterator_variable_syntax + (make_new_coroutine_foreach_helper_object + foreach_collection) in + (* $__iterator_index_# = 0 *) + let zero_index_assignment_syntax = + make_assignment_syntax_variable + iterator_index_variable_syntax + (make_int_literal_syntax 0) in + (* $k = $__iterator_#->key($__iterator_index_); *) + let get_key_syntax = + make_method_call_expression_syntax + iterator_variable_syntax + key_name_syntax + [iterator_index_variable_syntax] + |> make_assignment_syntax_variable + foreach_key in + (* $v = $__iterator_#->current($__iterator_index_); *) + let get_value_syntax = + make_method_call_expression_syntax + iterator_variable_syntax + current_name_syntax + [iterator_index_variable_syntax] + |> make_assignment_syntax_variable + foreach_value in + (* $__iterator_index_# = $__iterator_#->next($__iterator_index_); *) + let iterator_next_syntax = + make_method_call_expression_syntax + iterator_variable_syntax + next_name_syntax + [iterator_index_variable_syntax] + |> make_assignment_syntax_variable + iterator_index_variable_syntax in + (** while ($__iterator_#->valid($__iterator_index_)) { block } *) + let while_block = + [ get_value_syntax + ; foreach_body + ; iterator_next_syntax + ] in + let while_block = + if no_key then while_block + else get_key_syntax :: while_block in + let while_syntax = + make_while_syntax + (make_method_call_expression_syntax + iterator_variable_syntax + valid_name_syntax + [iterator_index_variable_syntax]) + while_block in + (* Putting it all together *) + let foreach_replacement_syntax = + make_compound_statement_syntax + [ iterator_assignment_syntax + ; zero_index_assignment_syntax + ; while_syntax + ] in + let new_iterator_variables = + iterator_variable_string :: + iterator_index_string :: + iterator_variables in + (foreach_counter + 1, new_iterator_variables), + Rewriter.Result.Replace foreach_replacement_syntax + | _ -> acc, Rewriter.Result.Keep + in + let (_, iterator_variables), rewritten_body = + Rewriter.aggregating_rewrite_post rewrite node (0, []) in + (iterator_variables, rewritten_body) + let get_token node = match Syntax.get_token node with | Some token -> token @@ -508,9 +619,13 @@ let unnest_compound_statements node = Rewriter.rewrite_post rewrite node let lower_body body = - if is_missing body then (body, []) else - let used_locals = Lambda_analyzer.all_locals body in + if is_missing body then (body, [], []) else let body = add_missing_return body in + (* Get developer locals first, rewriting introduces additional variables *) + let used_locals = Lambda_analyzer.all_locals body in + let used_locals = SSet.elements used_locals in + let generated_to_be_saved_variables, body = rewrite_foreach body in + let used_locals = generated_to_be_saved_variables @ used_locals in let (next_loop_label, body) = rewrite_do 0 body in let body = rewrite_while body in let (next_loop_label, body) = rewrite_for next_loop_label body in @@ -518,14 +633,13 @@ let lower_body body = let (next_loop_label, temp_count), body = SuspendRewriter.rewrite_suspends body in let body = add_switch (next_loop_label, body) in - let used_locals = SSet.elements used_locals in let body = add_try_finally used_locals body in let body = unnest_compound_statements body in let coroutine_result_data_variables = temp_count |> Core_list.range 1 |> Core_list.map ~f:make_coroutine_result_data_variable in - (body, coroutine_result_data_variables) + (body, coroutine_result_data_variables, generated_to_be_saved_variables) let lower_synchronous_body body = if is_missing body then body else @@ -577,10 +691,12 @@ let make_outer_params outer_variables = let compute_state_machine_data context - coroutine_result_data_variables = + coroutine_result_data_variables + generated_to_be_saved_variables = (* TODO: Add a test case for "..." param. *) let inner_variables = SSet.elements context.Coroutine_context.inner_variables in + let inner_variables = generated_to_be_saved_variables @ inner_variables in let saved_inner_variables = Core_list.map ~f:make_saved_variable inner_variables in let properties = saved_inner_variables @ coroutine_result_data_variables in @@ -604,11 +720,15 @@ let generate_coroutine_state_machine if SuspendRewriter.only_tail_call_suspends original_body then lower_synchronous_body original_body, None else - let new_body, coroutine_result_data_variables = - lower_body original_body in - let state_machine_data = compute_state_machine_data - context - coroutine_result_data_variables in + let new_body + , coroutine_result_data_variables + , generated_to_be_saved_variables = + lower_body original_body in + let state_machine_data = + compute_state_machine_data + context + coroutine_result_data_variables + generated_to_be_saved_variables in let closure_syntax = CoroutineClosureGenerator.generate_coroutine_closure context diff --git a/hphp/hack/src/parser/coroutine/coroutine_suspend_rewriter.mli b/hphp/hack/src/parser/coroutine/coroutine_suspend_rewriter.mli index 3fc135f31568d..f0ad281d3acdd 100644 --- a/hphp/hack/src/parser/coroutine/coroutine_suspend_rewriter.mli +++ b/hphp/hack/src/parser/coroutine/coroutine_suspend_rewriter.mli @@ -10,6 +10,9 @@ module Syntax = Full_fidelity_editable_positioned_syntax val fix_up_lambda_body: Syntax.t -> Syntax.t +val has_no_suspends: + Syntax.t -> bool + val only_tail_call_suspends: Syntax.t -> bool diff --git a/hphp/hack/src/parser/coroutine/coroutine_syntax.ml b/hphp/hack/src/parser/coroutine/coroutine_syntax.ml index c2803719480bf..3693bbbb17249 100644 --- a/hphp/hack/src/parser/coroutine/coroutine_syntax.ml +++ b/hphp/hack/src/parser/coroutine/coroutine_syntax.ml @@ -442,6 +442,21 @@ let make_member_selection_expression_syntax receiver_syntax member_syntax = member_selection_syntax member_syntax +(** + * $obj_variable_syntax->method_name_syntax(argument_syntax_list, ...) + *) +let make_method_call_expression_syntax + obj_variable_syntax + method_name_syntax + argument_syntax_list = + let member_selection_expression_syntax = + make_member_selection_expression_syntax + obj_variable_syntax + method_name_syntax in + make_function_call_expression_syntax + member_selection_expression_syntax + argument_syntax_list + let make_parameter_declaration_syntax ?(visibility_syntax = make_missing ()) parameter_type_syntax @@ -982,6 +997,32 @@ let set_next_label_syntax number = let number = make_int_literal_syntax number in make_assignment_syntax_variable label_syntax number +(** + * For rewriting foreach + *) +let make_iterator_variable_string number = + Printf.sprintf "$__iterator_%d" number + +let make_iterator_index_string number = + Printf.sprintf "$__iterator_index_%d" number + +let valid_name_syntax = + make_name_syntax "valid" + +let key_name_syntax = + make_name_syntax "key" + +let current_name_syntax = + make_name_syntax "current" + +let next_name_syntax = + make_name_syntax "next" + +let make_new_coroutine_foreach_helper_object collection = + make_object_creation_expression_syntax + "\\CoroutineForeachHelper" + [collection] + (** * $saved_... * @@ -995,8 +1036,6 @@ let make_saved_field var = let make_saved_variable var = "$" ^ (make_saved_field var) - - type label = | StateLabel of int | ErrorStateLabel