Skip to content

Commit

Permalink
Add addressable_ counterparts of local_ to GDA to make it easier …
Browse files Browse the repository at this point in the history
…for users to move to Array as both will have the same API.

PiperOrigin-RevId: 477332697
  • Loading branch information
yashk2810 authored and jax authors committed Sep 28, 2022
1 parent e4f2bff commit c8bff11
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jax/experimental/global_device_array.py
Expand Up @@ -394,6 +394,9 @@ def _create_local_shards(self) -> Sequence[Shard]:
def local_shards(self) -> Sequence[Shard]:
return self._create_local_shards()

def addressable_shards(self) -> Sequence[Shard]:
return self.local_shards

@property
def global_shards(self) -> Sequence[Shard]:
if self.mesh.size == len(self._local_devices):
Expand Down Expand Up @@ -440,6 +443,9 @@ def __array__(self, dtype=None, context=None):
def local_data(self, index) -> DeviceArray:
return pxla._set_aval(self._device_buffers[index])

def addressable_data(self, index) -> DeviceArray:
return self.local_data(index)

def block_until_ready(self):
# self._sharded_buffer can be None if xla_extension_version < 90 or
# _DeviceArray is used.
Expand Down

0 comments on commit c8bff11

Please sign in to comment.