From 6825f654b105240032b70dc1a2b7b8bcae8cb462 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 4 Apr 2022 14:33:17 -0700 Subject: [PATCH] * Disallow any other type other than GDA and ShapedArray for auto sharding. * Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.** * Auto sharding * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True. * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch * NO auto sharding * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params` * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch PiperOrigin-RevId: 439413895 --- jax/experimental/global_device_array.py | 2 +- jax/experimental/pjit.py | 39 +++++++------------------ jax/interpreters/pxla.py | 26 ++++++++++++++--- tests/pjit_test.py | 20 +++++++++++-- 4 files changed, 51 insertions(+), 36 deletions(-) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index a400143ec9c2..fc40efd850f5 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -51,7 +51,7 @@ def _get_array_mapping(mesh_axes): # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources - parsed_pspec, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes") + parsed_pspec, _, _, _ = _prepare_axis_resources(mesh_axes, "GDA mesh_axes") return get_array_mapping(parsed_pspec) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index ffffcbc1fdcc..0bdeca187056 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -196,9 +196,9 @@ def pjit(fun: Callable, # rather than raising an error. https://github.com/google/jax/issues/2367 in_axis_resources = tuple(in_axis_resources) - in_axis_resources, _, _ = _prepare_axis_resources( + in_axis_resources, _, _, in_all_auto = _prepare_axis_resources( in_axis_resources, "in_axis_resources") - out_axis_resources, _, _ = _prepare_axis_resources( + out_axis_resources, _, _, _ = _prepare_axis_resources( out_axis_resources, "out_axis_resources") static_argnums = _ensure_index_tuple(static_argnums) @@ -237,6 +237,12 @@ def infer_params(*args, _global_avals=False, **kwargs): _maybe_check_pjit_gda_mesh(args_flat, mesh) + # TODO(yashkatariya): Make sure you are not checking explicitly for `ShapedArray`. + # One possibility, is to only allow GDA and fully replicated inputs for AUTO. + if in_all_auto: + assert all(isinstance(a, GDA) or (isinstance(a, core.ShapedArray) and _global_avals) + for a in args_flat), args_flat + local_in_avals = tuple(shaped_abstractify(a) for a in args_flat) # TODO(yashkatariya): This is a hack. This should go away when avals have # is_global attribute. @@ -555,7 +561,7 @@ def _prepare_axis_resources(axis_resources, for entry in entries ] _check_unique_resources(entries, arg_name) - return tree_unflatten(treedef, entries), entries, treedef + return tree_unflatten(treedef, entries), entries, treedef, all_auto def _check_resources_mismatch(in_axis_resources_flat, is_gda): @@ -621,12 +627,8 @@ def _pjit_call_impl(*args, jaxpr, compiled = _pjit_lower( jaxpr, in_axis_resources, out_axis_resources, resource_env, donated_invars, name, in_is_global).compile() - # Check the GDA sharding and the sharding returned by the auto spmd partitoner - # only if auto_spmd_lowering is enabled. - # TODO(yashkatariya): Move this check to `def call()` method of MeshExecutable. if compiled._auto_spmd_lowering: - in_pspec, _ = _get_sharding_from_executable(compiled.xla_executable, resource_env.physical_mesh) - _check_gda_xla_sharding_match(args, in_pspec) + pxla._check_gda_xla_sharding_match(args, compiled._in_axes) distributed_debug_log(("Running pjit'd function", name), ("mesh", resource_env.physical_mesh)) return compiled.unsafe_call(*args) @@ -955,7 +957,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r def with_sharding_constraint(x, axis_resources): x_flat, tree = tree_flatten(x) - parsed_axis_resources, entries, _ = _prepare_axis_resources( + parsed_axis_resources, entries, _, _ = _prepare_axis_resources( axis_resources, "axis_resources", allow_unconstrained_dims=True) axis_resources_flat = tuple( flatten_axes("with_sharding_constraint axis_resources", @@ -1093,25 +1095,6 @@ def _calc_is_global_sequence(in_positional_semantics, in_axis_resources): ips == maps._PositionalSemantics.GLOBAL or p.partitions == () for ips, p in safe_zip(in_positional_semantics, in_axis_resources)) -def _check_gda_xla_sharding_match(args, in_pspec): - for arg, ip in safe_zip(args, in_pspec): - if not isinstance(arg, GDA): - continue - - gda_cpspec = CanonicalizedParsedPartitionSpec( - ParsedPartitionSpec.from_user_input( - arg.mesh_axes, arg_name="GDA mesh_axes")) - in_cpspec = CanonicalizedParsedPartitionSpec( - ParsedPartitionSpec.from_user_input(ip, arg_name="auto sharding pspec")) - if in_cpspec != gda_cpspec: - raise ValueError( - "GDA sharding does not match the sharding returned by auto spmd " - "partitioner. Did you create the GDA with the input sharding " - "returned by XLA? If yes, please file a bug. " - f"Got GDA spec: {gda_cpspec.user_spec} and " - f"auto sharding spec: {in_cpspec.user_spec} for GDA: {arg}") - - def _get_in_positional_semantics(global_avals: bool, arg) -> maps._PositionalSemantics: if isinstance(arg, GDA): return maps._PositionalSemantics.GLOBAL diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 5ff1be151ed7..472cba821f87 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2361,15 +2361,17 @@ def _get_array_mapping_from_executable( class MeshExecutable(stages.Executable): __slots__ = ['xla_executable', 'unsafe_call', '_input_avals', - '_auto_spmd_lowering'] + '_in_axes', '_out_axes', '_auto_spmd_lowering'] def __init__(self, xla_executable, unsafe_call, input_avals, - auto_spmd_lowering): + in_axes, out_axes, auto_spmd_lowering): self.xla_executable = xla_executable self.unsafe_call = unsafe_call # input_avals is a list of global and local avals. Aval is global if input # is a GDA else local. self._input_avals = input_avals + self._in_axes = in_axes + self._out_axes = out_axes self._auto_spmd_lowering = auto_spmd_lowering @staticmethod @@ -2429,7 +2431,8 @@ def from_hlo(name: str, handle_args = InputsHandler(xla_executable.local_devices(), input_specs, input_indices) unsafe_call = ExecuteReplicated(xla_executable, backend, handle_args, handle_outs) - return MeshExecutable(xla_executable, unsafe_call, input_avals, auto_spmd_lowering) + return MeshExecutable(xla_executable, unsafe_call, input_avals, + in_axes, out_axes, auto_spmd_lowering) # -- stages.Executable protocol @@ -2440,13 +2443,28 @@ def hlo_modules(self): return self.xla_executable.hlo_modules() def call(self, *args): - # TODO(yashkatariya): Add a AOT lowering test where GDA is an input. arg_avals = map(xla.abstractify, args) ref_avals = self._input_avals dispatch.check_arg_avals_for_call(ref_avals, arg_avals) + # Check the GDA sharding and the input sharding. + _check_gda_xla_sharding_match(args, self._in_axes) return self.unsafe_call(*args) +def _check_gda_xla_sharding_match(args, in_array_mappings): + from jax.experimental.global_device_array import GlobalDeviceArray, _get_array_mapping + + for arg, inp_array_mapping in safe_zip(args, in_array_mappings): + if not isinstance(arg, GlobalDeviceArray): + continue + gda_array_mapping = _get_array_mapping(arg.mesh_axes) + if inp_array_mapping != gda_array_mapping: + raise ValueError( + "GDA sharding does not match the input sharding. " + f"Got GDA spec: {array_mapping_to_axis_resources(gda_array_mapping)} and " + f"auto sharding spec: {array_mapping_to_axis_resources(inp_array_mapping)} for GDA: {arg}") + + _forbidden_primitives = { 'xla_pmap': 'pmap', 'sharded_call': 'sharded_jit', diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f95804fceda6..71d5d32d56d1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1171,6 +1171,18 @@ def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self): self.assertEqual(before_cache.hits + 1, after_cache.hits) self.assertEqual(before_cache.misses, after_cache.misses) + def test_pjit_gda_aot_sharding_mismatch(self): + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_input_shape = (8, 2) + input_gda = create_gda(global_input_shape, global_mesh, P('x', 'y')) + + with global_mesh: + f = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x')) + compiled = f.lower(jax.ShapedArray(global_input_shape, jnp.float32)).compile() + with self.assertRaisesRegex( + ValueError, "GDA sharding does not match the input sharding."): + compiled(input_gda) + class AutoShardingPjitTest(jtu.JaxTestCase): @@ -1253,9 +1265,11 @@ def test_xla_gda_sharding_mismatch(self): gda = create_gda(global_input_shape, global_mesh, different_pspec, global_input_data) with self.assertRaisesRegex( - ValueError, - "GDA sharding does not match the sharding returned by auto spmd " - "partitioner"): + ValueError, "GDA sharding does not match the input sharding."): + sharding_info.compiled(gda) + + with self.assertRaisesRegex( + ValueError, "GDA sharding does not match the input sharding."): f(gda)