Skip to content

Commit

Permalink
Improve the sharding documentation.
Browse files Browse the repository at this point in the history
* do some proofreading.
* add PmapSharding and GSPMDSharding, which are both missing.
  • Loading branch information
hawkinsp committed Aug 3, 2023
1 parent a184b5e commit d4336c1
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 48 deletions.
11 changes: 10 additions & 1 deletion docs/jax.sharding.rst
Expand Up @@ -13,10 +13,19 @@ Classes
.. autoclass:: XLACompatibleSharding
:members:
:show-inheritance:
.. autoclass:: SingleDeviceSharding
:members:
:show-inheritance:
.. autoclass:: NamedSharding
:members:
:show-inheritance:
.. autoclass:: SingleDeviceSharding
.. autoclass:: PositionalSharding
:members:
:show-inheritance:
.. autoclass:: PmapSharding
:members:
:show-inheritance:
.. autoclass:: GSPMDSharding
:members:
:show-inheritance:
.. autoclass:: PartitionSpec
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/partition_spec.py
Expand Up @@ -25,13 +25,13 @@ def __str__(self):


class PartitionSpec(tuple):
"""Tuple describing how to partition tensor into mesh .
"""Tuple describing how to partition an array across a mesh of devices.
Each element is either None, string or a tuple of strings.
See``NamedSharding`` class for more details.
Each element is either ``None``, a string, or a tuple of strings.
See the documentation of :class:`jax.sharding.NamedSharding` for more details.
We create a separate class for this so JAX's pytree utilities can distinguish
it from a tuple that should be treated as a pytree.
This class exists so JAX's pytree utilities can distinguish a partition
specifications from tuples that should be treated as pytrees.
"""

# A sentinel value representing a dim is unconstrained.
Expand Down
42 changes: 25 additions & 17 deletions jax/_src/sharding.py
Expand Up @@ -38,14 +38,13 @@ def _addressable_devices_indices_map(

@util.use_cpp_class(xc.Sharding)
class Sharding:
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
across devices.
"""Describes how a :class:`jax.Array` is laid out across devices.
"""

# Abstract methods below that subclasses should implement.
@property
def device_set(self) -> set[Device]:
"""A ``set`` of global devices that this ``Sharding`` spans.
"""The set of devices that this :class:`Sharding` spans.
In multi-controller JAX, the set of devices is global, i.e., includes
non-addressable devices from other processes.
Expand All @@ -54,35 +53,39 @@ def device_set(self) -> set[Device]:

def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index | None]:
"""A global mapping from device to the slice of the global data it contains.
"""Returns a mapping from devices to the array slices each contains.
The devices in this mapping are global devices i.e. includes
The mapping includes all global devices, i.e., including
non-addressable devices from other processes.
"""
raise NotImplementedError('Subclasses should implement this method.')

def shard_shape(self, global_shape: Shape) -> Shape:
"""Returns the shape of the data on each device.
The shard shape returned by this function is calculated from the global
shape (it takes as an input) and the properties of the sharding.
The shard shape returned by this function is calculated from
``global_shape`` and the properties of the sharding.
"""
raise NotImplementedError('Subclasses should implement this method.')

def is_equivalent_to(self, other: Sharding, ndim: int) -> bool:
"""Returns True if two shardings put the same logical array
(sharded/unsharded) on the same device(s).
"""Returns ``True`` if two shardings are equivalent.
For example, every XLACompatibleSharding lowers to GSPMDSharding which
is a general representation. So `jax.sharding.NamedSharding` is equivalent
to `jax.sharding.PositionalSharding` if both of them lower to the same
GSPMDSharding.
Two shardings are equivalent if they place the same logical array shards on
the same devices.
For example, a :class:`NamedSharding` may be equivalent
to a :class:`PositionalSharding` if both place the same shards of the array
on the same devices.
"""
raise NotImplementedError('Subclasses should implement this method.')

@property
def is_fully_replicated(self) -> bool:
"""Returns if a sharding is fully replicated on all the devices."""
"""Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the
entire data."""
raise NotImplementedError('Subclasses should implement this method.')

@property
Expand All @@ -95,7 +98,9 @@ def memory_kind(self) -> str | None:

@functools.cached_property
def addressable_devices(self) -> set[Device]:
"""A set of devices that are addressable by the current process."""
"""The set of devices in the :class:`Sharding` that are addressable by the
current process.
"""
# Add a fast path for single controller runtimes.
if xb.process_count() == 1:
return self.device_set
Expand All @@ -104,14 +109,17 @@ def addressable_devices(self) -> set[Device]:

@functools.cached_property
def is_fully_addressable(self) -> bool:
"""True if the current process can address all of the devices in device_set.
"""Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of
the devices named in the :class:`Sharding`.
"""
# The pytype disable is because pytype can't recognize a cached property.
return len(self.device_set) == len(self.addressable_devices) # type: ignore

