Skip to content

Commit

Permalink
* Remove AUTO from MeshPspecSharding and treat it like _UNSPECIFIED s…
Browse files Browse the repository at this point in the history
…ingleton value.

* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
  * As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.

* Made all auto sharding tests parameterized to test both gda and array.

PiperOrigin-RevId: 459776152
  • Loading branch information
yashk2810 authored and jax authors committed Jul 8, 2022
1 parent 5a7bedc commit 229ddec
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 106 deletions.
109 changes: 80 additions & 29 deletions jax/experimental/pjit.py
Expand Up @@ -15,7 +15,7 @@
from enum import IntEnum
import numpy as np
from collections import OrderedDict, Counter
from typing import Callable, Sequence, Tuple, Union, Optional, cast, List, Iterable
from typing import Callable, Sequence, Tuple, Union, Optional, cast, List
import itertools as it
from functools import partial

Expand Down Expand Up @@ -265,20 +265,24 @@ def pjit(fun: Callable,
out_axis_resources, _, _, _ = _prepare_axis_resources(
out_axis_resources, "out_axis_resources")
else:
# TODO(yashkatariya): Relax this restriction once the transition to
# sharding instances finishes.
if not _is_unspecified(in_axis_resources):
raise ValueError('in_axis_resources should be empty for Array. The sharding '
'should be specified on the arguments as pjit follows '
'computation follows data semantics.')
# `pjit.AUTO` is allowed partially in `in_axis_resources` i.e. you can
# put sharding instances and `pjit.AUTO` together.
if not all(isinstance(s, Sharding) or _is_auto(s)
for s in tree_flatten(in_axis_resources)[0]):
raise ValueError('When `config.jax_array` flag is enabled, '
'in_axis_resources should contain instances of '
'`Sharding` or `pjit.AUTO`.')

# `out_axis_resources` should be instances of `Sharding` if it's not
# unspecified. For `AUTO` sharding, it can only be used with
# MeshPspecSharding.
if not _is_unspecified(out_axis_resources):
if not all(isinstance(s, Sharding) for s in tree_flatten(out_axis_resources)[0]):
if not all(isinstance(s, Sharding) or _is_auto(s)
for s in tree_flatten(out_axis_resources)[0]):
raise ValueError('When `config.jax_array` flag is enabled, '
'out_axis_resources should contain instances of '
'`Sharding`.')
'`Sharding` or `pjit.AUTO`.')

static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
Expand Down Expand Up @@ -311,14 +315,10 @@ def infer_params(*args, _global_avals=False, **kwargs):
donated_invars = (False,) * len(args_flat)

if config.jax_array:
if any(not isinstance(a, Array) for a in args_flat):
raise ValueError('All arguments to pjit when `config.jax_array` is '
'enabled should be `Array`s.')
# tree_map over `dyn_args` to preserve the pytree structure of args.
in_shardings = tree_map(lambda x: x.sharding, dyn_args)
in_shardings = _get_and_check_in_shardings(
dyn_args, in_axis_resources, pjit_mesh, in_tree)
# TODO(yashkatariya): Add a device assignment check for out_shardings too.
out_shardings = out_axis_resources
# This function is cached which is an improvement over the old codepath.
_check_array_device_assignment(pjit_mesh, tuple(tree_flatten(in_shardings)[0]))
else:
in_shardings = tree_map(lambda x: _create_mesh_pspec_sharding(pjit_mesh, x),
in_axis_resources)
Expand Down Expand Up @@ -406,10 +406,38 @@ def hashable_pytree(pytree):
closure=(treedef, vals))


def _get_and_check_in_shardings(args, pjit_in_shardings, pjit_mesh, in_tree):
try:
# tree_map over `args` to preserve the pytree structure of args.
arg_in_shardings = tree_map(lambda x: x.sharding, args)
except AttributeError:
arg_in_shardings = None

arg_in_shardings_flat = tuple(tree_flatten(arg_in_shardings)[0])

