From 4d698c30b97b6c23cf34083965f05db5c6bbda79 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 12 Jun 2023 11:51:47 -0700 Subject: [PATCH] Return PositionalSharding instead of GSPMDSharding in custom_partitioning when mesh is not defined PiperOrigin-RevId: 539719517 --- jax/_src/interpreters/pxla.py | 2 +- jax/_src/sharding_impls.py | 2 +- jax/experimental/custom_partitioning.py | 12 ++++++------ tests/array_test.py | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 550c61600d71..1ee26ac6eba3 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2439,7 +2439,7 @@ def _gspmd_to_named_sharding( def _gspmd_to_positional_sharding( op_sharding: xc.OpSharding, self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding: - return sharding_impls._from_op_sharding_to_pos_sharding( + return sharding_impls._op_sharding_to_pos_sharding( op_sharding, self._device_assignment) orig_out_sharding_handlers[sharding_impls.PositionalSharding] = _gspmd_to_positional_sharding diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 04898da32ebe..6314764148da 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -495,7 +495,7 @@ def shard_shape(self, global_shape: Shape) -> Shape: return global_shape[:sharded_dim] + global_shape[sharded_dim+1:] -def _from_op_sharding_to_pos_sharding( +def _op_sharding_to_pos_sharding( op_sharding: Union[xc.OpSharding, xc.HloSharding], device_assignment: Sequence[xc.Device]) -> PositionalSharding: if isinstance(op_sharding, xc.HloSharding): diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index fbf19629c13f..b2a44471c75c 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -25,6 +25,7 @@ from jax._src.lib.mlir import ir from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.sharding_impls import _op_sharding_to_pos_sharding from jax._src import custom_api_util from jax._src.lib import xla_client as xc from jax._src.api_util import flatten_fun_nokwargs @@ -475,15 +476,14 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, return mlir.lower_fun( core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values) - def to_mesh_pspec_sharding(op_sharding: Optional[xc.OpSharding]): - if op_sharding is None: - return op_sharding + def to_mesh_pspec_sharding(hlo_sharding: Optional[xc.HloSharding]): + if hlo_sharding is None: + return hlo_sharding if mesh.empty or not decode_shardings: - from jax._src.sharding_impls import GSPMDSharding assert devices is not None - return GSPMDSharding(devices, op_sharding.to_proto()) + return _op_sharding_to_pos_sharding(hlo_sharding, devices) pspec = sharding_impls.parse_flatten_op_sharding( - op_sharding.to_proto(), mesh)[0].get_partition_spec() + hlo_sharding, mesh)[0].get_partition_spec() return jax.sharding.NamedSharding(mesh, pspec) sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, diff --git a/tests/array_test.py b/tests/array_test.py index a47d7cca3c1a..bf2113ff0f5b 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -31,7 +31,7 @@ from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.util import safe_zip -from jax._src.sharding_impls import _from_op_sharding_to_pos_sharding +from jax._src.sharding_impls import _op_sharding_to_pos_sharding from jax.experimental.pjit import pjit from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P @@ -922,7 +922,7 @@ def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec): mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z')) mps = jax.sharding.NamedSharding(mesh, pspec) original_op_sharding = mps._to_xla_hlo_sharding(ndim) - ps = _from_op_sharding_to_pos_sharding(original_op_sharding, + ps = _op_sharding_to_pos_sharding(original_op_sharding, mps._device_assignment) out_op_sharding = ps._to_xla_hlo_sharding(ndim) self.assertTrue(op_shardings.are_op_shardings_equal( @@ -958,7 +958,7 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding) self.assertEqual(mps.is_fully_replicated, ops_ifr) - ps = _from_op_sharding_to_pos_sharding(mps_op_sharding, + ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment) self.assertEqual(ps.is_fully_replicated, op_shardings.is_op_sharding_replicated(