Skip to content

Commit

Permalink
AOT sharding mismatch error shouldn't have GSPMDSharding in it.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576668290
  • Loading branch information
yashk2810 authored and jax authors committed Oct 25, 2023
1 parent ba9fd77 commit 4d15375
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
22 changes: 13 additions & 9 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -2921,30 +2921,34 @@ def check_gda_or_array_xla_sharding_match(
if not isinstance(arg, ArrayImpl):
continue

db_xs = check_device_backend_on_shardings([xs])
if not db_xs:
xs = getattr(xs, '_original_sharding', xs)

# Raise memory kind mismatch error even if the arg is uncommitted.
if arg.sharding.memory_kind != xs.memory_kind:
errors.append(
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
f"arg {name} with shape: {arg.aval.str_short()}")
"Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")

# No need to cache this check since MeshExecutable has a C++ fast path
# for AOT compiled call.
if (not check_device_backend_on_shardings([xs]) and
arg._committed and
if (not db_xs and arg._committed and
not op_shardings.are_op_shardings_equal(
arg.sharding._to_xla_hlo_sharding(arg.ndim),
xs._to_xla_hlo_sharding(arg.ndim))):
errors.append(
f"Got Array sharding: {arg.sharding} and input sharding: {xs} for "
f"arg {name} with shape: {arg.aval.str_short()}")
"Got input sharding(s) that compiled object was called with: "
f"{arg.sharding} and sharding(s) the computation was compiled "
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")

if errors:
str_errors = '\n'.join(errors[:num_errors])
num_mismatch_str = (
f'the {len(errors)} mismatches' if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise ValueError(
"Array(s) sharding does not match the input(s) sharding. "
"Compiled object called with input sharding(s) does not match the "
"sharding(s) the computation was compiled with. "
f"Here are {num_mismatch_str}:\n{str_errors}")


Expand Down
10 changes: 6 additions & 4 deletions tests/pjit_test.py
Expand Up @@ -1491,8 +1491,8 @@ def test_xla_arr_sharding_mismatch(self):
input_data)
with self.assertRaisesRegex(
ValueError,
r"Array\(s\) sharding does not match the input\(s\) "
r"sharding.*\n.*for arg x"):
r"Compiled object called with input sharding\(s\) does not match the "
r"sharding\(s\) the computation was compiled with.*\n.*for arg x"):
compiled(arr)

def test_gda_auto_shardings_len(self):
Expand Down Expand Up @@ -1806,7 +1806,8 @@ def test_array_lower_compile(self):

with self.assertRaisesRegex(
ValueError,
r"Array\(s\) sharding does not match the input\(s\) sharding. "
r"Compiled object called with input sharding\(s\) does not match the "
r"sharding\(s\) the computation was compiled with. "
"Here are 5 mismatches out of 6"):
compiled(a2, a2, a2, a2, a2, a2)

Expand All @@ -1819,7 +1820,8 @@ def test_array_lower_compile(self):
inp2 = {'x': a2, 'y': {'y1': a2}}
with self.assertRaisesRegex(
ValueError,
r"Array\(s\) sharding does not match the input\(s\) sharding. "
r"Compiled object called with input sharding\(s\) does not match the "
r"sharding\(s\) the computation was compiled with. "
"Here are the 2 mismatches"):
compiled(inp2)

Expand Down

0 comments on commit 4d15375

Please sign in to comment.