From 28741b8e0da393b86e94eb5c2172cc331c76fb5c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 15 Sep 2022 13:26:57 -0700 Subject: [PATCH] Some miscellaneous changes to make tests pass when jax.Array is enabled by default. 1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA. 2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally. 3. Some tests changes to make them pass PiperOrigin-RevId: 474642889 --- jax/_src/lax/lax.py | 3 +- jax/_src/test_util.py | 6 +-- jax/experimental/array.py | 19 ++++++++ jax/experimental/maps.py | 2 + jax/experimental/pjit.py | 94 +++++++++++++++++++++------------------ tests/BUILD | 1 + tests/lax_test.py | 6 ++- tests/pjit_test.py | 45 +++++++++---------- tests/random_test.py | 5 +-- 9 files changed, 102 insertions(+), 79 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 30ffce41516c..1a7f9b238451 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1331,7 +1331,8 @@ def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None # (so it works in staged-out code as well as 'eager' code). Related to # equi-sharding. if (config.jax_array and hasattr(x, 'sharding') and - not dispatch.is_single_device_sharding(x.sharding)): + not dispatch.is_single_device_sharding(x.sharding) and + not isinstance(x.sharding, sharding.PmapSharding)): return array.make_array_from_callback( fill_shape, x.sharding, lambda idx: val[idx]) # type: ignore[arg-type] return val diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 44c7bb42ab07..491e456952a2 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -899,11 +899,11 @@ class BufferDonationTestCase(JaxTestCase): assertNotDeleted = lambda self, x: self._assertDeleted(x, False) def _assertDeleted(self, x, deleted): - if hasattr(x, "device_buffer"): - self.assertEqual(x.device_buffer.is_deleted(), deleted) - elif hasattr(x, "_arrays"): + if hasattr(x, "_arrays"): for buffer in x._arrays: self.assertEqual(buffer.is_deleted(), deleted) + elif hasattr(x, "device_buffer"): + self.assertEqual(x.device_buffer.is_deleted(), deleted) else: for buffer in x.device_buffers: self.assertEqual(buffer.is_deleted(), deleted) diff --git a/jax/experimental/array.py b/jax/experimental/array.py index ac87a020c92d..3db06e076de3 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -371,6 +371,23 @@ def devices(self) -> List[Device]: self._check_if_deleted() return list(self.sharding.device_set) + # TODO(https://github.com/google/jax/issues/12380): Remove this when DA is + # deleted. + @property + def device_buffer(self) -> DeviceArray: + self._check_if_deleted() + if len(self._arrays) == 1: + return self._arrays[0] + raise ValueError('Length of buffers is greater than 1. Please use ' + '`.device_buffers` instead.') + + # TODO(https://github.com/google/jax/issues/12380): Remove this when SDA is + # deleted. + @property + def device_buffers(self) -> Sequence[DeviceArray]: + self._check_if_deleted() + return self._arrays + @pxla.maybe_cached_property def addressable_shards(self) -> Sequence[Shard]: self._check_if_deleted() @@ -433,6 +450,7 @@ def _value(self) -> np.ndarray: if self._npy_value is None: if self.is_fully_replicated(): self._npy_value = np.asarray(self._arrays[0]) # type: ignore + self._npy_value.flags.writeable = False return cast(np.ndarray, self._npy_value) if not self.is_fully_addressable(): @@ -454,6 +472,7 @@ def _value(self) -> np.ndarray: if not replica_id_exists or s.replica_id == 0: npy_value[s.index] = np.asarray(s.data._arrays[0]) # type: ignore # [union-attr] self._npy_value = npy_value # type: ignore + self._npy_value.flags.writeable = False # https://docs.python.org/3/library/typing.html#typing.cast return cast(np.ndarray, self._npy_value) diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 07ba1214546e..b32fe3d1c3a1 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -1817,6 +1817,8 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env, for arg, xmap_array_mapping in safe_zip(args_flat, mesh_in_axes): if isinstance(arg, (GlobalDeviceArray, Array)): arr_flavor = 'GDA' if isinstance(arg, GlobalDeviceArray) else 'Array' + if arr_flavor == 'Array' and not isinstance(arg.sharding, MeshPspecSharding): + continue mesh = arg.mesh if arr_flavor == 'GDA' else arg.sharding.mesh if mesh != resource_env.physical_mesh: raise ValueError(f"xmap's mesh and {arr_flavor}'s mesh should be equal. " diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index e042a424ca91..b867109976c7 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -261,31 +261,10 @@ def pjit(fun: Callable, # rather than raising an error. https://github.com/google/jax/issues/2367 in_axis_resources = tuple(in_axis_resources) - in_any_auto: bool - if not config.jax_array: - in_axis_resources, _, _, in_any_auto = _prepare_axis_resources( - in_axis_resources, "in_axis_resources") - out_axis_resources, _, _, _ = _prepare_axis_resources( - out_axis_resources, "out_axis_resources") - else: - if not _is_unspecified(in_axis_resources): - # `pjit.AUTO` is allowed partially in `in_axis_resources` i.e. you can - # put sharding instances and `pjit.AUTO` together. - if not all(isinstance(s, Sharding) or _is_auto(s) - for s in tree_flatten(in_axis_resources)[0]): - raise ValueError('When `config.jax_array` flag is enabled, ' - 'in_axis_resources should contain instances of ' - '`Sharding` or `pjit.AUTO`.') - - # `out_axis_resources` should be instances of `Sharding` if it's not - # unspecified. For `AUTO` sharding, it can only be used with - # MeshPspecSharding. - if not _is_unspecified(out_axis_resources): - if not all(isinstance(s, Sharding) or _is_auto(s) - for s in tree_flatten(out_axis_resources)[0]): - raise ValueError('When `config.jax_array` flag is enabled, ' - 'out_axis_resources should contain instances of ' - '`Sharding` or `pjit.AUTO`.') + in_axis_resources, _, _, in_any_auto = _prepare_axis_resources( + in_axis_resources, "in_axis_resources") + out_axis_resources, _, _, _ = _prepare_axis_resources( + out_axis_resources, "out_axis_resources") static_argnums = _ensure_index_tuple(static_argnums) donate_argnums = _ensure_index_tuple(donate_argnums) @@ -323,7 +302,10 @@ def infer_params(*args, _global_avals=False, **kwargs): donated_invars = (False,) * len(args_flat) if config.jax_array: - in_shardings, out_shardings = in_axis_resources, out_axis_resources + in_shardings = tree_map( + lambda x: _create_sharding_for_array(pjit_mesh, x), in_axis_resources) + out_shardings = tree_map( + lambda x: _create_sharding_for_array(pjit_mesh, x), out_axis_resources) else: in_shardings = tree_map( lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), @@ -424,6 +406,19 @@ def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x): return pxla._create_mesh_pspec_sharding(mesh, x.user_spec, x) +def _create_sharding_for_array(mesh, x): + if isinstance(x, XLACompatibleSharding) or _is_auto(x) or _is_unspecified(x): + return x + if mesh.empty: + raise RuntimeError("pjit requires a non-empty mesh! Is a mesh defined at " + "the call site? Alternatively, provide a " + "XLACompatibleSharding to pjit and then the " + "mesh context manager is not required.") + # A nice user error is raised in _prepare_axis_resources. + assert isinstance(x, ParsedPartitionSpec) + return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x) + + def flatten_axis_resources(what, tree, shardings, tupled_args): try: return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args)) @@ -585,9 +580,6 @@ def pjit_check_aval_sharding( for aval, s in zip(flat_avals, shardings): if _is_unspecified_or_from_gda_or_auto(s): continue - if not isinstance(s, XLACompatibleSharding): - raise ValueError(f'One of {what_aval} got sharding {s} which is not a ' - 'subclass of XLACompatibleSharding.') global_str = "" if s.is_fully_addressable else " global" shape = aval.shape try: @@ -747,14 +739,25 @@ def _prepare_axis_resources(axis_resources, # be 1 entry for that since _UNSPECIFIED is a private API. _check_all_or_none_unspecified(entries, arg_name) any_auto = pxla._check_if_any_auto(entries) - entries = [ - (entry if _is_unspecified_or_from_gda_or_auto(entry) - else ParsedPartitionSpec.from_user_input( + new_entries = [] + for entry in entries: + if _is_unspecified_or_from_gda_or_auto(entry): + if config.jax_array and _is_from_gda(entry): + raise ValueError('`FROM_GDA` cannot be set when config.jax_array is ' + 'enabled. Leave in_axis_resources empty or populate ' + 'it with shardings.') + new_entries.append(entry) + elif isinstance(entry, Sharding): + if not isinstance(entry, XLACompatibleSharding): + raise ValueError(f'One of {what} got sharding {entry} which is not a ' + 'subclass of XLACompatibleSharding.') + new_entries.append(entry) + else: + new_entries.append(ParsedPartitionSpec.from_user_input( entry, what, allow_unconstrained_dims=allow_unconstrained_dims)) - for entry in entries - ] - _check_unique_resources(entries, arg_name) - return tree_unflatten(treedef, entries), entries, treedef, any_auto + + _check_unique_resources(new_entries, arg_name) + return tree_unflatten(treedef, new_entries), new_entries, treedef, any_auto def _check_resources_mismatch(in_axis_resources_flat, is_gda): @@ -765,7 +768,9 @@ def _check_resources_mismatch(in_axis_resources_flat, is_gda): def _check_unique_resources(axis_resources, arg_name): for arg_axis_resources in axis_resources: if not arg_axis_resources: continue - if _is_unspecified_or_from_gda_or_auto(arg_axis_resources): continue + if (_is_unspecified_or_from_gda_or_auto(arg_axis_resources) or + isinstance(arg_axis_resources, XLACompatibleSharding)): + continue constrained_dims = [d for d in arg_axis_resources if d is not None] resource_counts = Counter(it.chain.from_iterable(constrained_dims)) if not resource_counts: continue @@ -806,7 +811,8 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh): if _is_unspecified(arg_s): resolved_in_shardings.append(OpShardingSharding.get_replicated(da)) else: - resolved_in_shardings.append(to_op_sharding_sharding(arg_s, arg.ndim)) + resolved_in_shardings.append(to_op_sharding_sharding( + cast(XLACompatibleSharding, arg_s), arg.ndim)) else: if not _is_unspecified(arg_s): if committed and not pxla.are_op_shardings_equal( @@ -1290,16 +1296,16 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r def with_sharding_constraint(x, axis_resources): x_flat, tree = tree_flatten(x) - if not config.jax_array: - axis_resources, _, _, _ = _prepare_axis_resources( - axis_resources, "axis_resources", allow_unconstrained_dims=True) + axis_resources, _, _, _ = _prepare_axis_resources( + axis_resources, "axis_resources", allow_unconstrained_dims=True) axis_resources_flat = tuple( flatten_axes("with_sharding_constraint sharding", tree, axis_resources)) resource_env = pxla.thread_resources.env mesh = resource_env.physical_mesh if config.jax_array: - sharding_flat = axis_resources_flat + sharding_flat = [_create_sharding_for_array(mesh, a) + for a in axis_resources_flat] unconstrained_dims = [ get_unconstrained_dims(s) if isinstance(s, MeshPspecSharding) else {} for s in sharding_flat @@ -1410,11 +1416,11 @@ def get_array_mapping( if axes is not None for axis in axes) -def to_op_sharding_sharding(s, ndim): +def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> OpShardingSharding: if isinstance(s, OpShardingSharding): return s op_sharding_sharding = OpShardingSharding( - s._device_assignment, s._to_xla_op_sharding(ndim)) # type: ignore + s._device_assignment, s._to_xla_op_sharding(ndim)) op_sharding_sharding._original_sharding = s return op_sharding_sharding diff --git a/tests/BUILD b/tests/BUILD index c6347dcd9414..25b5913343d3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -416,6 +416,7 @@ jax_test( jax_test( name = "lax_test", srcs = ["lax_test.py"], + pjrt_c_api_bypass = True, shard_count = { "cpu": 40, "gpu": 40, diff --git a/tests/lax_test.py b/tests/lax_test.py index 001fdef2a16c..4782551dbd91 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2938,9 +2938,11 @@ def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out): x_buf = x.device_buffer y_buf = y.device_buffer if np.dtype(dtype_in) == np.dtype(dtype_out): - self.assertIs(x_buf, y_buf) + self.assertEqual(x_buf.unsafe_buffer_pointer(), + y_buf.unsafe_buffer_pointer()) else: - self.assertIsNot(x_buf, y_buf) + self.assertNotEqual(x_buf.unsafe_buffer_pointer(), + y_buf.unsafe_buffer_pointer()) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_fn={}_indexdtype={}" diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2acc0b42e7e1..9562efbec952 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1816,27 +1816,6 @@ def test_in_axis_resources_same_as_array_sharding(self): in_axis_resources=MeshPspecSharding(global_mesh, P('x' ,'y')))(input_array) self.assertIsInstance(out, array.Array) - def test_in_axis_resources_error(self): - mesh = jtu.create_global_mesh((2,), ('x')) - with jax_array(True): - with self.assertRaisesRegex( - ValueError, - ('When `config.jax_array` flag is enabled, ' - 'in_axis_resources should contain instances of `Sharding` ' - 'or `pjit.AUTO`.')): - pjit(lambda x: x, - in_axis_resources=(MeshPspecSharding(mesh, P('x')), - pjit_lib._UNSPECIFIED)) - - def test_out_axis_resources_error(self): - with jax_array(True): - with self.assertRaisesRegex( - ValueError, - ('When `config.jax_array` flag is enabled, ' - 'out_axis_resources should contain instances of `Sharding` ' - 'or `pjit.AUTO`.')): - pjit(lambda x: x, out_axis_resources=P('x')) - def test_no_input_output(self): with jax_array(True): def f(): @@ -2118,16 +2097,32 @@ def test_not_xlacompatible_sharding_error(self): with self.assertRaisesRegex( ValueError, - 'One of pjit arguments got sharding.*which is not a subclass of ' - 'XLACompatibleSharding.'): + 'One of in_axis_resources leaf specifications got sharding.*which is ' + 'not a subclass of XLACompatibleSharding.'): pjit(lambda x: x, in_axis_resources=ts)(arr) with self.assertRaisesRegex( ValueError, - 'One of pjit outputs got sharding.*which is not a subclass of ' - 'XLACompatibleSharding.'): + 'One of out_axis_resources leaf specifications got sharding.*which is ' + 'not a subclass of XLACompatibleSharding.'): pjit(lambda x: x, out_axis_resources=ts)(arr) + @jax_array(True) + def test_array_enabled_non_empty_mesh_with_pspec(self): + arr = jnp.array([1, 2, 3]) + with self.assertRaisesRegex( + RuntimeError, + "pjit requires a non-empty mesh!.*Alternatively, provide a " + "XLACompatibleSharding to pjit and then the mesh context manager is " + "not required."): + pjit(lambda x: x, in_axis_resources=P('x'))(arr) + + with self.assertRaisesRegex( + TypeError, + "in_axis_resources leaf specifications are expected to be PartitionSpec " + "instances or None, but got x"): + pjit(lambda x: x, in_axis_resources='x') + class TempSharding(Sharding): diff --git a/tests/random_test.py b/tests/random_test.py index 9f2f8f0e76c9..2517ca093c1f 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1451,10 +1451,7 @@ def test_random_split_doesnt_device_put_during_tracing(self): key = self.seed_prng(1).block_until_ready() with jtu.count_device_put() as count: jax.jit(random.split)(key) - if config.jax_array: - self.assertEqual(count[0], 0) - else: - self.assertEqual(count[0], 1) # 1 for the argument device_put + self.assertLessEqual(count[0], 1) # 1 for the argument device_put @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": f"_dtype={dtype}", "dtype": dtype}