Skip to content

Commit

Permalink
Merge pull request #21203 from gnecula:export_device_poly
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633253709
  • Loading branch information
jax authors committed May 13, 2024
2 parents e4f3b3f + 98aead7 commit 85e91c2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 10 deletions.
48 changes: 38 additions & 10 deletions jax/experimental/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,14 +817,15 @@ def _check_lowering(lowering) -> None:
check_sharding_pattern = re.compile(r"^({replicated}|{unknown shard_as.*}|"")$")

def _check_module(mod: ir.Module, *,
disabled_checks: Sequence[DisabledSafetyCheck]) -> None:
disabled_checks: Sequence[DisabledSafetyCheck]) -> bool:
"""Run a number of checks on the module.
Args:
allow_non_replicated_sharding: whether the module is allowed to contain
non_replicated sharding annotations.
disabled_checks: the safety checks that are disabled.
Returns True if the module uses non-replicated shardings.
"""
sharding_attr = ir.StringAttr.get("Sharding", mod.context)
allowed_custom_call_targets: set[str] = copy.copy(_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
for dc in disabled_checks:
target = dc.is_custom_call()
Expand All @@ -835,13 +836,28 @@ def _check_module(mod: ir.Module, *,
ir.StringAttr.get(target, mod.context)
for target in allowed_custom_call_targets}
disallowed_custom_call_ops: list[str] = []
module_uses_non_replicated_sharding = False
def check_sharding(op: ir.Operation, loc: ir.Location):
try:
sharding = op.attributes["mhlo.sharding"]
except KeyError:
pass
else:
if not re.match(check_sharding_pattern, ir.StringAttr(sharding).value):
nonlocal module_uses_non_replicated_sharding
module_uses_non_replicated_sharding = True

def check_op(op: ir.Operation):
op_name = op.operation.name
if op_name == "stablehlo.custom_call":
if op_name == "func.func":
check_sharding(op.operation, op.location)

elif op_name == "stablehlo.custom_call":
call_target_name_attr = op.operation.attributes["call_target_name"]
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
disallowed_custom_call_ops.append(f"{op} at {op.location}")
if call_target_name_attr == sharding_attr:
check_sharding(op, op.location)

def walk_operations(op):
check_op(op)
Expand All @@ -858,6 +874,7 @@ def walk_operations(op):
f"{disallowed_custom_call_ops_str}.\n"
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
raise ValueError(msg)
return module_uses_non_replicated_sharding

def expand_in_shardings(in_shardings: Sequence[LoweringSharding],
module_kept_var_idx: Sequence[int],
Expand Down Expand Up @@ -1094,6 +1111,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
exported: Exported):
if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True
submodule = ir.Module.parse(exported.mlir_module())

axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
Expand All @@ -1103,17 +1121,27 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
else:
raise NotImplementedError(type(axis_context))
if num_devices != exported.nr_devices:
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with "
f"{num_devices} devices"
)
# In some special cases we allow running with a different number of devices
# than the function was exported for.
err_msg = ""
if exported.nr_devices != 1:
err_msg = "the module was lowered for more than 1 device."
elif (_check_module(submodule, disabled_checks=()) or
any(s is not None and not s.is_replicated()
for s in exported.in_shardings + exported.out_shardings)):
err_msg = "the module contains non-replicated sharding annotations."
if err_msg:
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with "
f"{num_devices} devices. This is disallowed because: {err_msg}"
)

# Apply in_shardings
args = tuple(
wrap_with_sharding(ctx, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings))
submodule = ir.Module.parse(exported.mlir_module())

symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
Expand Down
76 changes: 76 additions & 0 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,82 @@ def f_jax(b): # b: f32[16 // DEVICES, 4]
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
)(a)

def test_call_with_different_no_of_devices(self):
if jax.local_device_count() < 2:
self.skipTest("Need at least 2 devices")

@jax.jit
def f_without_shardings(x):
return jnp.sum(x ** 2, axis=0)

a = jnp.arange(jax.local_device_count() * 10, dtype=np.float32).reshape(
(jax.local_device_count(), 10)
)
res_native = f_without_shardings(a)
exp = get_exported(f_without_shardings)(a)
self.assertEqual(exp.nr_devices, 1)

run_devices = jax.local_devices()
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))

res_exported = export.call_exported(exp)(b)
self.assertAllClose(res_native, res_exported)

def test_call_with_different_no_of_devices_error_has_in_shardings(self):
if jax.local_device_count() < 2:
self.skipTest("Need at least 2 devices")

mesh_1 = Mesh(jax.local_devices()[:1], "i")
@functools.partial(pjit.pjit,
in_shardings=NamedSharding(mesh_1, P("i")))
def f_with_sharding(x):
return jnp.sum(x ** 2, axis=0)

a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
(jax.device_count(), 10)
)
exp = get_exported(f_with_sharding)(a)
self.assertEqual(exp.nr_devices, 1)

run_devices = jax.local_devices()
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))

with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):
export.call_exported(exp)(b)

def test_call_with_different_no_of_devices_error_has_sharding_constraint(self):
if jax.device_count() < 2:
self.skipTest("Need at least 2 devices")

mesh_1 = Mesh(jax.local_devices()[:1], "i")
@jax.jit
def f_with_sharding(x):
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh_1, P("i")))
return jnp.sum(x ** 2, axis=0)

a = jnp.arange(jax.device_count() * 10, dtype=np.float32).reshape(
(jax.device_count(), 10)
)
exp = get_exported(f_with_sharding)(a)
self.assertEqual(exp.nr_devices, 1)

run_devices = jax.local_devices()
run_mesh = Mesh(run_devices, "i")
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))

with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):
export.call_exported(exp)(b)

@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_poly={poly}", poly=poly)
Expand Down

0 comments on commit 85e91c2

Please sign in to comment.