Skip to content

Commit

Permalink
Ensure sharding-related array properties are documented
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 3, 2023
1 parent c9db50c commit cd3ea05
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 17 deletions.
67 changes: 62 additions & 5 deletions jax/_src/basearray.py
Expand Up @@ -16,7 +16,11 @@

import abc
import numpy as np
from typing import Union
from typing import Any, Sequence, Union

# TODO(jakevdp): fix import cycles and define these.
Shard = Any
Sharding = Any

# Array is a type annotation for standard JAX arrays and tracers produced by
# core functions in jax.lax and jax.numpy; it is not meant to include
Expand Down Expand Up @@ -46,11 +50,64 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace

__slots__ = ['__weakref__']

# at property must be defined because we overwrite its docstring in
# lax_numpy.py
@property
def at(self):
raise NotImplementedError("property must be defined in subclasses")
@abc.abstractmethod
def dtype(self) -> np.dtype:
"""The data type (:class:`numpy.dtype`) of the array."""

@property
@abc.abstractmethod
def ndim(self) -> int:
"""The number of dimensions in the array."""

@property
@abc.abstractmethod
def size(self) -> int:
"""The total number of elements in the array."""

@property
@abc.abstractmethod
def shape(self) -> tuple[int, ...]:
"""The shape of the array."""

# Documentation for sharding-related methods and properties defined on ArrayImpl:
@abc.abstractmethod
def addressable_data(self, index: int) -> "Array":
"""Return an array of the addressable data at a particular index."""

@property
@abc.abstractmethod
def addressable_shards(self) -> Sequence[Shard]:
"""List of addressable shards."""

@property
@abc.abstractmethod
def global_shards(self) -> Sequence[Shard]:
"""List of global shards."""

@property
@abc.abstractmethod
def is_fully_addressable(self) -> bool:
"""Is this Array fully addressable?
A jax.Array is fully addressable if the current process can address all of
the devices named in the :class:`Sharding`. ``is_fully_addressable`` is
equivalent to "is_local" in multi-process JAX.
Note that fully replicated is not equal to fully addressable i.e.
a jax.Array which is fully replicated can span across multiple hosts and is
not fully addressable.
"""

@property
@abc.abstractmethod
def is_fully_replicated(self) -> bool:
"""Is this Array fully replicated?"""

@property
@abc.abstractmethod
def sharding(self) -> Sharding:
"""The sharding for the array."""


Array.__module__ = "jax"
Expand Down
11 changes: 4 additions & 7 deletions jax/_src/basearray.pyi
Expand Up @@ -16,7 +16,6 @@ from typing import Any, Callable, Optional, Sequence, Union
import numpy as np

from jax._src.sharding import Sharding
from jax._src import lib

Shard = Any

Expand All @@ -43,12 +42,6 @@ class Array(abc.ABC):
@property
def shape(self) -> tuple[int, ...]: ...

@property
def sharding(self) -> Sharding: ...

@property
def addressable_shards(self) -> Sequence[Shard]: ...

def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
Expand Down Expand Up @@ -205,6 +198,10 @@ class Array(abc.ABC):
def device(self) -> Device: ...
def devices(self) -> set[Device]: ...
@property
def sharding(self) -> Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
@property
def global_shards(self) -> Sequence[Shard]: ...
def is_deleted(self) -> bool: ...
@property
Expand Down
14 changes: 11 additions & 3 deletions jax/_src/core.py
Expand Up @@ -626,11 +626,22 @@ def check_bool_conversion(arr: Array, warn_on_empty=False):
"ambiguous. Use a.any() or a.all()")


def _aval_property(name):
return property(lambda self: getattr(self.aval, name))


class Tracer(typing.Array):
__array_priority__ = 1000
__slots__ = ['_trace', '_line_info']

dtype = _aval_property('dtype')
ndim = _aval_property('ndim')
size = _aval_property('size')
shape = _aval_property('shape')

def __init__(self, trace: Trace):
self._trace = trace

def _error_repr(self):
if self.aval is None:
return f"traced array with aval {self.aval}"
Expand All @@ -655,9 +666,6 @@ def tobytes(self, order="C"):
f"The tobytes() method was called on {self._error_repr()}."
f"{self._origin_msg()}")

def __init__(self, trace: Trace):
self._trace = trace

def __iter__(self):
return iter(self.aval._iter(self))

Expand Down
3 changes: 1 addition & 2 deletions jax/_src/numpy/array_methods.py
Expand Up @@ -240,6 +240,7 @@ def _view(arr: Array, dtype: Optional[DTypeLike] = None, type: None = None) -> A


def _notimplemented_flat(self):
"""Not implemented: Use :meth:`~jax.Array.flatten` instead."""
raise NotImplementedError("JAX Arrays do not implement the arr.flat property: "
"consider arr.flatten() instead.")

Expand Down Expand Up @@ -800,5 +801,3 @@ def register_jax_array_methods():
_set_array_attributes(ArrayImpl)

_set_array_abstract_methods(Array)

Array.at.__doc__ = _IndexUpdateHelper.__doc__

0 comments on commit cd3ea05

Please sign in to comment.