Skip to content

Commit

Permalink
Fix symbol collision when merging MLIR modules.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542039479
  • Loading branch information
hawkinsp authored and jax authors committed Jun 20, 2023
1 parent eca3b97 commit e99ca46
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 27 deletions.
78 changes: 56 additions & 22 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1613,35 +1613,69 @@ def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
def merge_mlir_modules(dst_module: ir.Module,
sym_name: str,
src_module: ir.Module) -> str:
"""Returns the name of src_module's main() function, after renaming."""
callee_name = None
"""
Args:
dst_module: the module into which the contents of src_module should be
moved. Nothing in dst_module will be renamed.
sym_name: the desired name for the "main" function of src_module after
merging. This is a hint: the true name may be different because of symbol
uniquification, and the true name is returned by this function.
src_module: the module whose contents are to be alpha-renamed, set to
private visibility, and merged into dst_module. src_module must contain
exactly one symbol named "main".
Functions in src_module will be renamed such that they do not collide with
functions in dst_module.
This function mutates `src_module`. On return, `src_module` is left in an
undefined state.
Returns:
the name of src_module's main() function, after renaming.
"""
assert dst_module.context == src_module.context

src_symtab = ir.SymbolTable(src_module.operation)
dst_symtab = ir.SymbolTable(dst_module.operation)
used_names = set()

n = len(dst_module.body.operations)
# Rename all symbols in src_module that clash with names in dst_module, or
# are the "main" symbol.
renamings = {}
for op in src_module.body.operations:
dst_module.body.append(op)
ops = list(dst_module.body.operations)[n:]

for op in ops:
op = typing.cast(func_dialect.FuncOp, op)
old_name = op.name.value
if op.name.value == "main":
dst_symtab.set_symbol_name(op, sym_name)
op.attributes["sym_visibility"] = ir.StringAttr.get("private")
callee_name = ir.StringAttr(dst_symtab.insert(op)).value
new_name = callee_name
else:
new_name = ir.StringAttr(dst_symtab.insert(op)).value
name = op.name.value
should_rename = name in dst_symtab or name == "main"
if should_rename:
base_name = sym_name if name == "main" else name
new_name = base_name
i = 0
# Replacements are chosen such that the new names are present in neither
# src_module, dst_module, or the set of fresh names we've already used.
# Since we rename names one at a time, if new names were in src_module,
# they might themselves collide with a later renaming.
while (new_name in src_symtab or new_name in dst_symtab or
new_name in used_names):
new_name = f"{base_name}_{i}"
i += 1
renamings[name] = new_name
used_names.add(new_name)

# Apply the symbol renamings to symbol definitions.
private = ir.StringAttr.get("private")
for op in src_module.body.operations:
if op.name.value in renamings:
src_symtab.set_symbol_name(op, renamings[op.name.value])
op.attributes["sym_visibility"] = private

# Replace references to the symbol with the new name
for other_op in ops:
dst_symtab.replace_all_symbol_uses(
old_name, new_name, other_op.operation)
# Apply the symbol renamings to symbol uses.
for old_name, new_name in renamings.items():
for op in src_module.body.operations:
src_symtab.replace_all_symbol_uses(old_name, new_name, op)

for op in src_module.body.operations:
dst_module.body.append(op)

assert callee_name is not None
return callee_name
return renamings["main"]


def xla_fallback_lowering(prim: core.Primitive):
Expand Down
39 changes: 34 additions & 5 deletions tests/filecheck/subcomputations.filecheck.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Tests for lowering of array origami ops into MLIR.
# Tests for MLIR helpers.

# RUN: %PYTHON %s | FileCheck %s

Expand Down Expand Up @@ -44,10 +44,10 @@ def cumsum_only_once(x, y):
# Test merging modules
# CHECK-LABEL: TEST: merge_modules
# CHECK: module @jit_g
# CHECK: func public @main
# CHECK: func private @f
# CHECK: func private @m2_main_renamed
# CHECK: func private @f_0
# CHECK: func public @main(
# CHECK: func private @f(
# CHECK: func private @m2_main_renamed(
# CHECK: func private @f_0(
def make_module(c):
@jax.jit
def f(x):
Expand All @@ -69,5 +69,34 @@ def g(x):
print("\nTEST: merge_modules")
print(str(m1))


# Test symbol renaming when merging modules
# CHECK-LABEL: TEST: merge_modules_2
# CHECK: module @jit_f
# CHECK: func public @main(
# CHECK: call @f(
# CHECK: func private @f(
# CHECK: func private @f_0(
# CHECK: call @f_1(
# CHECK: func private @f_1(

with mlir.make_ir_context():
m_str = """
module @jit_f {
func.func public @main(%arg0: tensor<i64>) -> tensor<i64> {
%0 = call @f(%arg0) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
func.func private @f(%arg0: tensor<i64>) -> tensor<i64> {
return %arg0 : tensor<i64>
}
}"""
m1 = ir.Module.parse(m_str)
m2 = ir.Module.parse(m_str)
mlir.merge_mlir_modules(m1, "f", m2)
print("\nTEST: merge_modules_2")
print(str(m1))


if __name__ == "__main__":
app.run(main)

0 comments on commit e99ca46

Please sign in to comment.