def addressable_devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Index | None]:
"""A mapping from addressable device to the slice of global data it contains.
"""A mapping from addressable devices to the slice of array data each contains.
``addressable_devices_indices_map`` contains that part of
``device_indices_map`` that applies to the addressable devices.
Expand Down
55 changes: 30 additions & 25 deletions jax/_src/sharding_impls.py
Expand Up @@ -22,7 +22,6 @@
import functools
import itertools
import math
import operator as op
from typing import Any, NamedTuple, Union, cast

from jax._src import mesh as mesh_lib
Expand Down Expand Up @@ -51,10 +50,10 @@
# `_device_assignment` property and `_to_xla_hlo_sharding` method.
@use_cpp_class(xc.XLACompatibleSharding)
class XLACompatibleSharding(sharding.Sharding):
"""A `Sharding` that describes shardings expressible to XLA.
"""A :class:`Sharding` that describes shardings expressible to XLA.
Any ``Sharding`` that is a subclass of ``XLACompatibleSharding`` will work
with all JAX APIs and transformations that use XLA.
Subclasses of :class:`XLACompatibleSharding` work with
all JAX APIs and transformations that use XLA.
"""

# Abstract methods below that subclasses should implement.
Expand Down Expand Up @@ -157,29 +156,30 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]

@use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding):
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
r"""A :class:`NamedSharding` expresses sharding using named axes.
``Mesh`` and ``PartitionSpec`` can be used to express a ``Sharding`` with a name.
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
:class:`PartitionSpec` which describes how to shard an array across that
mesh.
``Mesh`` is a NumPy array of JAX devices in a multi-dimensional grid,
where each axis of the mesh has a name, e.g. 'x' or 'y'. Another name for
``Mesh`` is "logical mesh".
A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.
``PartitionSpec`` is a tuple, whose elements can be a ``None``,
a mesh axis or a tuple of mesh axes. Each element describes how an input
A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
a mesh axis, or a tuple of mesh axes. Each element describes how an input
dimension is partitioned across zero or more mesh dimensions. For example,
PartitionSpec('x', 'y') is a PartitionSpec where the first dimension of data
``PartitionSpec('x', 'y')`` says that the first dimension of data
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
across ``y`` axis of the mesh.
The Distributed arrays and automatic parallelization
(https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names)
goes into more details and has diagrams to help explain the concept about
``Mesh`` and ``PartitionSpec``.
tutorial has more details and diagrams that explain how
:class:`Mesh` and :class:`PartitionSpec` are used.
Args:
mesh: A ``jax.sharding.Mesh`` object.
spec: A ``jax.sharding.PartitionSpec`` object.
mesh: A :class:`jax.sharding.Mesh` object.
spec: A :class:`jax.sharding.PartitionSpec` object.
Example:
Expand Down Expand Up @@ -334,7 +334,7 @@ def get_replicated_hlo_sharding():

@use_cpp_class(xc.SingleDeviceSharding)
class SingleDeviceSharding(XLACompatibleSharding):
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
"""A :class:`Sharding` that places its data on a single device.
Args:
device: A single :py:class:`Device`.
Expand Down Expand Up @@ -398,6 +398,7 @@ def is_fully_replicated(self) -> bool:

@use_cpp_class(xc.PmapSharding)
class PmapSharding(XLACompatibleSharding):
"""Describes a sharding used by :func:`jax.pmap`."""
devices: np.ndarray
sharding_spec: sharding_specs.ShardingSpec

Expand Down Expand Up @@ -443,16 +444,15 @@ def is_equivalent_to(self: PmapSharding, other: PmapSharding, # type: ignore
@classmethod
def default(cls, shape: Shape, sharded_dim: int = 0,
devices: Sequence[xc.Device] | None = None) -> PmapSharding:
"""Creates a `PmapSharding` which matches the implicit device order used by
`pmap` if devices is None. If devices is specified, it will use those
devices.
"""Creates a :class:`PmapSharding` which matches the default placement
used by :func:`jax.pmap`.
Args:
shape: The shape of the input array.
sharded_dim: Dimension the input array is sharded on. Defaults to 0.
devices: Optional sequence of devices used to create PmapSharding. If not
specified, it will use the implicit device order used by pmap which is
the order of jax.local_devices()
devices: Optional sequence of devices to use. If omitted, the implicit
device order used by pmap is used, which is the order of
:func:`jax.local_devices`.
"""
# The dtype doesn't matter here. Its only used for creating the
# sharding_spec.
Expand Down Expand Up @@ -571,8 +571,13 @@ def __init__(self, devices: Sequence[xc.Device] | np.ndarray,
# Will error if memory_kind does not exist on the device.
self._devices[0].memory(self._memory_kind)

shape = property(op.attrgetter('_ids.shape'))
ndim = property(op.attrgetter('_ids.ndim'))
@property
def shape(self):
return self._ids.shape

@property
def ndim(self):
return self._ids.ndim

def __repr__(self) -> str:
cls_name = self.__class__.__name__
Expand Down

0 comments on commit d4336c1

Please sign in to comment.