Skip to content

Commit

Permalink
Any devices passed to jax.sharding.Mesh are required to be hashable.
Browse files Browse the repository at this point in the history
This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
  • Loading branch information
yashk2810 authored and jax authors committed Aug 29, 2023
1 parent ff5b480 commit 6072d59
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 44 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Remember to align the itemized text with the first line of an item within a list
`JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback).
* jax2tf default serialization version is now 7, which introduces new shape
[safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism).
* Devices passed to `jax.sharding.Mesh` should be hashable. This specifically
applies to mock devices or user created devices. `jax.devices()` are
already hashable.

* Breaking changes:
* jax2tf now uses native serialization by default. See
Expand Down
79 changes: 36 additions & 43 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,35 @@ def __repr__(self):
return f"ResourceEnv({self.physical_mesh!r}, {self.loops!r})"


@functools.lru_cache(maxsize=128)
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
if global_mesh.empty:
return global_mesh
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(global_mesh.devices.ndim):
other_axes = util.tuple_delete(tuple(range(global_mesh.devices.ndim)), axis)
# NOTE: This re-reduces over many axes multiple times, so we could definitely
# optimize it, but I hope it won't be a bottleneck anytime soon.
local_slices = is_local_device.any(other_axes, keepdims=False)
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(global_mesh.devices[subcube_indices], global_mesh.axis_names)


_mesh_object_dict = {} # type: ignore


Expand Down Expand Up @@ -156,28 +185,16 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
axis_names = tuple(axis_names)
assert devices.ndim == len(axis_names)

flat_devices = tuple(devices.flat)

# TODO(yashkatariya): Make Mock Devices hashable and them remove this
# workaround
_use_cache = True
try:
hash(flat_devices[0])
except:
_use_cache = False

if _use_cache:
key = (axis_names, devices.shape, flat_devices)
val = _mesh_object_dict.get(key, None)
if val is not None:
return val
key = (axis_names, devices.shape, tuple(devices.flat))
val = _mesh_object_dict.get(key, None)
if val is not None:
return val

self = super(Mesh, cls).__new__(cls)
self.devices = devices.copy()
self.devices.flags.writeable = False
self.axis_names = axis_names
if _use_cache:
_mesh_object_dict[key] = self
_mesh_object_dict[key] = self
return self

def __reduce__(self):
Expand Down Expand Up @@ -248,36 +265,12 @@ def empty(self):
def is_multi_process(self):
return self.devices.size != len(self.local_devices)

@functools.cached_property
@property
def local_mesh(self):
return self._local_mesh(xb.process_index())

def _local_mesh(self, process_index):
if self.empty:
return self
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(self.devices)
subcube_indices = []
# We take the smallest slice of each dimension that doesn't skip any local device.
for axis in range(self.devices.ndim):
other_axes = util.tuple_delete(tuple(range(self.devices.ndim)), axis)
# NOTE: This re-reduces over many axes multiple times, so we could definitely
# optimize it, but I hope it won't be a bottleneck anytime soon.
local_slices = is_local_device.any(other_axes, keepdims=False)
nonzero_indices = np.flatnonzero(local_slices)
start, end = int(np.min(nonzero_indices)), int(np.max(nonzero_indices))
subcube_indices.append(slice(start, end + 1))
subcube_indices = tuple(subcube_indices)
# We only end up with all conditions being true if the local devices formed a
# subcube of the full array. This is because we were biased towards taking a
# "hull" spanned by the devices, and in case the local devices don't form a
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(self.devices[subcube_indices], self.axis_names)
return _get_local_mesh(self, process_index)

@functools.cached_property
def device_ids(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/mesh_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from jax._src import test_util


@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class MockTpuDevice:
"""Mock TPU device for testing."""
id: int
Expand Down

0 comments on commit 6072d59

Please sign in to comment.