Skip to content

Commit

Permalink
Remove fast_path_args from Array and add id checks to Sharding's `_…
Browse files Browse the repository at this point in the history
…_eq__` method as a fast shortcut.

Also the C++ pjit path should help optimize the dispatch path.

PiperOrigin-RevId: 475163903
  • Loading branch information
yashk2810 authored and jax authors committed Sep 18, 2022
1 parent 9d8363a commit a24726d
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 56 deletions.
67 changes: 15 additions & 52 deletions jax/experimental/array.py
Expand Up @@ -43,11 +43,6 @@
ArrayLike = Union[np.ndarray, DeviceArray]


class _ArrayFastPathArgs(NamedTuple):
devices_indices_map: Mapping[Device, Optional[Index]]
addressable_device_assignment: Sequence[Device]


class Shard:
"""A single data shard of an Array.
Expand All @@ -61,13 +56,11 @@ class Shard:
"""

def __init__(self, device: Device, sharding: Sharding, global_shape: Shape,
data: Optional[Array] = None,
_fast_path_args: Optional[_ArrayFastPathArgs] = None):
data: Optional[Array] = None):
self.device = device
self._sharding = sharding
self._global_shape = global_shape
self.data = data
self._fast_path_args = _fast_path_args

def __repr__(self):
try:
Expand All @@ -78,16 +71,13 @@ def __repr__(self):

@property
def index(self) -> Index:
if self._fast_path_args is None:
try:
device_indices_map_fn = self._sharding.devices_indices_map
except AttributeError:
raise ValueError('Cannot calculate indices from sharding: '
f'{self._sharding}. Please create a device to index '
'mapping for your sharding.') from None
index = device_indices_map_fn(self._global_shape)[self.device]
else:
index = self._fast_path_args.devices_indices_map[self.device]
try:
device_indices_map_fn = self._sharding.devices_indices_map
except AttributeError:
raise ValueError('Cannot calculate indices from sharding: '
f'{self._sharding}. Please create a device to index '
'mapping for your sharding.') from None
index = device_indices_map_fn(self._global_shape)[self.device]
assert index is not None
return index

Expand All @@ -112,8 +102,7 @@ class Array:
@pxla.use_cpp_method
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
arrays: Union[Sequence[DeviceArray], Sequence[Array]],
committed: bool, _skip_checks: bool = False,
_fast_path_args: Optional[_ArrayFastPathArgs] = None):
committed: bool, _skip_checks: bool = False):
# NOTE: the actual implementation of the constructor is moved to C++.

self.aval = aval
Expand All @@ -127,8 +116,6 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding,
# See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
# for what committed means.
self._committed = committed
# Optionally precomputed for performance.
self._fast_path_args = _fast_path_args
self._npy_value = None

if not _skip_checks or config.jax_enable_checks:
Expand Down Expand Up @@ -161,10 +148,7 @@ def _rearrange(self):
# XLACompatibleSharding. But leave the rearragement to XLACompatibleSharding
# only.
if isinstance(self.sharding, XLACompatibleSharding):
if self._fast_path_args is None:
addressable_da = cast(XLACompatibleSharding, self.sharding)._addressable_device_assignment
else:
addressable_da = self._fast_path_args.addressable_device_assignment
addressable_da = self.sharding._addressable_device_assignment
if len(self._arrays) != len(addressable_da):
raise ValueError(
f"Expected {len(addressable_da)} per-device arrays "
Expand Down Expand Up @@ -264,10 +248,7 @@ def __getitem__(self, idx):
else:
cidx = idx + (slice(None),) * (len(self.shape) - len(idx))
if self._npy_value is None:
if self._fast_path_args is None:
indices = tuple(self.sharding.devices_indices_map(self.shape).values())
else:
indices = tuple(self._fast_path_args.devices_indices_map.values())
indices = tuple(self.sharding.devices_indices_map(self.shape).values())
try:
buf_idx = indices.index(cidx)
except ValueError:
Expand Down Expand Up @@ -399,8 +380,7 @@ def addressable_shards(self) -> Sequence[Shard]:
# of a DA.
array = Array(db.aval, SingleDeviceSharding(device), [db],
committed=self._committed, _skip_checks=True)
out.append(Shard(
device, self.sharding, self.shape, array, self._fast_path_args))
out.append(Shard(device, self.sharding, self.shape, array))
return out

def delete(self):
Expand Down Expand Up @@ -528,14 +508,10 @@ def _array_pmap_shard_arg(x, devices, indices, mode):
if dispatch.is_single_device_sharding(x.sharding):
return pxla._shard_device_array(x, devices, indices, mode)

if x._fast_path_args is None:
x_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
else:
x_indices = tuple(x._fast_path_args.devices_indices_map.values())

# If the sharding of Array does not match pmap's sharding then take the slow
# path which is similar to what SDA does. This slow path reroute only happens
# for `pmap`.
x_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
if indices == x_indices:
return [buf if buf.device() == d else buf.copy_to_device(d)
for buf, d in safe_zip(x._arrays, devices)]
Expand Down Expand Up @@ -591,15 +567,8 @@ def _array_global_result_handler(global_aval, out_sharding, committed,
if core.is_opaque_dtype(global_aval.dtype):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed, is_out_sharding_from_xla)

# Calculate the indices and addressable device assignment once during
# compilation and pass it to the constructor.
_array_fast_path_args = _ArrayFastPathArgs(
out_sharding.devices_indices_map(global_aval.shape),
out_sharding._addressable_device_assignment)
return lambda bufs: Array(global_aval, out_sharding, bufs,
committed=committed, _skip_checks=True,
_fast_path_args=_array_fast_path_args)
committed=committed, _skip_checks=True)
pxla.global_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_global_result_handler
pxla.global_result_handlers[(core.AbstractToken, pxla.OutputType.Array)] = lambda *_: lambda *_: core.token
Expand All @@ -610,13 +579,7 @@ def _array_local_result_handler(aval, sharding, indices):
if core.is_opaque_dtype(aval.dtype):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)

# Calculate the indices and addressable device assignment once during
# compilation and pass it to the constructor.
_array_fast_path_args = _ArrayFastPathArgs(
sharding.devices_indices_map(aval.shape),
sharding._addressable_device_assignment)
return lambda bufs: Array(aval, sharding, bufs, committed=True,
_skip_checks=True, _fast_path_args=_array_fast_path_args)
_skip_checks=True)
pxla.local_result_handlers[(core.ShapedArray, pxla.OutputType.Array)] = _array_local_result_handler
pxla.local_result_handlers[(core.ConcreteArray, pxla.OutputType.Array)] = _array_local_result_handler
4 changes: 1 addition & 3 deletions jax/experimental/gda_serialization/serialization.py
Expand Up @@ -56,9 +56,7 @@ async def create_async_array_from_callback(
dbs = [jax.device_put(array, device)
for array, device in zip(local_arrays, addressable_da)]
aval = jax.ShapedArray(global_shape, dbs[0].dtype)
return array.Array(
aval, inp_sharding, dbs, committed=True,
_fast_path_args=array._ArrayFastPathArgs(device_to_index_map, addressable_da))
return array.Array(aval, inp_sharding, dbs, committed=True)


async def create_async_gda_from_callback(
Expand Down
9 changes: 8 additions & 1 deletion jax/experimental/sharding.py
Expand Up @@ -185,7 +185,8 @@ def __hash__(self):
def __eq__(self, other):
if not isinstance(other, MeshPspecSharding):
return False

if id(self) == id(other):
return True
if id(self.mesh) == id(other.mesh) and self._parsed_pspec == other._parsed_pspec:
return True
return self.mesh == other.mesh and self._parsed_pspec == other._parsed_pspec
Expand Down Expand Up @@ -262,6 +263,8 @@ def __hash__(self):
def __eq__(self, other):
if not isinstance(other, SingleDeviceSharding):
return False
if id(self) == id(other):
return True
return self._device == other._device

@property
Expand Down Expand Up @@ -292,6 +295,8 @@ def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
def __eq__(self, other):
if not isinstance(other, PmapSharding):
return False
if id(self) == id(other):
return True
return (self.sharding_spec == other.sharding_spec and
np.array_equal(self.devices, other.devices))

Expand Down Expand Up @@ -353,6 +358,8 @@ def _op_sharding_hash(self):
def __eq__(self, other):
if not isinstance(other, OpShardingSharding):
return False
if id(self) == id(other):
return True
return (pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) and
self._devices == other._devices)

Expand Down
48 changes: 48 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -2238,6 +2238,54 @@ def test_pjit_committed_array_different_devices(self):
"Devices of all `Array` inputs and outputs should be the same"):
pjit(lambda x, y: (x, y))(a, b)

@jax_array(True)
def test_same_out_sharding_id(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
arr, inp_data = create_array(shape, mesh, P('x', 'y'))

f = pjit(lambda x: x)
out1 = f(arr)
self.assertArraysEqual(out1, inp_data)
out1_sharding_id = id(out1.sharding)

out2 = f(out1)
self.assertArraysEqual(out2, inp_data)
out2_sharding_id = id(out2.sharding)

out3 = f(out2)
self.assertArraysEqual(out3, inp_data)
out3_sharding_id = id(out3.sharding)

self.assertEqual(out1_sharding_id, out2_sharding_id)
self.assertEqual(out1_sharding_id, out3_sharding_id)
self.assertEqual(out2_sharding_id, out3_sharding_id)

@jax_array(True)
def test_out_sharding_indices_id_cache_hit(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
arr, _ = create_array(shape, mesh, P('x', 'y'))

f = pjit(lambda x: x)
out1 = f(arr)
self.assertIsInstance(out1.sharding, OpShardingSharding)
out1.sharding.devices_indices_map(shape)
cache_info1 = OpShardingSharding.devices_indices_map.cache_info()

out2 = f(out1)
self.assertIsInstance(out2.sharding, OpShardingSharding)
out2.sharding.devices_indices_map(shape)
cache_info2 = OpShardingSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)

out3 = f(out2)
self.assertIsInstance(out3.sharding, OpShardingSharding)
out3.sharding.devices_indices_map(shape)
cache_info3 = OpShardingSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)


class TempSharding(Sharding):

def __init__(self, devices):
Expand Down
24 changes: 24 additions & 0 deletions tests/pmap_test.py
Expand Up @@ -3012,6 +3012,30 @@ def dynamic_shape_function(y):

self.assertArraysEqual(w, jnp.cos(jnp.sin(x) ** 2))

@jax_config.jax_array(True)
def test_same_out_sharding_id(self):
if config.jax_disable_jit:
self.skipTest('Skip this under eager pmap mode.')
shape = (jax.device_count(), 2)
arr, inp_data = create_input_array_for_pmap(shape)

f = pmap(lambda x: x)
out1 = f(arr)
self.assertArraysEqual(out1, inp_data)
out1_sharding_id = id(out1.sharding)

out2 = f(out1)
self.assertArraysEqual(out2, inp_data)
out2_sharding_id = id(out2.sharding)

out3 = f(out2)
self.assertArraysEqual(out3, inp_data)
out3_sharding_id = id(out3.sharding)

self.assertEqual(out1_sharding_id, out2_sharding_id)
self.assertEqual(out1_sharding_id, out3_sharding_id)
self.assertEqual(out2_sharding_id, out3_sharding_id)


class EagerPmapMixin:

Expand Down

0 comments on commit a24726d

Please sign in to comment.