Skip to content

Commit

Permalink
* Disallow any other type other than GDA and ShapedArray for auto sha…
Browse files Browse the repository at this point in the history
…rding.

* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
  * Auto sharding
    * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
  * NO auto sharding
    * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch

PiperOrigin-RevId: 439413895
  • Loading branch information
yashk2810 authored and jax authors committed Apr 4, 2022
1 parent 4949e78 commit 6825f65
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 36 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/global_device_array.py
Expand Up @@ -51,7 +51,7 @@ def _get_array_mapping(mesh_axes):
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
parsed_pspec, _, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes")
return get_array_mapping(parsed_pspec)


Expand Down
39 changes: 11 additions & 28 deletions jax/experimental/pjit.py
Expand Up @@ -196,9 +196,9 @@ def pjit(fun: Callable,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_axis_resources = tuple(in_axis_resources)

in_axis_resources, _, _ = _prepare_axis_resources(
in_axis_resources, _, _, in_all_auto = _prepare_axis_resources(
in_axis_resources, "in_axis_resources")
out_axis_resources, _, _ = _prepare_axis_resources(
out_axis_resources, _, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")

static_argnums = _ensure_index_tuple(static_argnums)
Expand Down Expand Up @@ -237,6 +237,12 @@ def infer_params(*args, _global_avals=False, **kwargs):

_maybe_check_pjit_gda_mesh(args_flat, mesh)

# TODO(yashkatariya): Make sure you are not checking explicitly for `ShapedArray`.
# One possibility, is to only allow GDA and fully replicated inputs for AUTO.
if in_all_auto:
assert all(isinstance(a, GDA) or (isinstance(a, core.ShapedArray) and _global_avals)
for a in args_flat), args_flat

local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
Expand Down Expand Up @@ -555,7 +561,7 @@ def _prepare_axis_resources(axis_resources,
for entry in entries
]
_check_unique_resources(entries, arg_name)
return tree_unflatten(treedef, entries), entries, treedef
return tree_unflatten(treedef, entries), entries, treedef, all_auto


def _check_resources_mismatch(in_axis_resources_flat, is_gda):
Expand Down Expand Up @@ -621,12 +627,8 @@ def _pjit_call_impl(*args, jaxpr,
compiled = _pjit_lower(
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, in_is_global).compile()
# Check the GDA sharding and the sharding returned by the auto spmd partitoner
# only if auto_spmd_lowering is enabled.
# TODO(yashkatariya): Move this check to `def call()` method of MeshExecutable.
if compiled._auto_spmd_lowering:
in_pspec, _ = _get_sharding_from_executable(compiled.xla_executable, resource_env.physical_mesh)
_check_gda_xla_sharding_match(args, in_pspec)
pxla._check_gda_xla_sharding_match(args, compiled._in_axes)
distributed_debug_log(("Running pjit'd function", name),
("mesh", resource_env.physical_mesh))
return compiled.unsafe_call(*args)
Expand Down Expand Up @@ -955,7 +957,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r

def with_sharding_constraint(x, axis_resources):
x_flat, tree = tree_flatten(x)
parsed_axis_resources, entries, _ = _prepare_axis_resources(
parsed_axis_resources, entries, _, _ = _prepare_axis_resources(
axis_resources, "axis_resources", allow_unconstrained_dims=True)
axis_resources_flat = tuple(
flatten_axes("with_sharding_constraint axis_resources",
Expand Down Expand Up @@ -1093,25 +1095,6 @@ def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))

def _check_gda_xla_sharding_match(args, in_pspec):
for arg, ip in safe_zip(args, in_pspec):
if not isinstance(arg, GDA):
continue

gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(
arg.mesh_axes, arg_name="GDA mesh_axes"))
in_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(ip, arg_name="auto sharding pspec"))
if in_cpspec != gda_cpspec:
raise ValueError(
"GDA sharding does not match the sharding returned by auto spmd "
"partitioner. Did you create the GDA with the input sharding "
"returned by XLA? If yes, please file a bug. "
f"Got GDA spec: {gda_cpspec.user_spec} and "
f"auto sharding spec: {in_cpspec.user_spec} for GDA: {arg}")


def _get_in_positional_semantics(global_avals: bool, arg) -> maps._PositionalSemantics:
if isinstance(arg, GDA):
return maps._PositionalSemantics.GLOBAL
Expand Down
26 changes: 22 additions & 4 deletions jax/interpreters/pxla.py
Expand Up @@ -2361,15 +2361,17 @@ def _get_array_mapping_from_executable(

class MeshExecutable(stages.Executable):
__slots__ = ['xla_executable', 'unsafe_call', '_input_avals',
'_auto_spmd_lowering']
'_in_axes', '_out_axes', '_auto_spmd_lowering']

def __init__(self, xla_executable, unsafe_call, input_avals,
auto_spmd_lowering):
in_axes, out_axes, auto_spmd_lowering):
self.xla_executable = xla_executable
self.unsafe_call = unsafe_call
# input_avals is a list of global and local avals. Aval is global if input
# is a GDA else local.
self._input_avals = input_avals
self._in_axes = in_axes
self._out_axes = out_axes
self._auto_spmd_lowering = auto_spmd_lowering

@staticmethod
Expand Down Expand Up @@ -2429,7 +2431,8 @@ def from_hlo(name: str,
handle_args = InputsHandler(xla_executable.local_devices(), input_specs, input_indices)
unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs)

return MeshExecutable(xla_executable, unsafe_call, input_avals, auto_spmd_lowering)
return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_axes, out_axes, auto_spmd_lowering)

# -- stages.Executable protocol

Expand All @@ -2440,13 +2443,28 @@ def hlo_modules(self):
return self.xla_executable.hlo_modules()

def call(self, *args):
# TODO(yashkatariya): Add a AOT lowering test where GDA is an input.
arg_avals = map(xla.abstractify, args)
ref_avals = self._input_avals
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
# Check the GDA sharding and the input sharding.
_check_gda_xla_sharding_match(args, self._in_axes)
return self.unsafe_call(*args)


def _check_gda_xla_sharding_match(args, in_array_mappings):
from jax.experimental.global_device_array import GlobalDeviceArray, _get_array_mapping

for arg, inp_array_mapping in safe_zip(args, in_array_mappings):
if not isinstance(arg, GlobalDeviceArray):
continue
gda_array_mapping = _get_array_mapping(arg.mesh_axes)
if inp_array_mapping != gda_array_mapping:
raise ValueError(
"GDA sharding does not match the input sharding. "
f"Got GDA spec: {array_mapping_to_axis_resources(gda_array_mapping)} and "
f"auto sharding spec: {array_mapping_to_axis_resources(inp_array_mapping)} for GDA: {arg}")


_forbidden_primitives = {
'xla_pmap': 'pmap',
'sharded_call': 'sharded_jit',
Expand Down
20 changes: 17 additions & 3 deletions tests/pjit_test.py
Expand Up @@ -1171,6 +1171,18 @@ def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self):
self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)

def test_pjit_gda_aot_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_gda = create_gda(global_input_shape, global_mesh, P('x', 'y'))

with global_mesh:
f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x'))
compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile()
with self.assertRaisesRegex(
ValueError, "GDA sharding does not match the input sharding."):
compiled(input_gda)


class AutoShardingPjitTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -1253,9 +1265,11 @@ def test_xla_gda_sharding_mismatch(self):
gda = create_gda(global_input_shape, global_mesh, different_pspec,
global_input_data)
with self.assertRaisesRegex(
ValueError,
"GDA sharding does not match the sharding returned by auto spmd "
"partitioner"):
ValueError, "GDA sharding does not match the input sharding."):
sharding_info.compiled(gda)

with self.assertRaisesRegex(
ValueError, "GDA sharding does not match the input sharding."):
f(gda)


Expand Down

0 comments on commit 6825f65

Please sign in to comment.