From 0bc8f8abeb72870ea2fbf2b5d2a1279e993731d2 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 11 Jul 2022 16:26:39 -0700 Subject: [PATCH] * Check if the device assignment is the same across input and output shardings. * Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources. PiperOrigin-RevId: 460326054 --- jax/experimental/pjit.py | 91 +++++++++++++++++++++------------------- tests/pjit_test.py | 69 ++++++++++++++++++++++++++++-- 2 files changed, 115 insertions(+), 45 deletions(-) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 8db381a21a13..9f4cbc5d928d 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -314,11 +314,8 @@ def infer_params(*args, _global_avals=False, **kwargs): donated_invars = (False,) * len(args_flat) if config.jax_array: - in_shardings = _get_and_check_in_shardings( - dyn_args, in_axis_resources, pjit_mesh, in_tree) - out_shardings = out_axis_resources - _check_array_device_assignment( - pjit_mesh, tuple(tree_flatten(out_shardings)[0])) + in_shardings, out_shardings = _get_and_check_in_and_out_shardings( + args_flat, in_axis_resources, out_axis_resources, pjit_mesh, in_tree) else: in_shardings = tree_map(lambda x: _create_mesh_pspec_sharding(pjit_mesh, x), in_axis_resources) @@ -406,34 +403,32 @@ def hashable_pytree(pytree): closure=(treedef, vals)) -def _get_and_check_in_shardings(args, pjit_in_shardings, pjit_mesh, in_tree): - try: - # tree_map over `args` to preserve the pytree structure of args. - arg_in_shardings = tree_map(lambda x: x.sharding, args) - except AttributeError: - arg_in_shardings = None - - arg_in_shardings_flat = tuple(tree_flatten(arg_in_shardings)[0]) +def _get_and_check_in_and_out_shardings(args_flat, pjit_in_shardings, out_shardings, + pjit_mesh, in_tree): + arg_in_shardings_flat = tuple(a.sharding if hasattr(a, 'sharding') else _UNSPECIFIED + for a in args_flat) + arg_ndims = tuple(a.ndim for a in args_flat) if _is_unspecified(pjit_in_shardings): - if arg_in_shardings is None: - raise ValueError('Please specify sharding either on the args or on pjit.') - else: - # This function is cached. - _check_array_device_assignment(pjit_mesh, arg_in_shardings_flat) - return arg_in_shardings + # If pjit_in_shardings is unspecified, then arg_in_shardings cannot have + # unspecified in them. + for a in arg_in_shardings_flat: + if _is_unspecified(a): + raise ValueError('Please specify sharding either on the arg or on ' + f'pjit. Found sharding {a} which is invalid.') + in_shardings_flat = arg_in_shardings_flat else: - if arg_in_shardings is None: - _check_array_device_assignment( - pjit_mesh, tuple(tree_flatten(pjit_in_shardings)[0])) - return pjit_in_shardings - else: - # This function is cached. - _check_pjit_arg_shardings( - hashable_pytree(pjit_in_shardings), arg_in_shardings_flat, in_tree) - return arg_in_shardings + # This function is cached. + in_shardings_flat = _get_and_check_pjit_arg_shardings( + hashable_pytree(pjit_in_shardings), arg_in_shardings_flat, arg_ndims, + in_tree) - assert False, "Please open a bug report!" # This should be unreachable. + out_shardings_flat = tuple(tree_flatten(out_shardings)[0]) + # Check if the device assignment is the same across inputs and outputs. + # This function is cached. + _check_array_device_assignment(pjit_mesh, in_shardings_flat + out_shardings_flat) + + return tree_unflatten(in_tree, in_shardings_flat), out_shardings def _create_mesh_pspec_sharding(mesh, x): @@ -1325,7 +1320,8 @@ def _check_array_device_assignment(pjit_mesh, shardings): # If mesh is empty, then check if all devices across shardings are # equal if first_device_assignment != arr_device_assignment: - raise ValueError("Devices of all `Array` inputs should be the same. " + raise ValueError("Devices of all `Array` inputs and outputs should be " + "the same. " f"Got array devices: {first_device_assignment},\n " f"another array devices: {arr_device_assignment}") else: @@ -1337,21 +1333,32 @@ def _check_array_device_assignment(pjit_mesh, shardings): f"Array devices: {arr_device_assignment}") @cache() -def _check_pjit_arg_shardings(pjit_in_shardings, arg_in_shardings_flat, - in_tree): +def _get_and_check_pjit_arg_shardings(pjit_in_shardings, arg_in_shardings_flat, + arg_ndims, in_tree): pjit_in_shardings_flat = flatten_axis_resources( "pjit in_shardings", in_tree, pjit_in_shardings(), tupled_args=True) - if pxla._check_if_any_auto(pjit_in_shardings_flat): - raise ValueError('Passing sharding on pjit and on args while using the ' - 'auto spmd partitioner is not allowed. Please call the ' - 'compiled object on the inputs.') - - for p, a in safe_zip(pjit_in_shardings_flat, arg_in_shardings_flat): - if p.normalize() != a.normalize(): - raise ValueError('Sharding passed to pjit does not match the sharding ' - 'on the respective arg. ' - f'Got pjit sharding: {p},\narg sharding: {a}') + out = [] + for pjit_sharding, arg_sharding, ndim in safe_zip( + pjit_in_shardings_flat, arg_in_shardings_flat, arg_ndims): + # If the sharding of the arg is not known, replace it with the sharding on + # pjit. + if _is_unspecified(arg_sharding): + out.append(pjit_sharding) + elif _is_auto(pjit_sharding): + raise ValueError('Passing sharding on pjit and on args while using the ' + 'auto spmd partitioner is not allowed. Please call the ' + 'compiled object on the inputs.') + else: + if pjit_sharding._to_xla_op_sharding(ndim) != arg_sharding._to_xla_op_sharding(ndim): + raise ValueError('Sharding passed to pjit does not match the sharding ' + 'on the respective arg. ' + f'Got pjit sharding: {pjit_sharding},\n' + f'arg sharding: {arg_sharding}') + out.append(pjit_sharding) + + assert not any(_is_unspecified(o) for o in out) + return tuple(out) def _maybe_check_pjit_gda_mesh(args, mesh): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a71d5316142a..bf359763606a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1451,10 +1451,28 @@ def test_non_array_input_error(self): out_axis_resources=MeshPspecSharding( global_mesh, P('x', 'y'))) with self.assertRaisesRegex( - ValueError, ('Please specify sharding either on the args or on ' - 'pjit.')): + ValueError, 'Please specify sharding either on the arg or on pjit'): f(input_data) + def test_numpy_array_input(self): + input_shape = (8, 2) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + input_data = np.arange( + prod(input_shape), dtype=np.float32).reshape(input_shape) + with jax._src.config.jax_array(True): + with global_mesh: + f = pjit(lambda x: x, + in_axis_resources=MeshPspecSharding( + global_mesh, P(None)), + out_axis_resources=MeshPspecSharding( + global_mesh, P('x', 'y'))) + out = f(input_data) + self.assertIsInstance(out, array.Array) + for s in out.addressable_shards: + self.assertEqual(s.data.shape, (2, 1)) + self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertArraysEqual(out._value, input_data) + def test_unspecified_out_axis_resources(self): global_input_shape = (8, 2) global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) @@ -1561,13 +1579,25 @@ def test_in_axis_resources_same_as_array_sharding(self): in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))(input_array) self.assertIsInstance(out, array.Array) + def test_in_axis_resources_error(self): + mesh = jtu.create_global_mesh((2,), ('x')) + with jax._src.config.jax_array(True): + with self.assertRaisesRegex( + ValueError, + ('When `config.jax_array` flag is enabled, ' + 'in_axis_resources should contain instances of `Sharding` ' + 'or `pjit.AUTO`.')): + pjit(lambda x: x, + in_axis_resources=(MeshPspecSharding(mesh, P('x')), + pjit_lib._UNSPECIFIED)) def test_out_axis_resources_error(self): with jax._src.config.jax_array(True): with self.assertRaisesRegex( ValueError, ('When `config.jax_array` flag is enabled, ' - 'out_axis_resources should contain instances of `Sharding`.')): + 'out_axis_resources should contain instances of `Sharding` ' + 'or `pjit.AUTO`.')): pjit(lambda x: x, out_axis_resources=P('x')) def test_no_input_output(self): @@ -1630,6 +1660,39 @@ def test_array_device_assignment_mismatch_out_shardings(self): out_axis_resources=(MeshPspecSharding(m1, spec), MeshPspecSharding(m2, spec)))(a1, a1) + def test_array_device_assignment_mismatch_in_and_out_shardings(self): + input_shape = (8, 2) + m1 = jtu.create_global_mesh((4, 2), ('x', 'y')) + m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) + spec = P('x', 'y') + + a1, _ = create_array(input_shape, m2, spec) + + with jax._src.config.jax_array(True): + with m1: + with self.assertRaisesRegex( + ValueError, "Pjit's devices and Array's devices should be equal"): + pjit(lambda x, y: (x, y), + in_axis_resources=MeshPspecSharding(m2, spec), + out_axis_resources=MeshPspecSharding(m1, spec))(a1, a1) + + def test_mixed_inputs(self): + input_shape = (8, 2) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + spec = P('x', 'y') + + a1, input_data = create_array(input_shape, global_mesh, spec) + + with jax._src.config.jax_array(True): + with global_mesh: + f = pjit(lambda x, y: (x, y), + in_axis_resources=MeshPspecSharding(global_mesh, P(None))) + with self.assertRaisesRegex( + ValueError, + ('Sharding passed to pjit does not match the sharding on the ' + 'respective arg')): + f(input_data, a1) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")