Skip to content

Commit

Permalink
Cache the replacement of FROM_GDA to actual shardings present on GD…
Browse files Browse the repository at this point in the history
…A and check for sharding equality via opsharding.

PiperOrigin-RevId: 463388009
  • Loading branch information
yashk2810 authored and jax authors committed Jul 26, 2022
1 parent b4bf8e5 commit 97f2f4e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 37 deletions.
55 changes: 30 additions & 25 deletions jax/experimental/pjit.py
Expand Up @@ -311,10 +311,12 @@ def infer_params(*args, _global_avals=False, **kwargs):
in_shardings, out_shardings = _get_and_check_in_and_out_shardings(
args_flat, in_axis_resources, out_axis_resources, pjit_mesh, in_tree)
else:
in_shardings = tree_map(lambda x: _create_mesh_pspec_sharding(pjit_mesh, x),
in_axis_resources)
out_shardings = tree_map(lambda x: x if _is_unspecified(x) else
_create_mesh_pspec_sharding(pjit_mesh, x), out_axis_resources)
in_shardings = tree_map(
lambda x: _create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x),
in_axis_resources)
out_shardings = tree_map(
lambda x: x if _is_unspecified(x) else
_create_mesh_pspec_sharding_from_parsed_pspec(pjit_mesh, x), out_axis_resources)
# This check fails extrememly rarely and has a huge cost in the dispatch
# path. So hide it behind the jax_enable_checks flag.
if config.jax_enable_checks:
Expand Down Expand Up @@ -346,9 +348,8 @@ def infer_params(*args, _global_avals=False, **kwargs):
HashableFunction(out_tree, closure=()))

if not config.jax_array:
normalized_in_shardings_flat = tree_map(
_maybe_replace_from_gda_with_pspec, normalized_in_shardings_flat,
tuple(args_flat))
normalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
normalized_in_shardings_flat, args_flat)

params = dict(
jaxpr=jaxpr,
Expand Down Expand Up @@ -429,10 +430,10 @@ def _get_and_check_in_and_out_shardings(args_flat, pjit_in_shardings, out_shardi


@lru_cache(maxsize=4096)
def _create_mesh_pspec_sharding(mesh, x):
def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
if _is_unspecified_or_from_gda_or_auto(x):
return x
return MeshPspecSharding(mesh, x.user_spec, x)
return pxla._create_mesh_pspec_sharding(mesh, x.user_spec, x)


def flatten_axis_resources(what, tree, shardings, tupled_args):
Expand Down Expand Up @@ -1277,26 +1278,30 @@ def _get_in_positional_semantics(arg) -> maps._PositionalSemantics:


def _maybe_replace_from_gda_with_pspec(
in_sharding_flat, arg) -> MeshPspecSharding:
if isinstance(arg, GDA):
if _is_auto(in_sharding_flat):
return in_sharding_flat
gda_cpspec = CanonicalizedParsedPartitionSpec(
ParsedPartitionSpec.from_user_input(arg.mesh_axes, arg_name="GDA spec"))
if (not _is_from_gda(in_sharding_flat) and
in_sharding_flat._parsed_pspec != gda_cpspec):
in_shardings_flat, args_flat) -> Sequence[XLACompatibleSharding]:

@lru_cache()
def _gda_check_and_get_sharding(gda_sharding, in_sharding, ndim):
if not _is_from_gda(in_sharding) and not pxla.are_op_shardings_equal(
gda_sharding._to_xla_op_sharding(ndim),
in_sharding._to_xla_op_sharding(ndim)):
raise ValueError(
f"Got an input GDA to pjit with different partitioning than specified in "
"the in_axis_resources argument to pjit. The partitioning must match, or "
"use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. "
f"Got GDA spec: {gda_cpspec.user_spec} and "
f"pjit spec: {in_sharding_flat.spec} for GDA: {arg}")
# This is an optimization to return the original sharding if its not
# FROM_GDA.
if not _is_from_gda(in_sharding_flat):
return in_sharding_flat
return MeshPspecSharding(arg.mesh, arg.mesh_axes).normalize()
return in_sharding_flat
f"Got GDA sharding: {gda_sharding} and pjit sharding: {in_sharding}")
return gda_sharding.normalize()

out = []
for in_sharding_flat, arg in safe_zip(in_shardings_flat, args_flat):
if _is_auto(in_sharding_flat):
out.append(in_sharding_flat)
elif isinstance(arg, GDA):
gda_sharding = pxla._create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
out.append(_gda_check_and_get_sharding(gda_sharding, in_sharding_flat, arg.ndim))
else:
out.append(in_sharding_flat)
return tuple(out)


@lru_cache(maxsize=4096)
Expand Down
11 changes: 6 additions & 5 deletions jax/interpreters/pxla.py
Expand Up @@ -2672,15 +2672,16 @@ def call(self, *args):
return self.unsafe_call(*args)


@lru_cache()
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
from jax.experimental.sharding import MeshPspecSharding
return MeshPspecSharding(mesh, pspec, parsed_pspec)


def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.array import Array

@lru_cache
def _create_mesh_pspec_sharding(mesh, pspec):
from jax.experimental.sharding import MeshPspecSharding
return MeshPspecSharding(mesh, pspec)

@lru_cache(maxsize=4096)
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim):
if not are_op_shardings_equal(
Expand Down
14 changes: 7 additions & 7 deletions tests/pjit_test.py
Expand Up @@ -1119,14 +1119,14 @@ def cb(index):

gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with self.assertRaisesWithLiteralMatch(

with self.assertRaisesRegex(
ValueError,
"Got an input GDA to pjit with different partitioning than specified "
'in the in_axis_resources argument to pjit. The partitioning must match, or '
'use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. '
"Got GDA spec: PartitionSpec('x',) and "
"pjit spec: PartitionSpec(('x',), ('y',)) "
'for GDA: GlobalDeviceArray(shape=(8, 2), dtype=float32)'):
r"Got an input GDA to pjit with different partitioning than specified "
r'in the in_axis_resources argument to pjit. The partitioning must match, or '
r'use `jax.experimental.pjit.FROM_GDA` in `in_axis_resources` for GDA. '
r"Got GDA sharding.*PartitionSpec\('x',\).*and "
r"pjit sharding.*PartitionSpec\(\('x',\), \('y',\)\).*"):
@partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y'))
def f(x):
return x
Expand Down

0 comments on commit 97f2f4e

Please sign in to comment.