diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 6ff0efae5cd1..6e97e72d29cb 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1613,35 +1613,69 @@ def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation def merge_mlir_modules(dst_module: ir.Module, sym_name: str, src_module: ir.Module) -> str: - """Returns the name of src_module's main() function, after renaming.""" - callee_name = None + """ + Args: + dst_module: the module into which the contents of src_module should be + moved. Nothing in dst_module will be renamed. + sym_name: the desired name for the "main" function of src_module after + merging. This is a hint: the true name may be different because of symbol + uniquification, and the true name is returned by this function. + src_module: the module whose contents are to be alpha-renamed, set to + private visibility, and merged into dst_module. src_module must contain + exactly one symbol named "main". + + Functions in src_module will be renamed such that they do not collide with + functions in dst_module. + + This function mutates `src_module`. On return, `src_module` is left in an + undefined state. + + Returns: + the name of src_module's main() function, after renaming. + """ assert dst_module.context == src_module.context + + src_symtab = ir.SymbolTable(src_module.operation) dst_symtab = ir.SymbolTable(dst_module.operation) + used_names = set() - n = len(dst_module.body.operations) + # Rename all symbols in src_module that clash with names in dst_module, or + # are the "main" symbol. + renamings = {} for op in src_module.body.operations: - dst_module.body.append(op) - ops = list(dst_module.body.operations)[n:] - - for op in ops: - op = typing.cast(func_dialect.FuncOp, op) - old_name = op.name.value - if op.name.value == "main": - dst_symtab.set_symbol_name(op, sym_name) - op.attributes["sym_visibility"] = ir.StringAttr.get("private") - callee_name = ir.StringAttr(dst_symtab.insert(op)).value - new_name = callee_name - else: - new_name = ir.StringAttr(dst_symtab.insert(op)).value + name = op.name.value + should_rename = name in dst_symtab or name == "main" + if should_rename: + base_name = sym_name if name == "main" else name + new_name = base_name + i = 0 + # Replacements are chosen such that the new names are present in neither + # src_module, dst_module, or the set of fresh names we've already used. + # Since we rename names one at a time, if new names were in src_module, + # they might themselves collide with a later renaming. + while (new_name in src_symtab or new_name in dst_symtab or + new_name in used_names): + new_name = f"{base_name}_{i}" + i += 1 + renamings[name] = new_name + used_names.add(new_name) + + # Apply the symbol renamings to symbol definitions. + private = ir.StringAttr.get("private") + for op in src_module.body.operations: + if op.name.value in renamings: + src_symtab.set_symbol_name(op, renamings[op.name.value]) + op.attributes["sym_visibility"] = private - # Replace references to the symbol with the new name - for other_op in ops: - dst_symtab.replace_all_symbol_uses( - old_name, new_name, other_op.operation) + # Apply the symbol renamings to symbol uses. + for old_name, new_name in renamings.items(): + for op in src_module.body.operations: + src_symtab.replace_all_symbol_uses(old_name, new_name, op) + for op in src_module.body.operations: + dst_module.body.append(op) - assert callee_name is not None - return callee_name + return renamings["main"] def xla_fallback_lowering(prim: core.Primitive): diff --git a/tests/filecheck/subcomputations.filecheck.py b/tests/filecheck/subcomputations.filecheck.py index 155c75f2482f..1f8e9d32e5b1 100644 --- a/tests/filecheck/subcomputations.filecheck.py +++ b/tests/filecheck/subcomputations.filecheck.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Tests for lowering of array origami ops into MLIR. +# Tests for MLIR helpers. # RUN: %PYTHON %s | FileCheck %s @@ -44,10 +44,10 @@ def cumsum_only_once(x, y): # Test merging modules # CHECK-LABEL: TEST: merge_modules # CHECK: module @jit_g - # CHECK: func public @main - # CHECK: func private @f - # CHECK: func private @m2_main_renamed - # CHECK: func private @f_0 + # CHECK: func public @main( + # CHECK: func private @f( + # CHECK: func private @m2_main_renamed( + # CHECK: func private @f_0( def make_module(c): @jax.jit def f(x): @@ -69,5 +69,34 @@ def g(x): print("\nTEST: merge_modules") print(str(m1)) + + # Test symbol renaming when merging modules + # CHECK-LABEL: TEST: merge_modules_2 + # CHECK: module @jit_f + # CHECK: func public @main( + # CHECK: call @f( + # CHECK: func private @f( + # CHECK: func private @f_0( + # CHECK: call @f_1( + # CHECK: func private @f_1( + + with mlir.make_ir_context(): + m_str = """ +module @jit_f { + func.func public @main(%arg0: tensor) -> tensor { + %0 = call @f(%arg0) : (tensor) -> tensor + return %0 : tensor + } + func.func private @f(%arg0: tensor) -> tensor { + return %arg0 : tensor + } +}""" + m1 = ir.Module.parse(m_str) + m2 = ir.Module.parse(m_str) + mlir.merge_mlir_modules(m1, "f", m2) + print("\nTEST: merge_modules_2") + print(str(m1)) + + if __name__ == "__main__": app.run(main)