Skip to content

Commit

Permalink
Disintegrate Array into DeviceBuffers inside GDA. This is required …
Browse files Browse the repository at this point in the history
…for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on.

PiperOrigin-RevId: 477297406
  • Loading branch information
yashk2810 authored and jax authors committed Sep 27, 2022
1 parent 0919a67 commit e4f2bff
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions jax/experimental/global_device_array.py
Expand Up @@ -300,18 +300,22 @@ def __init__(self,
self.dtype = dtype

def _init_buffers(self, device_buffers):
from jax._src.array import ArrayImpl

self._maybe_device_buffers = None

# ShardedBuffer is the fast path for managing sharded buffers that avoids
# creating python objects for every device.
if xb.use_sharded_buffer:
if isinstance(device_buffers, xb.xla_client.ShardedBuffer):
if isinstance(device_buffers, xc.ShardedBuffer):
# if ShardedBuffer is provided, we don't need to use `_device_buffers`
self._sharded_buffer = device_buffers # type: ignore
elif isinstance(device_buffers[0], DeviceArray): # type: ignore
# if xla_client.Buffer is provided, we convert it to ShardedBuffer.
self._sharded_buffer = xb.xla_client.ShardedBuffer.create_sharded_buffer(
device_buffers)
self._sharded_buffer = xc.ShardedBuffer.create_sharded_buffer(device_buffers)
elif isinstance(device_buffers[0], ArrayImpl):
self._sharded_buffer = None
self._maybe_device_buffers = [db._arrays[0] for db in device_buffers]
else:
# if `device_buffers` is any other types that cannot
# be converted to ShardedBuffer, then we use `device_buffers`.
Expand Down

0 comments on commit e4f2bff

Please sign in to comment.