Skip to content

Commit

Permalink
[export] Improve the calling of multi-platform exported module
Browse files Browse the repository at this point in the history
Previously we declared the lowering rule for call_exported to be
platform specific. This was correct, but in the case when the
caller function is lowered itself for multiple platforms this results
in multiple copies of the inner called Exported. Now instead we
make the call_exported rule be platform independent and make it
compute the platform index for the called module based on the
platform index in the caller module. This results in a single
copy of the HLO for the called module in the output.
  • Loading branch information
gnecula committed Oct 19, 2023
1 parent a40e7ed commit 82a2793
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 22 deletions.
71 changes: 51 additions & 20 deletions jax/experimental/export/export.py
Expand Up @@ -1048,16 +1048,7 @@ def _call_exported_impl(*args, exported: Exported):
call_exported_p.def_impl(_call_exported_impl)

def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
platform: str,
exported: Exported):
# TODO: implement true multi-platform lowering for call_exported
if (platform not in exported.lowering_platforms and
DisabledSafetyCheck.platform() not in exported.disabled_checks):
raise ValueError(
f"The exported function '{exported.fun_name}' was lowered for "
f"platforms '{exported.lowering_platforms}' but it is used "
f"on '{platform}'.")

if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True

Expand Down Expand Up @@ -1089,15 +1080,54 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
convert_shape(a, a_aval, exported_in_aval)
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
if i in exported.module_kept_var_idx]

# All the platforms for the current lowering must be among the platforms
# for which the callee was lowered.
if ctx.module_context.lowering_parameters.is_multi_platform:
assert ctx.module_context.lowering_parameters.platforms is not None
lowering_platforms = ctx.module_context.lowering_parameters.platforms
else:
lowering_platforms = (ctx.module_context.platform,)

callee_lowering_platform_index: list[int] = []
for platform in lowering_platforms:
if platform in exported.lowering_platforms:
callee_lowering_platform_index.append(
exported.lowering_platforms.index(platform))
elif DisabledSafetyCheck.platform() in exported.disabled_checks:
callee_lowering_platform_index.append(0)
else:
raise ValueError(
f"The exported function '{exported.fun_name}' was lowered for "
f"platforms '{exported.lowering_platforms}' but it is used "
f"on '{lowering_platforms}'.")

if len(exported.lowering_platforms) > 1:
# The exported module takes a platform index argument
# TODO: implement proper handling of the platform_index when we are
# in a multi-platform lowering context.
platform_index = exported.lowering_platforms.index(platform)
arg_width = callee_type.inputs[0].element_type.width
assert arg_width in [32, 64]
platform_index = np.int32(platform_index) if arg_width == 32 else np.int64(platform_index) # type: ignore
kept_args = [mlir.ir_constant(platform_index)] + kept_args
if len(lowering_platforms) > 1:
current_platform_idx = ctx.dim_var_values[0]
else:
current_platform_idx = mlir.ir_constant(np.int32(0))
# Compute the rule index based on the current platform
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
if current_platform_idx.type != i32_type:
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
callee_platform_idx = hlo.CaseOp([i32_type],
index=current_platform_idx,
num_branches=len(lowering_platforms))
for i in range(len(lowering_platforms)):
branch = callee_platform_idx.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.ReturnOp(mlir.ir_constants(
np.int32(callee_lowering_platform_index[i])))
if callee_platform_idx.result.type != callee_type.inputs[0]:
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
callee_platform_idx)

kept_args = [callee_platform_idx] + kept_args
else:
assert len(lowering_platforms) == 1

call = func_dialect.CallOp(callee_type.results,
ir.FlatSymbolRefAttr.get(fn),
kept_args)
Expand All @@ -1114,10 +1144,11 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
return results


for _p in ("cpu", "tpu", "cuda", "rocm"):
mlir.register_lowering(call_exported_p,
functools.partial(_call_exported_lowering, platform=_p),
platform=_p)
# for _p in ("cpu", "tpu", "cuda", "rocm"):
# mlir.register_lowering(call_exported_p,
# functools.partial(_call_exported_lowering, platform=_p),
# platform=_p)
mlir.register_lowering(call_exported_p, _call_exported_lowering)

def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
exported: Exported,
Expand Down
10 changes: 8 additions & 2 deletions tests/export_test.py
Expand Up @@ -769,7 +769,7 @@ def test_multi_platform(self):

def test_multi_platform_nested(self):
x = np.arange(5, dtype=np.float32)
exp = export.export(_testing_multi_platform_func,
exp = export.export(lambda x: _testing_multi_platform_func(jnp.sin(x)),
lowering_platforms=("cpu", "tpu", "cuda"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))

Expand All @@ -778,14 +778,20 @@ def test_multi_platform_nested(self):
# nested exported.
exp2 = export.export(export.call_exported(exp),
lowering_platforms=("cpu", "cuda"))(x)

# Ensure that we do not have multiple lowerings of the exported function
exp2_module_str = str(exp2.mlir_module())
count_sine = len(re.findall("stablehlo.sine", exp2_module_str))
self.assertEqual(1, count_sine)

# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
if platform == "tpu": continue
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call_exported(exp2)(x_device)
self.assertAllClose(
res_exp,
_testing_multi_platform_fun_expected(x, platform=platform))
_testing_multi_platform_fun_expected(np.sin(x), platform=platform))

def test_multi_platform_nested_inside_single_platform_export(self):
x = np.arange(5, dtype=np.float32)
Expand Down

0 comments on commit 82a2793

Please sign in to comment.