diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index e627eb6254c3..7b1393d8e2c4 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -13,10 +13,19 @@ Classes .. autoclass:: XLACompatibleSharding :members: :show-inheritance: +.. autoclass:: SingleDeviceSharding + :members: + :show-inheritance: .. autoclass:: NamedSharding :members: :show-inheritance: -.. autoclass:: SingleDeviceSharding +.. autoclass:: PositionalSharding + :members: + :show-inheritance: +.. autoclass:: PmapSharding + :members: + :show-inheritance: +.. autoclass:: GSPMDSharding :members: :show-inheritance: .. autoclass:: PartitionSpec diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index b24a6c4d67cf..05200fd347a7 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -25,13 +25,13 @@ def __str__(self): class PartitionSpec(tuple): - """Tuple describing how to partition tensor into mesh . + """Tuple describing how to partition an array across a mesh of devices. - Each element is either None, string or a tuple of strings. - See``NamedSharding`` class for more details. + Each element is either ``None``, a string, or a tuple of strings. + See the documentation of :class:`jax.sharding.NamedSharding` for more details. - We create a separate class for this so JAX's pytree utilities can distinguish - it from a tuple that should be treated as a pytree. + This class exists so JAX's pytree utilities can distinguish a partition + specifications from tuples that should be treated as pytrees. """ # A sentinel value representing a dim is unconstrained. diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index b5d74e8fb5a2..c4624ccfb9d9 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -38,14 +38,13 @@ def _addressable_devices_indices_map( @util.use_cpp_class(xc.Sharding) class Sharding: - """Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out - across devices. + """Describes how a :class:`jax.Array` is laid out across devices. """ # Abstract methods below that subclasses should implement. @property def device_set(self) -> set[Device]: - """A ``set`` of global devices that this ``Sharding`` spans. + """The set of devices that this :class:`Sharding` spans. In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes. @@ -54,9 +53,9 @@ def device_set(self) -> set[Device]: def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Index | None]: - """A global mapping from device to the slice of the global data it contains. + """Returns a mapping from devices to the array slices each contains. - The devices in this mapping are global devices i.e. includes + The mapping includes all global devices, i.e., including non-addressable devices from other processes. """ raise NotImplementedError('Subclasses should implement this method.') @@ -64,25 +63,29 @@ def devices_indices_map( def shard_shape(self, global_shape: Shape) -> Shape: """Returns the shape of the data on each device. - The shard shape returned by this function is calculated from the global - shape (it takes as an input) and the properties of the sharding. + The shard shape returned by this function is calculated from + ``global_shape`` and the properties of the sharding. """ raise NotImplementedError('Subclasses should implement this method.') def is_equivalent_to(self, other: Sharding, ndim: int) -> bool: - """Returns True if two shardings put the same logical array - (sharded/unsharded) on the same device(s). + """Returns ``True`` if two shardings are equivalent. - For example, every XLACompatibleSharding lowers to GSPMDSharding which - is a general representation. So `jax.sharding.NamedSharding` is equivalent - to `jax.sharding.PositionalSharding` if both of them lower to the same - GSPMDSharding. + Two shardings are equivalent if they place the same logical array shards on + the same devices. + + For example, a :class:`NamedSharding` may be equivalent + to a :class:`PositionalSharding` if both place the same shards of the array + on the same devices. """ raise NotImplementedError('Subclasses should implement this method.') @property def is_fully_replicated(self) -> bool: - """Returns if a sharding is fully replicated on all the devices.""" + """Is this sharding fully replicated? + + A sharding is fully replicated if each device has a complete copy of the + entire data.""" raise NotImplementedError('Subclasses should implement this method.') @property @@ -95,7 +98,9 @@ def memory_kind(self) -> str | None: @functools.cached_property def addressable_devices(self) -> set[Device]: - """A set of devices that are addressable by the current process.""" + """The set of devices in the :class:`Sharding` that are addressable by the + current process. + """ # Add a fast path for single controller runtimes. if xb.process_count() == 1: return self.device_set @@ -104,14 +109,17 @@ def addressable_devices(self) -> set[Device]: @functools.cached_property def is_fully_addressable(self) -> bool: - """True if the current process can address all of the devices in device_set. + """Is this sharding fully addressable? + + A sharding is fully addressable if the current process can address all of + the devices named in the :class:`Sharding`. """ # The pytype disable is because pytype can't recognize a cached property. return len(self.device_set) == len(self.addressable_devices) # type: ignore def addressable_devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Index | None]: - """A mapping from addressable device to the slice of global data it contains. + """A mapping from addressable devices to the slice of array data each contains. ``addressable_devices_indices_map`` contains that part of ``device_indices_map`` that applies to the addressable devices. diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 7eb14a7b424f..243ed5eb45ca 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -22,7 +22,6 @@ import functools import itertools import math -import operator as op from typing import Any, NamedTuple, Union, cast from jax._src import mesh as mesh_lib @@ -51,10 +50,10 @@ # `_device_assignment` property and `_to_xla_hlo_sharding` method. @use_cpp_class(xc.XLACompatibleSharding) class XLACompatibleSharding(sharding.Sharding): - """A `Sharding` that describes shardings expressible to XLA. + """A :class:`Sharding` that describes shardings expressible to XLA. - Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work - with all JAX APIs and transformations that use XLA. + Subclasses of :class:`XLACompatibleSharding` work with + all JAX APIs and transformations that use XLA. """ # Abstract methods below that subclasses should implement. @@ -157,29 +156,30 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] @use_cpp_class(xc.NamedSharding) class NamedSharding(XLACompatibleSharding): - r"""NamedSharding is a way to express ``Sharding``\s using named axes. + r"""A :class:`NamedSharding` expresses sharding using named axes. - ``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name. + A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and + :class:`PartitionSpec` which describes how to shard an array across that + mesh. - ``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid, - where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for - ``Mesh`` is "logical mesh". + A :class:`Mesh` is a multidimensional NumPy array of JAX devices, + where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``. - ``PartitionSpec`` is a tuple, whose elements can be a ``None``, - a mesh axis or a tuple of mesh axes. Each element describes how an input + A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``, + a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, - PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data + ``PartitionSpec('x', 'y')`` says that the first dimension of data is sharded across ``x`` axis of the mesh, and the second dimension is sharded across ``y`` axis of the mesh. The Distributed arrays and automatic parallelization (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) - goes into more details and has diagrams to help explain the concept about - ``Mesh`` and ``PartitionSpec``. + tutorial has more details and diagrams that explain how + :class:`Mesh` and :class:`PartitionSpec` are used. Args: - mesh: A ``jax.sharding.Mesh`` object. - spec: A ``jax.sharding.PartitionSpec`` object. + mesh: A :class:`jax.sharding.Mesh` object. + spec: A :class:`jax.sharding.PartitionSpec` object. Example: @@ -334,7 +334,7 @@ def get_replicated_hlo_sharding(): @use_cpp_class(xc.SingleDeviceSharding) class SingleDeviceSharding(XLACompatibleSharding): - """A subclass of ``XLACompatibleSharding`` that places its data on a single device. + """A :class:`Sharding` that places its data on a single device. Args: device: A single :py:class:`Device`. @@ -398,6 +398,7 @@ def is_fully_replicated(self) -> bool: @use_cpp_class(xc.PmapSharding) class PmapSharding(XLACompatibleSharding): + """Describes a sharding used by :func:`jax.pmap`.""" devices: np.ndarray sharding_spec: sharding_specs.ShardingSpec @@ -443,16 +444,15 @@ def is_equivalent_to(self: PmapSharding, other: PmapSharding, # type: ignore @classmethod def default(cls, shape: Shape, sharded_dim: int = 0, devices: Sequence[xc.Device] | None = None) -> PmapSharding: - """Creates a `PmapSharding` which matches the implicit device order used by - `pmap` if devices is None. If devices is specified, it will use those - devices. + """Creates a :class:`PmapSharding` which matches the default placement + used by :func:`jax.pmap`. Args: shape: The shape of the input array. sharded_dim: Dimension the input array is sharded on. Defaults to 0. - devices: Optional sequence of devices used to create PmapSharding. If not - specified, it will use the implicit device order used by pmap which is - the order of jax.local_devices() + devices: Optional sequence of devices to use. If omitted, the implicit + device order used by pmap is used, which is the order of + :func:`jax.local_devices`. """ # The dtype doesn't matter here. Its only used for creating the # sharding_spec. @@ -571,8 +571,13 @@ def __init__(self, devices: Sequence[xc.Device] | np.ndarray, # Will error if memory_kind does not exist on the device. self._devices[0].memory(self._memory_kind) - shape = property(op.attrgetter('_ids.shape')) - ndim = property(op.attrgetter('_ids.ndim')) + @property + def shape(self): + return self._ids.shape + + @property + def ndim(self): + return self._ids.ndim def __repr__(self) -> str: cls_name = self.__class__.__name__