Skip to content

Commit

Permalink
Remove device_replica_id_map from the Sharding interface because th…
Browse files Browse the repository at this point in the history
…e standalone function should be more than enough to use. The major use-case of this is for checkpointing and accessing addressable_shards which accesses the standalone function makes it work.

PiperOrigin-RevId: 470820443
  • Loading branch information
yashk2810 authored and jax authors committed Aug 29, 2022
1 parent c1217be commit fc7a71d
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 27 deletions.
13 changes: 4 additions & 9 deletions jax/experimental/array.py
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import operator as op
import functools
import numpy as np
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List

Expand All @@ -31,8 +32,8 @@
from jax._src.api import device_put
from jax._src.numpy.ndarray import ndarray
from jax.interpreters import pxla, xla, mlir
from jax.experimental.sharding import (Sharding, SingleDeviceSharding,
XLACompatibleSharding)
from jax.experimental.sharding import (
Sharding, SingleDeviceSharding, XLACompatibleSharding, device_replica_id_map)

Shape = Tuple[int, ...]
Device = xc.Device
Expand Down Expand Up @@ -81,13 +82,7 @@ def index(self) -> Index:

@property
def replica_id(self) -> int:
try:
device_replica_id_fn = self._sharding.device_replica_id_map # pytype: disable=attribute-error
except AttributeError:
raise ValueError('Cannot calculate replica ids from sharding: '
f'{self._sharding}. Please create a device to replica id '
'mapping for your sharding.') from None
return device_replica_id_fn(self._global_shape)[self.device]
return device_replica_id_map(self._sharding, self._global_shape)[self.device]


def _reconstruct_array(fun, args, arr_state, aval_state):
Expand Down
28 changes: 10 additions & 18 deletions jax/experimental/sharding.py
Expand Up @@ -78,10 +78,6 @@ class XLACompatibleSharding(Sharding):
def _device_assignment(self) -> XLADeviceAssignment:
raise NotImplementedError('Subclasses should implement this method.')

@abc.abstractmethod
def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
raise NotImplementedError('Subclasses should implement this method.')

@abc.abstractmethod
def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
raise NotImplementedError('Subclasses should implement this method.')
Expand Down Expand Up @@ -133,10 +129,18 @@ def _hashed_index(x) -> int:


@functools.lru_cache(maxsize=4096)
def _device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]:
try:
device_indices_map_fn = sharding.devices_indices_map
except AttributeError:
raise ValueError(
f'Cannot calculate replica ids from sharding: {sharding}. Please '
'create a device to index mapping for your sharding from which replica '
'ids will be calculated.') from None

index_to_replica: Dict[int, int] = Counter()
out = {}
for device, index in sharding.devices_indices_map(global_shape).items():
for device, index in device_indices_map_fn(global_shape).items():
h_index = _hashed_index(index)
replica_id = index_to_replica[h_index]
index_to_replica[h_index] += 1
Expand Down Expand Up @@ -208,9 +212,6 @@ def devices_indices_map(
# `get_shard_indices` is cached.
return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec)

def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
return _device_replica_id_map(self, global_shape)

@pxla.maybe_cached_property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.mesh.devices.flat)
Expand Down Expand Up @@ -270,9 +271,6 @@ def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index]:
return {self._device: (slice(None),) * len(global_shape)}

def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
return {self._device: 0}

@property
def _device_assignment(self) -> XLADeviceAssignment:
return [self._device]
Expand All @@ -298,9 +296,6 @@ def devices_indices_map(
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore

def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
return _device_replica_id_map(self, global_shape)

@pxla.maybe_cached_property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.devices.flat)
Expand Down Expand Up @@ -374,9 +369,6 @@ def devices_indices_map(
len(self._devices))
return dict(safe_zip(self._devices, indices))

def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
return _device_replica_id_map(self, global_shape)

@property
def _device_assignment(self) -> XLADeviceAssignment:
return list(self._devices)
Expand Down
49 changes: 49 additions & 0 deletions tests/array_test.py
Expand Up @@ -61,6 +61,54 @@ def test_jax_array_value(self, mesh_axes):
self.assertArraysEqual(arr._value, global_data)
self.assertArraysEqual(arr._npy_value, global_data)

@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
s = sharding.MeshPspecSharding(global_mesh, mesh_axes)
arr, global_input_data = create_array(global_input_shape, s)
self.assertEqual(arr.ndim, 2)
self.assertEqual(arr.size, 16)
self.assertEqual(arr.addressable_shards[0].index, expected_index[0])
self.assertEqual(arr.addressable_shards[1].index, expected_index[1])
replica_ids = [i.replica_id for i in arr.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in arr.addressable_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(arr.is_fully_replicated(), expected_is_fully_replicated)
for s in arr.addressable_shards:
self.assertEqual(s.data.aval,
jax.ShapedArray(expected_shard_shape, s.data.dtype))
self.assertArraysEqual(s.data, global_input_data[s.index])

def test_array_delete(self):
with jax_config.jax_array(True):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Expand All @@ -84,6 +132,7 @@ def test_device_put(self):
self.assertArraysEqual(i.data, numpy_array)
self.assertEqual(i.device, jax.devices()[0])
self.assertEqual(i.index, (slice(None),))
self.assertEqual(i.replica_id, 0)

def test_device_put_array_delete(self):
with jax_config.jax_array(True):
Expand Down

0 comments on commit fc7a71d

Please sign in to comment.