Skip to content

Commit

Permalink
[Rollback] Convert _arrays to return PyArray instead of PyBuffer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 508827908
  • Loading branch information
yashk2810 authored and jax authors committed Feb 11, 2023
1 parent 61da781 commit 9316188
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 71 deletions.
37 changes: 7 additions & 30 deletions jax/_src/array.py
Expand Up @@ -28,7 +28,6 @@
from jax._src.config import config
from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import api
from jax._src.typing import ArrayLike
from jax.interpreters import mlir
Expand Down Expand Up @@ -105,8 +104,6 @@ def _reconstruct_array(fun, args, arr_state, aval_state):


def _single_device_array_from_buf(buf, committed):
if isinstance(buf, ArrayImpl):
return buf
db = dispatch._set_aval(buf)
return ArrayImpl(db.aval, SingleDeviceSharding(db.device()), [db],
committed=committed, _skip_checks=True)
Expand All @@ -123,7 +120,7 @@ class ArrayImpl(basearray.Array):
_skip_checks: bool
_npy_value: Optional[np.ndarray]

@use_cpp_method()
@use_cpp_method
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
arrays: Union[Sequence[DeviceArray], Sequence[ArrayImpl]],
committed: bool, _skip_checks: bool = False):
Expand Down Expand Up @@ -363,7 +360,6 @@ def __reduce__(self):
'named_shape': self.aval.named_shape}
return (_reconstruct_array, (fun, args, arr_state, aval_state))

@use_cpp_method(xla_extension_version >= 128)
def unsafe_buffer_pointer(self):
if len(self._arrays) != 1:
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
Expand All @@ -379,12 +375,8 @@ def __cuda_array_interface__(self):

def on_device_size_in_bytes(self):
"""Returns the total global on-device size of the array in bytes."""
arr = self._arrays[0]
if hasattr(arr, "_on_device_size_in_bytes"):
per_shard_size = arr._on_device_size_in_bytes() # type: ignore
else:
per_shard_size = arr.on_device_size_in_bytes() # type: ignore
return per_shard_size * len(self.sharding.device_set)
return (self._arrays[0].on_device_size_in_bytes() *
len(self.sharding.device_set))

# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
Expand Down Expand Up @@ -455,7 +447,6 @@ def global_shards(self) -> Sequence[Shard]:
out.append(Shard(global_d, self.sharding, self.shape, array))
return out

@use_cpp_method(xla_extension_version >= 128)
def delete(self):
if self._arrays is None:
return
Expand All @@ -464,7 +455,7 @@ def delete(self):
self._arrays = None
self._npy_value = None

@use_cpp_method()
@use_cpp_method
def is_deleted(self):
if self._arrays is None:
return True
Expand All @@ -477,7 +468,7 @@ def _check_if_deleted(self):
if self.is_deleted():
raise RuntimeError("Array has been deleted.")

@use_cpp_method()
@use_cpp_method
def block_until_ready(self):
self._check_if_deleted()
for db in self._arrays:
Expand All @@ -487,12 +478,6 @@ def block_until_ready(self):
def copy_to_host_async(self):
self._check_if_deleted()
if self._npy_value is None:
if self.is_fully_replicated:
arr = self._arrays[0] # type: ignore
# copy_to_host_async implemented in c++ only for single device arrays.
if hasattr(arr, "_copy_single_device_array_to_host_async"):
arr._copy_single_device_array_to_host_async() # type: ignore
return
try:
self.addressable_shards[0].replica_id
replica_id_exists = True
Expand All @@ -501,23 +486,15 @@ def copy_to_host_async(self):

for s in self.addressable_shards:
if not replica_id_exists or s.replica_id == 0:
if xla_extension_version >= 128:
s.data.copy_to_host_async() # pytype: disable=attribute-error
else:
s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error
s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error

@property
def _value(self) -> np.ndarray:
self._check_if_deleted()

if self._npy_value is None:
if self.is_fully_replicated:
arr = self._arrays[0] # type: ignore
# Conversion to numpy implemented only for single device arrays.
if hasattr(arr, "_single_device_array_to_np_array"):
self._npy_value = arr._single_device_array_to_np_array() # type: ignore
else:
self._npy_value = np.asarray(arr) # type: ignore
self._npy_value = np.asarray(self._arrays[0]) # type: ignore
self._npy_value.flags.writeable = False
return cast(np.ndarray, self._npy_value)

Expand Down
36 changes: 12 additions & 24 deletions jax/_src/dispatch.py
Expand Up @@ -59,7 +59,6 @@
from jax._src.lib import pmap_lib
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
OpShardingSharding, NamedSharding, PartitionSpec,
Sharding)
Expand Down Expand Up @@ -853,6 +852,7 @@ def _dynamic_array_result_handler(sticky_device, aval, env, buf):
return core.DArray(aval.update(shape=tuple(shape)), data)



