diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index c27d72adafca..c96374636d61 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -16,7 +16,11 @@ import abc import numpy as np -from typing import Union +from typing import Any, Sequence, Union + +# TODO(jakevdp): fix import cycles and define these. +Shard = Any +Sharding = Any # Array is a type annotation for standard JAX arrays and tracers produced by # core functions in jax.lax and jax.numpy; it is not meant to include @@ -46,11 +50,64 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace __slots__ = ['__weakref__'] - # at property must be defined because we overwrite its docstring in - # lax_numpy.py @property - def at(self): - raise NotImplementedError("property must be defined in subclasses") + @abc.abstractmethod + def dtype(self) -> np.dtype: + """The data type (:class:`numpy.dtype`) of the array.""" + + @property + @abc.abstractmethod + def ndim(self) -> int: + """The number of dimensions in the array.""" + + @property + @abc.abstractmethod + def size(self) -> int: + """The total number of elements in the array.""" + + @property + @abc.abstractmethod + def shape(self) -> tuple[int, ...]: + """The shape of the array.""" + + # Documentation for sharding-related methods and properties defined on ArrayImpl: + @abc.abstractmethod + def addressable_data(self, index: int) -> "Array": + """Return an array of the addressable data at a particular index.""" + + @property + @abc.abstractmethod + def addressable_shards(self) -> Sequence[Shard]: + """List of addressable shards.""" + + @property + @abc.abstractmethod + def global_shards(self) -> Sequence[Shard]: + """List of global shards.""" + + @property + @abc.abstractmethod + def is_fully_addressable(self) -> bool: + """Is this Array fully addressable? + + A jax.Array is fully addressable if the current process can address all of + the devices named in the :class:`Sharding`. ``is_fully_addressable`` is + equivalent to "is_local" in multi-process JAX. + + Note that fully replicated is not equal to fully addressable i.e. + a jax.Array which is fully replicated can span across multiple hosts and is + not fully addressable. + """ + + @property + @abc.abstractmethod + def is_fully_replicated(self) -> bool: + """Is this Array fully replicated?""" + + @property + @abc.abstractmethod + def sharding(self) -> Sharding: + """The sharding for the array.""" Array.__module__ = "jax" diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 437bc594a64d..b2e5d934028a 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -16,7 +16,6 @@ from typing import Any, Callable, Optional, Sequence, Union import numpy as np from jax._src.sharding import Sharding -from jax._src import lib Shard = Any @@ -43,12 +42,6 @@ class Array(abc.ABC): @property def shape(self) -> tuple[int, ...]: ... - @property - def sharding(self) -> Sharding: ... - - @property - def addressable_shards(self) -> Sequence[Shard]: ... - def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None, order=None): raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly." @@ -205,6 +198,10 @@ class Array(abc.ABC): def device(self) -> Device: ... def devices(self) -> set[Device]: ... @property + def sharding(self) -> Sharding: ... + @property + def addressable_shards(self) -> Sequence[Shard]: ... + @property def global_shards(self) -> Sequence[Shard]: ... def is_deleted(self) -> bool: ... @property diff --git a/jax/_src/core.py b/jax/_src/core.py index 6ceb3d48ec9f..fdae6f39c750 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -626,11 +626,22 @@ def check_bool_conversion(arr: Array, warn_on_empty=False): "ambiguous. Use a.any() or a.all()") +def _aval_property(name): + return property(lambda self: getattr(self.aval, name)) + class Tracer(typing.Array): __array_priority__ = 1000 __slots__ = ['_trace', '_line_info'] + dtype = _aval_property('dtype') + ndim = _aval_property('ndim') + size = _aval_property('size') + shape = _aval_property('shape') + + def __init__(self, trace: Trace): + self._trace = trace + def _error_repr(self): if self.aval is None: return f"traced array with aval {self.aval}" @@ -655,9 +666,6 @@ def tobytes(self, order="C"): f"The tobytes() method was called on {self._error_repr()}." f"{self._origin_msg()}") - def __init__(self, trace: Trace): - self._trace = trace - def __iter__(self): return iter(self.aval._iter(self)) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 84b34d5fc32e..8f3043b096e9 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -240,6 +240,7 @@ def _view(arr: Array, dtype: Optional[DTypeLike] = None, type: None = None) -> A def _notimplemented_flat(self): + """Not implemented: Use :meth:`~jax.Array.flatten` instead.""" raise NotImplementedError("JAX Arrays do not implement the arr.flat property: " "consider arr.flatten() instead.") @@ -800,5 +801,3 @@ def register_jax_array_methods(): _set_array_attributes(ArrayImpl) _set_array_abstract_methods(Array) - - Array.at.__doc__ = _IndexUpdateHelper.__doc__