if _is_unspecified(pjit_in_shardings):
if arg_in_shardings is None:
raise ValueError('Please specify sharding either on the args or on pjit.')
else:
# This function is cached.
_check_array_device_assignment(pjit_mesh, arg_in_shardings_flat)
return arg_in_shardings
else:
if arg_in_shardings is None:
# TODO(yashkatariya): Add a check here to check against the device
# assignment of all pjit_in_shardings.
return pjit_in_shardings
else:
# This function is cached.
_check_pjit_arg_shardings(
hashable_pytree(pjit_in_shardings), arg_in_shardings_flat, in_tree)
return arg_in_shardings

assert False, "Please open a bug report!" # This should be unreachable.


def _create_mesh_pspec_sharding(mesh, x):
if _is_unspecified(x):
return x
if _is_from_gda(x):
if _is_unspecified_or_from_gda_or_auto(x):
return x
return MeshPspecSharding._from_parsed_pspec(mesh, x)

Expand Down Expand Up @@ -486,17 +514,20 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
# complexity below.
if config.jax_array:
for aval, i in safe_zip(local_in_avals, in_shardings_flat):
if _is_auto(i): continue
pjit_check_aval_sharding(i, aval, "pjit arguments",
allow_uneven_sharding=False)
global_in_avals = local_in_avals
return tuple(global_in_avals), tuple(i.normalize() for i in in_shardings_flat)
return tuple(global_in_avals), tuple(i if _is_auto(i) else i.normalize()
for i in in_shardings_flat)

if not local_in_avals:
assert not in_shardings_flat
return (), ()

in_axis_resources_flat = tuple(i if _is_from_gda(i) else i._parsed_pspec
for i in in_shardings_flat)
in_axis_resources_flat = tuple(
i if _is_from_gda(i) or _is_auto(i) else i._parsed_pspec
for i in in_shardings_flat)

# This check should be above local_to_global call below otherwise if
# `FROM_GDA` is passed to any input other than GDA, a ugly error message
Expand All @@ -518,17 +549,17 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
# For example, partitions of () and ((),) are not equivalent, since the
# first one is a valid spec for a scalar value, while the second is not!
for aval, i in safe_zip(local_in_avals, in_shardings_flat):
if _is_from_gda(i): continue
if _is_from_gda(i) or _is_auto(i): continue
pjit_check_aval_sharding(i, aval, "pjit arguments",
allow_uneven_sharding=False)
else:
for aval, i in safe_zip(local_in_avals, in_shardings_flat):
if _is_from_gda(i): continue
if _is_from_gda(i) or _is_auto(i): continue
pjit_check_aval_sharding(i, aval, "pjit arguments",
allow_uneven_sharding=False, local=True)

normalized_in_shardings_flat = tuple(
i if _is_from_gda(i) else i.normalize() for i in in_shardings_flat)
i if _is_from_gda(i) or _is_auto(i) else i.normalize() for i in in_shardings_flat)

global_in_avals = local_to_global(in_positional_semantics,
local_in_avals, normalized_in_shardings_flat)
Expand All @@ -551,11 +582,11 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
"pjit out_axis_resources", out_tree(), out_shardings_thunk(), tupled_args=False)

for aval, o in safe_zip(global_out_avals, out_shardings_flat):
if _is_unspecified(o): continue
if _is_unspecified(o) or _is_auto(o): continue
pjit_check_aval_sharding(o, aval, "pjit outputs", allow_uneven_sharding=False)

normalized_out_shardings_flat = tuple(
o if _is_unspecified(o) else o.normalize()
o if _is_unspecified(o) or _is_auto(o) else o.normalize()
for o in out_shardings_flat
)

