Skip to content

Commit

Permalink
[JAX] Use DeviceList in JAX Sharding implementations
Browse files Browse the repository at this point in the history
XLA-compatible `Sharding` implementations keep a `DeviceList` object as
`_internal_device_list`. This is used for finding the default memory kind more
quickly in C++, and enables caching of the default memory kind between multiple
`NamedSharding` objects that shares the same `Mesh`. Also it uses an
addressable device within `DeviceList`, which will be required for supporting
multiple device types with different default memory kinds.

PiperOrigin-RevId: 556969789
  • Loading branch information
hyeontaek authored and jax authors committed Aug 15, 2023
1 parent 5ffca47 commit 423c8d8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 35 deletions.
12 changes: 12 additions & 0 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,18 @@ def addressable_device_list(self) -> _DeviceAssignment: # type: ignore
tuple(d for d in self._device_assignment
if d.process_index == d.client.process_index()))

@cached_property
def memory_kinds(self) -> tuple[str, ...]:
# Keep this method unimplemented as it will not be called if
# xla_extension_version is low.
raise NotImplementedError("memory_kinds is not supported")

@cached_property
def default_memory_kind(self) -> Optional[str]:
# Keep this method unimplemented as it will not be called if
# xla_extension_version is low.
raise NotImplementedError("default_memory_kind is not supported")


@lru_cache(maxsize=2048)
def _create_da_object( # pytype: disable=invalid-annotation
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ def _local_devices_set(self):
def _flat_devices_tuple(self):
return tuple(self.devices.flat)

@functools.cached_property
def _internal_device_list(self):
return xc.DeviceList(self._flat_devices_tuple)

@functools.cached_property
def _flat_devices_set(self):
return set(self.devices.flat)
Expand Down
46 changes: 11 additions & 35 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,6 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out


# This is an optimization to get the memory kinds associated with the local
# devices. This is because in McJAX, checking if the memory kind input by user
# is correct requires doing `local_devices()[0].memory(inp)` which is expensive
# because calculating the local devices is expensive. So cache on xc.Client and
# find all the memories associated only once since the client does not change.
@functools.lru_cache
def _mem_kinds(client: xc.Client) -> set[str]:
return set(m.kind for m in client.local_devices()[0].addressable_memories())

def _check_mem_kind(device: xc.Device, mk):
mem_kinds = _mem_kinds(device.client)
if mk not in mem_kinds:
raise ValueError(
f'Could not find memory addressable by device {device.device_kind}.'
f' Device {device.device_kind} can address the following memory kinds:'
f' {mem_kinds}. Got memory kind: {mk}')


@use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding):
r"""A :class:`NamedSharding` expresses sharding using named axes.
Expand Down Expand Up @@ -237,10 +219,6 @@ def __init__(
self._preprocess()

def _preprocess(self):
if self.memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(self.mesh.devices.flat[0], self.memory_kind)

# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
# axis_resources for vmap handlers. In such cases you need to preserve the
Expand Down Expand Up @@ -584,9 +562,6 @@ def _op_sharding_to_pos_sharding(
ids = np.array(
[DeviceIdSet(name, i) for i in op_sharding.tile_assignment_devices()]
)
if memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(device_assignment[0], memory_kind)
p = PositionalSharding._remake(tuple(device_assignment), ids,
memory_kind=memory_kind)
p = p.reshape(op_sharding.tile_assignment_dimensions())
Expand All @@ -612,12 +587,10 @@ def __init__(self, devices: Sequence[xc.Device] | np.ndarray,
name = self._devices[0].platform.upper()
self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)],
dtype='object').reshape(devices.shape)
if self._memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(self._devices[0], self._memory_kind)
if xla_extension_version >= 177:
self._memory_kind = xc.canonicalize_memory_kind(
self._memory_kind, self._devices[0])
if xla_extension_version >= 182:
self._internal_device_list = xc.DeviceList(self._devices)
self._memory_kind = xc.check_and_canonicalize_memory_kind(
self._memory_kind, self._internal_device_list)

@property
def shape(self):
Expand Down Expand Up @@ -768,10 +741,13 @@ def __init__(self, devices: Sequence[Device],
self._hlo_sharding = op_sharding
self._memory_kind = memory_kind

def _preprocess(self):
if self._memory_kind is not None:
# Will error if memory_kind does not exist on the device.
_check_mem_kind(self._devices[0], self._memory_kind)
if xla_extension_version < 182:
def _preprocess(self):
# Preprocessing is no longer necessary, but the method must exist for a
# previous release of jaxlib that calls back this method from C++>
# TODO(yashkatariya): Remove this method once jaxlib with
# xla_extension_version >= 182 is released.
pass

def __reduce__(self):
return (type(self), (self._devices, self._hlo_sharding.to_proto()),
Expand Down

0 comments on commit 423c8d8

Please sign in to comment.