result_handlers: Dict[
Type[core.AbstractValue],
Callable[[Optional[Device], Any], ResultHandler]] = {}
Expand Down Expand Up @@ -1098,7 +1098,7 @@ def compile_or_get_cached(backend, computation: ir.Module, compile_options,
def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
compile_options: CompileOptions,
backend: Backend) -> Optional[XlaLoadedExecutable]:
"""Looks up `computation` in the persistent compilation cache."""
"""Looks up `computation` in the persisent compilation cache."""
# Avoid import cycle between jax and jax.experimental
from jax.experimental.compilation_cache import compilation_cache as cc

Expand Down Expand Up @@ -1304,11 +1304,8 @@ def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_a

def _device_put_jax_array(x, device: Optional[Device]):
if is_single_device_sharding(x.sharding):
if xla_extension_version >= 128:
return (_copy_array_to_device(x, device),)
else:
x = _copy_device_array_to_device(_set_aval(x._arrays[0]), device)
return (x,)
x = _copy_device_array_to_device(_set_aval(x._arrays[0]), device)
return (x,)
else:
# Round trip via host if x is sharded. SDA also does a round trip via host.
return _device_put_array(x._value, device)
Expand All @@ -1330,7 +1327,7 @@ def _copy_device_array_to_device(
# source and target platforms are the same
if x.device_buffer.device() == device:
# no copying to be done because source equals target
if x.device == device:
if x._device == device:
return x
else:
moved_buf = x.device_buffer # We need to change stickyness
Expand All @@ -1350,34 +1347,25 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
# no copying to be done because there's no target specified
return x

arr = x._arrays[0]
if xb.get_device_backend(device).platform == arr.platform():
buf = x._arrays[0]
if xb.get_device_backend(device).platform == buf.platform():
# source and target platforms are the same
if x.device() == device:
# no copying to be done because source equals target
if x._committed:
return x
else:
if isinstance(arr, array.ArrayImpl):
# Copy to device with the same device will update the stickyness.
moved_array = arr.copy_to_device(device)
else:
moved_array = arr # We need to change stickyness
moved_buf = buf # We need to change stickyness
else:
# move the buffer with a device-to-device copy
moved_array = arr.copy_to_device(device)
moved_buf = buf.copy_to_device(device)
else:
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_array = backend.buffer_from_pyval(np.asarray(arr), device)
if isinstance(moved_array, array.ArrayImpl):
return moved_array
moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
return array.ArrayImpl(
x.aval,
SingleDeviceSharding(moved_array.device()),
[moved_array],
committed=(device is not None),
)
x.aval, SingleDeviceSharding(moved_buf.device()), [moved_buf],
committed=(device is not None))


def _device_put_impl(
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/sharding.py
Expand Up @@ -285,7 +285,7 @@ class NamedSharding(XLACompatibleSharding):
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
"""

@use_cpp_method()
@use_cpp_method
def __init__(
self, mesh: pxla.Mesh, spec: PartitionSpec, _parsed_pspec = None):

Expand Down Expand Up @@ -396,7 +396,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
... jax.devices()[0])
"""

@use_cpp_method()
@use_cpp_method
def __init__(self, device: Device):
self._device = device

Expand Down Expand Up @@ -434,7 +434,7 @@ def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
@use_cpp_class(xc.PmapSharding)
class PmapSharding(XLACompatibleSharding):

@use_cpp_method()
@use_cpp_method
def __init__(self, devices: np.ndarray, sharding_spec: pxla.ShardingSpec):
self.devices = devices
# The sharding spec should be pmap's sharding spec.
Expand Down Expand Up @@ -647,7 +647,7 @@ def __eq__(self, other) -> bool:
@use_cpp_class(xc.OpShardingSharding)
class OpShardingSharding(XLACompatibleSharding):

@use_cpp_method()
@use_cpp_method
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
self._devices = tuple(devices)
self._op_sharding = op_sharding
Expand Down
17 changes: 4 additions & 13 deletions jax/_src/util.py
Expand Up @@ -539,17 +539,8 @@ def wrapper(cls):

return wrapper


def use_cpp_method(is_enabled=True):
def use_cpp_method(f):
"""A helper decorator to exclude methods from the set that are forwarded to C++ class"""
def decorator(f):
if is_enabled:
original_func = _original_func(f)
original_func._use_cpp = True
return f

if not isinstance(is_enabled, bool):
raise TypeError(
"Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)"
)
return decorator
original_func = _original_func(f)
original_func._use_cpp = True
return f

0 comments on commit 9316188

Please sign in to comment.