Expand Down Expand Up @@ -778,7 +809,7 @@ def _pjit_call_impl(*args, jaxpr,
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.jax_enable_checks:
pxla._check_gda_xla_sharding_match(args, compiled._in_axes)
pxla._check_gda_or_array_xla_sharding_match(args, compiled._in_axes)
if config.jax_distributed_debug:
# Defensively only perform fingerprint logic if debug logging is enabled
# NOTE(skyewm): I didn't benchmark this
Expand Down Expand Up @@ -808,8 +839,9 @@ def _pjit_lower(
# recompilation (since pjit_lower is cached) if its compiled with `None` but
# in the next call `P(None)` is passed. Those are the same thing so should be
# treat as equivalent and pjit_lower's cache shouldn't be invalidated.
in_axes = [get_array_mapping(i._parsed_pspec) for i in in_shardings]
out_axes = [get_array_mapping(o if _is_unspecified(o) else o._parsed_pspec)
in_axes = [get_array_mapping(i if _is_auto(i) else i._parsed_pspec)
for i in in_shardings]
out_axes = [get_array_mapping(o if _is_unspecified(o) or _is_auto(o) else o._parsed_pspec)
for o in out_shardings]
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
f = core.jaxpr_as_fun(jaxpr)
Expand Down Expand Up @@ -1259,6 +1291,8 @@ def _get_in_positional_semantics(arg) -> maps._PositionalSemantics:
def _maybe_replace_from_gda_with_pspec(
in_sharding_flat, arg) -> MeshPspecSharding:
if isinstance(arg, GDA):
if _is_auto(in_sharding_flat):
return in_sharding_flat
gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(arg.mesh_axes, arg_name="GDA spec"))
if (not _is_from_gda(in_sharding_flat) and
Expand Down Expand Up @@ -1296,6 +1330,23 @@ def _check_array_device_assignment(pjit_mesh, in_shardings):
f"Got Pjit devices: {list(pjit_mesh.devices.flat)},\n "
f"Array devices: {arr_device_assignment}")

@cache()
def _check_pjit_arg_shardings(pjit_in_shardings, arg_in_shardings_flat,
in_tree):
pjit_in_shardings_flat = flatten_axis_resources(
"pjit in_shardings", in_tree, pjit_in_shardings(), tupled_args=True)

if pxla._check_if_any_auto(pjit_in_shardings_flat):
raise ValueError('Passing sharding on pjit and on args while using the '
'auto spmd partitioner is not allowed. Please call the '
'compiled object on the inputs.')

for p, a in safe_zip(pjit_in_shardings_flat, arg_in_shardings_flat):
if p.normalize() != a.normalize():
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {p},\narg sharding: {a}')


def _maybe_check_pjit_gda_mesh(args, mesh):
for x in args:
Expand Down
23 changes: 6 additions & 17 deletions jax/experimental/sharding.py
Expand Up @@ -14,7 +14,7 @@

import abc
from collections import Counter
from typing import Sequence, Tuple, Optional, Mapping, Dict, Set, Union
from typing import Sequence, Tuple, Optional, Mapping, Dict, Set

from jax._src.util import cache, safe_zip
from jax._src.lib import xla_bridge as xb
Expand Down Expand Up @@ -83,9 +83,7 @@ def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
class MeshPspecSharding(XLACompatibleSharding):

def __init__(
self, mesh: pxla.Mesh,
spec: Union[pxla.PartitionSpec, pxla._AUTOAxisResource],
_parsed_pspec = None):
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):

self.mesh = mesh
self.spec = spec
Expand Down Expand Up @@ -115,17 +113,13 @@ def __eq__(self, other):

def normalize(self):
from jax.experimental import pjit
cp = (self._parsed_pspec if pjit._is_auto(self._parsed_pspec) else
pjit.CanonicalizedParsedPartitionSpec(self._parsed_pspec))
cp = pjit.CanonicalizedParsedPartitionSpec(self._parsed_pspec)
return MeshPspecSharding._from_parsed_pspec(self.mesh, cp)

@classmethod
def _from_parsed_pspec(cls, mesh, parsed_pspec):
from jax.experimental import pjit
parsed_pspec, spec = ((parsed_pspec, parsed_pspec)
if pjit._is_auto(parsed_pspec) else
(parsed_pspec, pjit._get_single_pspec(parsed_pspec)))
return cls(mesh, spec, parsed_pspec)
return cls(mesh, pjit._get_single_pspec(parsed_pspec), parsed_pspec)

@pxla.maybe_cached_property
def device_set(self) -> Set[Device]:
Expand All @@ -139,11 +133,6 @@ def devices_indices_map(
# TODO(yashkatariya): Remove this when utilities are moved to pxla.py.
from jax.experimental import global_device_array

if pxla._is_auto(self.spec):
raise ValueError('Getting indices when the sharding is not known is not '
'possible. Please get the sharding from XLA and then '
'create the Array with that sharding.')

# `get_shard_indices` is cached.
return global_device_array.get_shard_indices(global_shape, self.mesh, self.spec)

Expand All @@ -161,6 +150,7 @@ def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
out[device] = replica_id
return out

@cache()
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.mesh.devices.flat)

Expand All @@ -170,8 +160,6 @@ def _to_xla_op_sharding(
axis_ctx: Optional[mlir.SPMDAxisContext] = None) -> Optional[xc.OpSharding]:
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources

assert not pxla._is_auto(self.spec)

parsed_spec, _, _, _ = _prepare_axis_resources(self.spec, "spec")
array_mapping = get_array_mapping(parsed_spec)
# TODO(yashkatariya): Move away from sharding spec in MeshPspecSharding
Expand Down Expand Up @@ -277,6 +265,7 @@ def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
out[device] = replica_id
return out

@cache()
def _device_assignment(self) -> XLADeviceAssignment:
return list(self.devices.flat)

Expand Down
23 changes: 15 additions & 8 deletions jax/interpreters/pxla.py
Expand Up @@ -2534,7 +2534,7 @@ def call(self, *args):
ref_avals = self._input_avals
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
# Check the GDA sharding and the input sharding.
_check_gda_xla_sharding_match(args, self._in_axes)
_check_gda_or_array_xla_sharding_match(args, self._in_axes)
return self.unsafe_call(*args)


Expand All @@ -2545,18 +2545,25 @@ def _get_mesh_pspec_sharding(mesh, out_axes):
for o in out_axes]


