Skip to content

Commit

Permalink
[shape_poly] Change lowering for shape polymorphism to simplify shape…
Browse files Browse the repository at this point in the history
… refinement

We insert a ConvertOp as the only use of an input argument in a shape polymorphic
`main` function. This helps the downstream shape refinement because it will set the type
of input arguments to static shapes, and this can invalidate the
module if the argument appears as the result of a function, or if
it appears as the input to a custom_call with output_operand_alias
attribute.
See b/287386268.
  • Loading branch information
gnecula committed Jun 15, 2023
1 parent bfe8acb commit 6e09d4a
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -527,8 +527,7 @@ def is_token(attrs):
nr_array_results = len(orig_output_types) - nr_token_results
assert nr_array_results >= 0
assert not any(
is_token(attrs) for attrs in result_attrs[-nr_array_results:]
)
is_token(attrs) for attrs in result_attrs[-nr_array_results:])
new_main_output_types = orig_output_types[-nr_array_results:]
new_main_ftype = ir.FunctionType.get(new_main_input_types, new_main_output_types)
new_main_op = func_dialect.FuncOp(
Expand All @@ -540,38 +539,42 @@ def is_token(attrs):
pass # TODO: better detection if orig_main.arg_attrs does not exist
try:
new_main_op.result_attrs = ir.ArrayAttr.get(
result_attrs[-nr_array_results:]
)
result_attrs[-nr_array_results:])
except KeyError:
pass
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args: List[ir.Value] = []
module_context = mlir.ModuleContext(
"cpu", "cpu", sharding_impls.ShardingContext([]),
source_info_util.new_name_stack(),
[], itertools.count(1), [], module=wrapped_module, context=context)
ctx = mlir.LoweringRuleContext(module_context=module_context,
primitive=None, avals_in=args_avals_flat, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
ctx = mlir.LoweringRuleContext(
module_context=module_context, primitive=None,
avals_in=args_avals_flat, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
dim_values = mlir.lower_fun(
functools.partial(shape_poly.compute_dim_vars_from_arg_shapes,
args_avals_flat, args_kwargs_tree=args_kwargs_tree),
multiple_results=True)(ctx, *new_main_op.arguments)

dim_args = []
# The arguments to pass to the call to orig_main
orig_main_args: List[ir.Value] = []
# The first arguments are the dimension variable
for dim_arg, dim_arg_type in zip(util.flatten(dim_values), dim_var_input_types):
if dim_arg.type != dim_arg_type:
dim_args.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
orig_main_args.append(hlo.ConvertOp(dim_arg_type, dim_arg).result)
else:
dim_args.append(dim_arg)
# The first arguments are the dimension variable
orig_main_args.extend(dim_args)
orig_main_args.append(dim_arg)
# Then the token arguments
orig_main_args.extend(list(mlir.dummy_token()) * nr_token_args)
# Then the array arguments
orig_main_args.extend(new_main_op.arguments)
# Then the array arguments. We insert a ConvertOp as the only use of
# an input argument. This helps the downstream shape refinement because
# it will set the type of input arguments to static shapes, and this
# can invalidate the module if the argument is used as the result of a
# function, or if it appears as the input to a custom_call with
# output_operand_alias attribute. See b/287386268.
for a in new_main_op.arguments:
orig_main_args.append(hlo.ConvertOp(a.type, a).result)
call = func_dialect.CallOp(orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args)
Expand Down

0 comments on commit 6e09d4a

Please sign in to comment.