diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index fd9bb6b7c8f2..6cb71540749b 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -527,8 +527,7 @@ def is_token(attrs): nr_array_results = len(orig_output_types) - nr_token_results assert nr_array_results >= 0 assert not any( - is_token(attrs) for attrs in result_attrs[-nr_array_results:] - ) + is_token(attrs) for attrs in result_attrs[-nr_array_results:]) new_main_output_types = orig_output_types[-nr_array_results:] new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types) new_main_op = func_dialect.FuncOp( @@ -540,38 +539,42 @@ def is_token(attrs): pass # TODO: better detection if orig_main.arg_attrs does not exist try: new_main_op.result_attrs = ir.ArrayAttr.get( - result_attrs[-nr_array_results:] - ) + result_attrs[-nr_array_results:]) except KeyError: pass symbol_table.insert(new_main_op) entry_block = new_main_op.add_entry_block() with ir.InsertionPoint(entry_block): - orig_main_args: List[ir.Value] = [] module_context = mlir.ModuleContext( "cpu", "cpu", sharding_impls.ShardingContext([]), source_info_util.new_name_stack(), [], itertools.count(1), [], module=wrapped_module, context=context) - ctx = mlir.LoweringRuleContext(module_context=module_context, - primitive=None, avals_in=args_avals_flat, avals_out=None, - tokens_in=mlir.TokenSet(), tokens_out=None) + ctx = mlir.LoweringRuleContext( + module_context=module_context, primitive=None, + avals_in=args_avals_flat, avals_out=None, + tokens_in=mlir.TokenSet(), tokens_out=None) dim_values = mlir.lower_fun( functools.partial(shape_poly.compute_dim_vars_from_arg_shapes, args_avals_flat, args_kwargs_tree=args_kwargs_tree), multiple_results=True)(ctx, *new_main_op.arguments) - - dim_args = [] + # The arguments to pass to the call to orig_main + orig_main_args: List[ir.Value] = [] + # The first arguments are the dimension variable for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_var_input_types): if dim_arg.type != dim_arg_type: - dim_args.append(hlo.ConvertOp(dim_arg_type, dim_arg).result) + orig_main_args.append(hlo.ConvertOp(dim_arg_type, dim_arg).result) else: - dim_args.append(dim_arg) - # The first arguments are the dimension variable - orig_main_args.extend(dim_args) + orig_main_args.append(dim_arg) # Then the token arguments orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args) - # Then the array arguments - orig_main_args.extend(new_main_op.arguments) + # Then the array arguments. We insert a ConvertOp as the only use of + # an input argument. This helps the downstream shape refinement because + # it will set the type of input arguments to static shapes, and this + # can invalidate the module if the argument is used as the result of a + # function, or if it appears as the input to a custom_call with + # output_operand_alias attribute. See b/287386268. + for a in new_main_op.arguments: + orig_main_args.append(hlo.ConvertOp(a.type, a).result) call = func_dialect.CallOp(orig_output_types, ir.FlatSymbolRefAttr.get(orig_main_name), orig_main_args)