def _check_gda_xla_sharding_match(args, in_array_mappings):
def _check_gda_or_array_xla_sharding_match(args, in_array_mappings):
from jax.experimental.global_device_array import GlobalDeviceArray, _get_array_mapping
from jax.experimental.array import Array

for arg, inp_array_mapping in safe_zip(args, in_array_mappings):
if not isinstance(arg, GlobalDeviceArray):
if not isinstance(arg, (GlobalDeviceArray, Array)):
continue
gda_array_mapping = _get_array_mapping(arg.mesh_axes)
if inp_array_mapping != gda_array_mapping:
# TODO(yashkatariya): For `Array` check the `sharding` directly when pxla
# takes sharding instances.
arr_type, arr_mapping = (
('GDA', _get_array_mapping(arg.mesh_axes)) if isinstance(arg, GlobalDeviceArray)
else ('Array', _get_array_mapping(arg.sharding.spec))
)
if inp_array_mapping != arr_mapping:
raise ValueError(
"GDA sharding does not match the input sharding. "
f"Got GDA spec: {array_mapping_to_axis_resources(gda_array_mapping)} and "
f"auto sharding spec: {array_mapping_to_axis_resources(inp_array_mapping)} for GDA: {arg}")
f"{arr_type} sharding does not match the input sharding. "
f"Got {arr_type} spec: {array_mapping_to_axis_resources(arr_mapping)} and "
f"auto sharding spec: {array_mapping_to_axis_resources(inp_array_mapping)} "
f"for {arr_type}: {arg}")


_forbidden_primitives = {
Expand Down

0 comments on commit 229ddec

Please sign in to comment.