Skip to content

Commit

Permalink
Return PositionalSharding if input's rank is >= 3 or a NamedSharding …
Browse files Browse the repository at this point in the history
…if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding.

PiperOrigin-RevId: 573048344
  • Loading branch information
yashk2810 authored and jax authors committed Oct 12, 2023
1 parent 489cd44 commit ef20526
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/_src/debugging.py
Expand Up @@ -40,8 +40,7 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (GSPMDSharding, NamedSharding,
parse_flatten_op_sharding)
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding

# pytype: disable=import-error
try:
Expand Down Expand Up @@ -336,7 +335,8 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
# partitioner calls back with the `HloSharding.
def _hlo_sharding_callback(hlo_sharding: xc.HloSharding):
if mesh.empty:
return callback(GSPMDSharding(devices, hlo_sharding))
return callback(
sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices))
pspec = parse_flatten_op_sharding(hlo_sharding, mesh)[0].get_partition_spec()
return callback(NamedSharding(mesh, pspec))

Expand Down
33 changes: 33 additions & 0 deletions tests/debugging_primitives_test.py
Expand Up @@ -1113,6 +1113,39 @@ def f_(x):
f(np.arange(8, dtype=jnp.float32))
self.assertTrue(is_called)

def test_inspect_sharding_3d_input_pos_sharding(self):
def _cb(sd):
self.assertIsInstance(sd, jax.sharding.PositionalSharding)
self.assertLen(sd.device_set, 2)

def f_(x):
debugging.inspect_array_sharding(x, callback=_cb)
return jnp.square(x)

f = jax.jit(f_)
mesh = jtu.create_global_mesh((2,), ('x'))
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s)

f(arr)

def test_inspect_sharding_3d_input_named_sharding(self):
def _cb(sd):
self.assertIsInstance(sd, jax.sharding.NamedSharding)
self.assertLen(sd.device_set, 2)

def f_(x):
debugging.inspect_array_sharding(x, callback=_cb)
return jnp.square(x)

f = pjit.pjit(f_)
mesh = jtu.create_global_mesh((2,), ('x'))
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s)

with mesh:
f(arr)


if not rich:
del VisualizeShardingTest
Expand Down

0 comments on commit ef20526

Please sign in to comment.