Skip to content

Commit

Permalink
Optimize jax.device_put() dispatch for 1:1 device-to-device transfers
Browse files Browse the repository at this point in the history
* Cache the sharding index comparison in addition to sharding index calculation. This helps when the list of indices is expensive to compare.
* Remove caching from `pxla.get_addressable_devices_for_shard_arg()` since `sharding._addressable_device_assignment` is already a cached property.
* Use `a is b` instead of `id(a) == id(b)` since the former is more concise.

PiperOrigin-RevId: 627080325
  • Loading branch information
junwhanahn authored and jax authors committed Apr 22, 2024
1 parent 1b1c6e7 commit 4be25d7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
16 changes: 11 additions & 5 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,20 +928,26 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
return pxla.batched_device_put(x.aval, sharding, bufs, devices)


@functools.lru_cache(maxsize=4096)
def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
src_indices = src_sharding.addressable_devices_indices_map(shape).values()
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
return dst_indices, tuple(src_indices) == tuple(dst_indices)


def _array_shard_arg(x, sharding):
x._check_if_deleted()

x_indices = x.sharding.addressable_devices_indices_map(x.shape).values()
indices = sharding.addressable_devices_indices_map(x.shape).values()
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
if not x.is_fully_addressable:
if tuple(x_indices) == tuple(indices):
if same_indices:
return x
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
devices = pxla.get_addressable_devices_for_shard_arg(sharding)
if tuple(x_indices) == tuple(indices):
devices = sharding._addressable_device_assignment
if same_indices:
return xc.copy_array_to_devices_with_sharding(x, list(devices), sharding)
# Resharding starts here:
if dispatch.is_single_device_sharding(x.sharding):
Expand Down
7 changes: 1 addition & 6 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,6 @@ def shard_args(
shard_arg_handlers: dict[Any, Callable[[Any, Any], Any]] = {}


@lru_cache(maxsize=1024)
def get_addressable_devices_for_shard_arg(
s: sharding_impls.XLACompatibleSharding) -> tuple[xc.Device, ...]:
return s._addressable_device_assignment

@lru_cache(maxsize=1024)
def _get_replicated_slices(num_addressable_devices: int):
return ((slice(None),),) * num_addressable_devices
Expand All @@ -138,7 +133,7 @@ def _masked_array_error(x, sharding):
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error

def _shard_array(x, sharding):
devices = get_addressable_devices_for_shard_arg(sharding)
devices = sharding._addressable_device_assignment
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
Expand Down
15 changes: 7 additions & 8 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,13 @@ def __hash__(self):
def __eq__(self, other):
if not isinstance(other, NamedSharding):
return False
if id(self) == id(other):
if self is other:
return True
if (self._parsed_pspec != other._parsed_pspec
or self.memory_kind != other.memory_kind
or self._manual_axes != other._manual_axes):
return False
return id(self.mesh) == id(other.mesh) or self.mesh == other.mesh
return self.mesh is other.mesh or self.mesh == other.mesh

def is_compatible_aval(self, aval_shape: Shape):
assert self._parsed_pspec is not None
Expand Down Expand Up @@ -422,7 +422,7 @@ def __hash__(self):
def __eq__(self, other):
if not isinstance(other, SingleDeviceSharding):
return False
if id(self) == id(other):
if self is other:
return True
return (self._device == other._device and
self.memory_kind == other.memory_kind)
Expand Down Expand Up @@ -485,7 +485,7 @@ def __reduce__(self):
def __eq__(self, other):
if not isinstance(other, PmapSharding):
return False
if id(self) == id(other):
if self is other:
return True
return (self.sharding_spec == other.sharding_spec and
self.devices.shape == other.devices.shape and
Expand Down Expand Up @@ -741,12 +741,11 @@ def __hash__(self) -> int:
def __eq__(self, other) -> bool:
if not isinstance(other, PositionalSharding):
return False
if id(self) == id(other):
if self is other:
return True
all_ids_equal = np.array_equal(self._ids,other._ids)
mem_kind_equal = self.memory_kind == other.memory_kind
if (id(self._devices) == id(other._devices) and mem_kind_equal and
all_ids_equal):
if self._devices is other._devices and mem_kind_equal and all_ids_equal:
return True
return (mem_kind_equal and all_ids_equal and
self._internal_device_list == other._internal_device_list)
Expand Down Expand Up @@ -852,7 +851,7 @@ def _hlo_sharding_hash(self):
def __eq__(self, other):
if not isinstance(other, GSPMDSharding):
return False
if id(self) == id(other):
if self is other:
return True
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self.memory_kind == other.memory_kind
Expand Down

0 comments on commit 4be25d7

Please sign in to comment.