From fc7a71dc89eeb943004d02d95ddfe28b071fcc7c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 29 Aug 2022 14:49:17 -0700 Subject: [PATCH] Remove `device_replica_id_map` from the Sharding interface because the standalone function should be more than enough to use. The major use-case of this is for checkpointing and accessing addressable_shards which accesses the standalone function makes it work. PiperOrigin-RevId: 470820443 --- jax/experimental/array.py | 13 +++------- jax/experimental/sharding.py | 28 ++++++++------------- tests/array_test.py | 49 ++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/jax/experimental/array.py b/jax/experimental/array.py index fc788f35d0fa..1008a43e5353 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -15,6 +15,7 @@ from __future__ import annotations import operator as op +import functools import numpy as np from typing import Sequence, Tuple, Callable, Union, Optional, cast, List @@ -31,8 +32,8 @@ from jax._src.api import device_put from jax._src.numpy.ndarray import ndarray from jax.interpreters import pxla, xla, mlir -from jax.experimental.sharding import (Sharding, SingleDeviceSharding, - XLACompatibleSharding) +from jax.experimental.sharding import ( + Sharding, SingleDeviceSharding, XLACompatibleSharding, device_replica_id_map) Shape = Tuple[int, ...] Device = xc.Device @@ -81,13 +82,7 @@ def index(self) -> Index: @property def replica_id(self) -> int: - try: - device_replica_id_fn = self._sharding.device_replica_id_map # pytype: disable=attribute-error - except AttributeError: - raise ValueError('Cannot calculate replica ids from sharding: ' - f'{self._sharding}. Please create a device to replica id ' - 'mapping for your sharding.') from None - return device_replica_id_fn(self._global_shape)[self.device] + return device_replica_id_map(self._sharding, self._global_shape)[self.device] def _reconstruct_array(fun, args, arr_state, aval_state): diff --git a/jax/experimental/sharding.py b/jax/experimental/sharding.py index 59700b4f8087..1ac2343df66a 100644 --- a/jax/experimental/sharding.py +++ b/jax/experimental/sharding.py @@ -78,10 +78,6 @@ class XLACompatibleSharding(Sharding): def _device_assignment(self) -> XLADeviceAssignment: raise NotImplementedError('Subclasses should implement this method.') - @abc.abstractmethod - def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: - raise NotImplementedError('Subclasses should implement this method.') - @abc.abstractmethod def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]: raise NotImplementedError('Subclasses should implement this method.') @@ -133,10 +129,18 @@ def _hashed_index(x) -> int: @functools.lru_cache(maxsize=4096) -def _device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]: +def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]: + try: + device_indices_map_fn = sharding.devices_indices_map + except AttributeError: + raise ValueError( + f'Cannot calculate replica ids from sharding: {sharding}. Please ' + 'create a device to index mapping for your sharding from which replica ' + 'ids will be calculated.') from None + index_to_replica: Dict[int, int] = Counter() out = {} - for device, index in sharding.devices_indices_map(global_shape).items(): + for device, index in device_indices_map_fn(global_shape).items(): h_index = _hashed_index(index) replica_id = index_to_replica[h_index] index_to_replica[h_index] += 1 @@ -208,9 +212,6 @@ def devices_indices_map( # `get_shard_indices` is cached. return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec) - def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: - return _device_replica_id_map(self, global_shape) - @pxla.maybe_cached_property def _device_assignment(self) -> XLADeviceAssignment: return list(self.mesh.devices.flat) @@ -270,9 +271,6 @@ def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Index]: return {self._device: (slice(None),) * len(global_shape)} - def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: - return {self._device: 0} - @property def _device_assignment(self) -> XLADeviceAssignment: return [self._device] @@ -298,9 +296,6 @@ def devices_indices_map( indices = pxla.spec_to_indices(global_shape, self.sharding_spec) return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore - def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: - return _device_replica_id_map(self, global_shape) - @pxla.maybe_cached_property def _device_assignment(self) -> XLADeviceAssignment: return list(self.devices.flat) @@ -374,9 +369,6 @@ def devices_indices_map( len(self._devices)) return dict(safe_zip(self._devices, indices)) - def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]: - return _device_replica_id_map(self, global_shape) - @property def _device_assignment(self) -> XLADeviceAssignment: return list(self._devices) diff --git a/tests/array_test.py b/tests/array_test.py index 86cfe11885fd..eb643cafb701 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -61,6 +61,54 @@ def test_jax_array_value(self, mesh_axes): self.assertArraysEqual(arr._value, global_data) self.assertArraysEqual(arr._npy_value, global_data) + @parameterized.named_parameters( + ("mesh_x_y", P("x", "y"), + # There are more slices but for convienient purposes, checking for only + # 2. The indices + shard_shape + replica_id should be unique enough. + ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))), + (2, 1), + [0, 0, 0, 0, 0, 0, 0, 0], False), + ("mesh_x", P("x"), + ((slice(0, 2), slice(None)), (slice(0, 2), slice(None))), + (2, 2), + [0, 1, 0, 1, 0, 1, 0, 1], False), + ("mesh_y", P("y"), + ((slice(0, 4), slice(None)), (slice(4, 8), slice(None))), + (4, 2), + [0, 0, 1, 1, 2, 2, 3, 3], False), + ("mesh_none_y", P(None, "y"), + ((slice(None), slice(0, 1)), (slice(None), slice(1, 2))), + (8, 1), + [0, 0, 1, 1, 2, 2, 3, 3], False), + ("mesh_xy", P(("x", "y")), + ((slice(0, 1), slice(None)), (slice(1, 2), slice(None))), + (1, 2), + [0, 0, 0, 0, 0, 0, 0, 0], False), + ("mesh_fully_replicated", P(), + ((slice(None), slice(None)), (slice(None), slice(None))), + (8, 2), + [0, 1, 2, 3, 4, 5, 6, 7], True), + ) + def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, + expected_replica_ids, expected_is_fully_replicated): + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_input_shape = (8, 2) + s = sharding.MeshPspecSharding(global_mesh, mesh_axes) + arr, global_input_data = create_array(global_input_shape, s) + self.assertEqual(arr.ndim, 2) + self.assertEqual(arr.size, 16) + self.assertEqual(arr.addressable_shards[0].index, expected_index[0]) + self.assertEqual(arr.addressable_shards[1].index, expected_index[1]) + replica_ids = [i.replica_id for i in arr.addressable_shards] + self.assertListEqual(replica_ids, expected_replica_ids) + self.assertListEqual([i.device.id for i in arr.addressable_shards], + [0, 1, 2, 3, 4, 5, 6, 7]) + self.assertEqual(arr.is_fully_replicated(), expected_is_fully_replicated) + for s in arr.addressable_shards: + self.assertEqual(s.data.aval, + jax.ShapedArray(expected_shard_shape, s.data.dtype)) + self.assertArraysEqual(s.data, global_input_data[s.index]) + def test_array_delete(self): with jax_config.jax_array(True): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) @@ -84,6 +132,7 @@ def test_device_put(self): self.assertArraysEqual(i.data, numpy_array) self.assertEqual(i.device, jax.devices()[0]) self.assertEqual(i.index, (slice(None),)) + self.assertEqual(i.replica_id, 0) def test_device_put_array_delete(self): with jax_config.jax_array(True):