Skip to content

Commit

Permalink
Set the sharding of uncommitted single device sharding Arrays correct…
Browse files Browse the repository at this point in the history
…ly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes).

PiperOrigin-RevId: 474858389
  • Loading branch information
yashk2810 authored and jax authors committed Sep 16, 2022
1 parent e010ae7 commit eec1b4a
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 67 deletions.
11 changes: 9 additions & 2 deletions jax/experimental/pjit.py
Expand Up @@ -811,8 +811,15 @@ def _resolve_in_shardings(args, pjit_in_shardings, out_shardings, pjit_mesh):
if _is_unspecified(arg_s):
resolved_in_shardings.append(OpShardingSharding.get_replicated(da))
else:
resolved_in_shardings.append(to_op_sharding_sharding(
cast(XLACompatibleSharding, arg_s), arg.ndim))
if committed:
resolved_in_shardings.append(to_op_sharding_sharding(
cast(XLACompatibleSharding, arg_s), arg.ndim))
else:
if dispatch.is_single_device_sharding(arg_s):
resolved_in_shardings.append(OpShardingSharding.get_replicated(da))
else:
raise NotImplementedError('Having uncommitted Array sharded on '
'multiple devices is not supported.')
else:
if not _is_unspecified(arg_s):
if committed and not pxla.are_op_shardings_equal(
Expand Down
8 changes: 4 additions & 4 deletions jax/interpreters/pxla.py
Expand Up @@ -3367,8 +3367,8 @@ def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
from jax.experimental.array import Array

@lru_cache(maxsize=4096)
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim):
if not are_op_shardings_equal(
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
if committed and not are_op_shardings_equal(
arg_sharding._to_xla_op_sharding(ndim),
in_xla_sharding._to_xla_op_sharding(ndim)):
raise ValueError(
Expand All @@ -3381,9 +3381,9 @@ def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim):
continue
if isinstance(arg, GlobalDeviceArray):
_cached_check(_create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes), xs,
'GDA', arg.ndim)
'GDA', arg.ndim, True)
else:
_cached_check(arg.sharding, xs, 'Array', arg.ndim)
_cached_check(arg.sharding, xs, 'Array', arg.ndim, arg._committed)


def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
Expand Down
2 changes: 0 additions & 2 deletions tests/BUILD
Expand Up @@ -174,7 +174,6 @@ py_test(
jax_test(
name = "xmap_test",
srcs = ["xmap_test.py"],
disable_configs = ["cpu_jax_array"],
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 10,
Expand All @@ -190,7 +189,6 @@ jax_test(
jax_test(
name = "pjit_test",
srcs = ["pjit_test.py"],
disable_configs = ["cpu_jax_array"],
pjrt_c_api_bypass = True,
shard_count = {
"cpu": 5,
Expand Down

0 comments on commit eec1b4a

Please sign in to comment.