From e4f2bff0a38fecd9fdd2333370a0c9ec98ad8320 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 27 Sep 2022 16:01:46 -0700 Subject: [PATCH] Disintegrate `Array` into DeviceBuffers inside GDA. This is required for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on. PiperOrigin-RevId: 477297406 --- jax/experimental/global_device_array.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 1993f450a2d0..c6c086da0dfc 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -300,18 +300,22 @@ def __init__(self, self.dtype = dtype def _init_buffers(self, device_buffers): + from jax._src.array import ArrayImpl + self._maybe_device_buffers = None # ShardedBuffer is the fast path for managing sharded buffers that avoids # creating python objects for every device. if xb.use_sharded_buffer: - if isinstance(device_buffers, xb.xla_client.ShardedBuffer): + if isinstance(device_buffers, xc.ShardedBuffer): # if ShardedBuffer is provided, we don't need to use `_device_buffers` self._sharded_buffer = device_buffers # type: ignore elif isinstance(device_buffers[0], DeviceArray): # type: ignore # if xla_client.Buffer is provided, we convert it to ShardedBuffer. - self._sharded_buffer = xb.xla_client.ShardedBuffer.create_sharded_buffer( - device_buffers) + self._sharded_buffer = xc.ShardedBuffer.create_sharded_buffer(device_buffers) + elif isinstance(device_buffers[0], ArrayImpl): + self._sharded_buffer = None + self._maybe_device_buffers = [db._arrays[0] for db in device_buffers] else: # if `device_buffers` is any other types that cannot # be converted to ShardedBuffer, then we use `device_buffers`.