Skip to content

Commit

Permalink
Return PositionalSharding instead of GSPMDSharding in custom_partitio…
Browse files Browse the repository at this point in the history
…ning when mesh is not defined

PiperOrigin-RevId: 539719517
  • Loading branch information
yashk2810 authored and jax authors committed Jun 12, 2023
1 parent 79a1bc9 commit 4d698c3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion jax/_src/interpreters/pxla.py
Expand Up @@ -2439,7 +2439,7 @@ def _gspmd_to_named_sharding(
def _gspmd_to_positional_sharding(
op_sharding: xc.OpSharding,
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
return sharding_impls._from_op_sharding_to_pos_sharding(
return sharding_impls._op_sharding_to_pos_sharding(
op_sharding, self._device_assignment)
orig_out_sharding_handlers[sharding_impls.PositionalSharding] = _gspmd_to_positional_sharding

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/sharding_impls.py
Expand Up @@ -495,7 +495,7 @@ def shard_shape(self, global_shape: Shape) -> Shape:
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]


def _from_op_sharding_to_pos_sharding(
def _op_sharding_to_pos_sharding(
op_sharding: Union[xc.OpSharding, xc.HloSharding],
device_assignment: Sequence[xc.Device]) -> PositionalSharding:
if isinstance(op_sharding, xc.HloSharding):
Expand Down
12 changes: 6 additions & 6 deletions jax/experimental/custom_partitioning.py
Expand Up @@ -25,6 +25,7 @@
from jax._src.lib.mlir import ir
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.sharding_impls import _op_sharding_to_pos_sharding
from jax._src import custom_api_util
from jax._src.lib import xla_client as xc
from jax._src.api_util import flatten_fun_nokwargs
Expand Down Expand Up @@ -475,15 +476,14 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
return mlir.lower_fun(
core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)

def to_mesh_pspec_sharding(op_sharding: Optional[xc.OpSharding]):
if op_sharding is None:
return op_sharding
def to_mesh_pspec_sharding(hlo_sharding: Optional[xc.HloSharding]):
if hlo_sharding is None:
return hlo_sharding
if mesh.empty or not decode_shardings:
from jax._src.sharding_impls import GSPMDSharding
assert devices is not None
return GSPMDSharding(devices, op_sharding.to_proto())
return _op_sharding_to_pos_sharding(hlo_sharding, devices)
pspec = sharding_impls.parse_flatten_op_sharding(
op_sharding.to_proto(), mesh)[0].get_partition_spec()
hlo_sharding, mesh)[0].get_partition_spec()
return jax.sharding.NamedSharding(mesh, pspec)

sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding,
Expand Down
6 changes: 3 additions & 3 deletions tests/array_test.py
Expand Up @@ -31,7 +31,7 @@
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.util import safe_zip
from jax._src.sharding_impls import _from_op_sharding_to_pos_sharding
from jax._src.sharding_impls import _op_sharding_to_pos_sharding
from jax.experimental.pjit import pjit
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
Expand Down Expand Up @@ -922,7 +922,7 @@ def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec):
mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z'))
mps = jax.sharding.NamedSharding(mesh, pspec)
original_op_sharding = mps._to_xla_hlo_sharding(ndim)
ps = _from_op_sharding_to_pos_sharding(original_op_sharding,
ps = _op_sharding_to_pos_sharding(original_op_sharding,
mps._device_assignment)
out_op_sharding = ps._to_xla_hlo_sharding(ndim)
self.assertTrue(op_shardings.are_op_shardings_equal(
Expand Down Expand Up @@ -958,7 +958,7 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
self.assertEqual(mps.is_fully_replicated, ops_ifr)

ps = _from_op_sharding_to_pos_sharding(mps_op_sharding,
ps = _op_sharding_to_pos_sharding(mps_op_sharding,
mps._device_assignment)
self.assertEqual(ps.is_fully_replicated,
op_shardings.is_op_sharding_replicated(
Expand Down

0 comments on commit 4d698c3

Please sign in to comment.