Skip to content

Commit

Permalink
Don't depend on mesh for UNSPECIFIED. Use OpShardingSharding for …
Browse files Browse the repository at this point in the history
…that since its now available and pjit accepts it.

PiperOrigin-RevId: 465641117
  • Loading branch information
yashk2810 authored and jax authors committed Aug 5, 2022
1 parent cc4bd0f commit 4b6d4a4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 44 deletions.
7 changes: 2 additions & 5 deletions jax/experimental/pjit.py
Expand Up @@ -849,12 +849,9 @@ def _pjit_lower(
for o in out_shardings
)

# TODO(yashkatariya): UNSPECIFIED should go through lower_sharding_computation.
# Also the `jaxpr_has_primitive` for xmap is temporary until xmap supports
# sharding instances.
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
# because `xmap` only supports SPMDAxisContext right now.
if (pxla._check_if_any_auto_or_unspecified(in_shardings + out_shardings) or
if (pxla._check_if_any_auto(it.chain(in_shardings, out_shardings)) or
dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap')):
return pxla.lower_mesh_computation(
fun, 'pjit', name, resource_env.physical_mesh,
Expand Down Expand Up @@ -1636,7 +1633,7 @@ def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartit
return in_ppspec, out_ppspec


def _get_sharding_from_executable(
def _get_pspec_from_executable(
executable, mesh: pxla.Mesh
) -> Tuple[Tuple[PartitionSpec, ...], Tuple[PartitionSpec, ...]]:
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)
Expand Down
70 changes: 38 additions & 32 deletions jax/interpreters/pxla.py
Expand Up @@ -2280,20 +2280,12 @@ class TileManual:
TilingMethod = Union[TileVectorize, TileManual]


def _check_if_any_auto(shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]]) -> bool:
def _check_if_any_auto(shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource]]) -> bool:
for s in shardings:
if _is_auto(s):
return True
return False

# TODO(yashkatariya): Remove this once UNSPECIFIED can be used without mesh.
def _check_if_any_auto_or_unspecified(
shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource, _UnspecifiedValue]]) -> bool:
for s in shardings:
if _is_auto(s) or _is_unspecified(s):
return True
return False


class _UnconstrainedPartitionSingleton:

Expand Down Expand Up @@ -2329,10 +2321,17 @@ def __repr__(self):


def _get_backend_from_shardings(
shardings: Sequence[XLACompatibleSharding]) -> Tuple[xb.XlaBackend, XLACompatibleSharding]:
device_set = shardings[0]._device_assignment
assert len(device_set) > 0
return xb.get_device_backend(device_set[0]), shardings[0]
shardings: Iterable[XLACompatibleSharding]) -> Tuple[xb.XlaBackend, XLACompatibleSharding]:
da = None
first_sharding = None
for s in shardings:
if _is_unspecified(s):
continue
da = s._device_assignment
first_sharding = s
break
assert len(da) > 0 # type: ignore
return xb.get_device_backend(da[0]), first_sharding # type: ignore


@profiler.annotate_function
Expand All @@ -2348,7 +2347,7 @@ def lower_sharding_computation(
# Device assignment across all inputs and outputs should be the same. This
# is checked in pjit.
backend, first_sharding = _get_backend_from_shardings(
in_shardings + out_shardings) # type: ignore
it.chain(in_shardings, out_shardings)) # type: ignore
name_stack = new_name_stack(wrap_name(fun_name, api_name))

log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
Expand Down Expand Up @@ -2383,7 +2382,7 @@ def lower_sharding_computation(
for aval, i in safe_zip(global_in_avals, in_shardings)]
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
# [None, OpShardingProto] has the sharding annotations.
out_op_shardings = [o._to_xla_op_sharding(aval.ndim)
out_op_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(global_out_avals, out_shardings)]
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = mlir.ShardingContext(first_sharding)
Expand Down Expand Up @@ -2631,11 +2630,22 @@ def _get_input_metadata(
return shardings, input_indices, input_avals


def _get_shardings_from_executable(xla_executable, mesh):
def _get_op_sharding_shardings_from_executable(xla_executable, device_assignment):
from jax.experimental import pjit
from jax.experimental.sharding import OpShardingSharding

in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings],
[OpShardingSharding(device_assignment, o) for o in out_op_shardings])


# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
# without mesh.
def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
from jax.experimental import pjit
from jax.experimental.sharding import MeshPspecSharding

in_pspec, out_pspec = pjit._get_sharding_from_executable(xla_executable, mesh)
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
return ([MeshPspecSharding(mesh, i) for i in in_pspec],
[MeshPspecSharding(mesh, o) for o in out_pspec])

Expand All @@ -2658,10 +2668,8 @@ def __init__(self, xla_executable, unsafe_call, input_avals,
@staticmethod
def from_hlo(name: str,
computation: Union[ir.Module, xc.XlaComputation],
# mesh only needs to be set if in_shardings and out_shardings
# contain AUTO or UNSPECIFIED (unspecified is temporary here).
# TODO(yashkatariya): Remove `mesh` from here once AUTO and
# UNSPECIFIED work without mesh.
# TODO(yashkatariya): Remove `mesh` from here once AUTO can work
# without mesh.
mesh: Optional[Mesh],
global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray],
Expand All @@ -2677,20 +2685,16 @@ def from_hlo(name: str,
unordered_effects: List[core.Effect],
host_callbacks: List[Any],
keepalive: Any) -> MeshExecutable:
auto_or_unspecified = (
auto_spmd_lowering or
(out_shardings and all(_is_unspecified(o) for o in out_shardings)))

if auto_or_unspecified:
if auto_spmd_lowering:
assert mesh is not None
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
else:
backend, first_sharding = _get_backend_from_shardings(
in_shardings + out_shardings) # type: ignore
it.chain(in_shardings, out_shardings)) # type: ignore

dev: np.ndarray
if auto_or_unspecified:
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
dev = mesh.devices
num_replicas, num_partitions = 1, mesh.size
Expand Down Expand Up @@ -2735,12 +2739,14 @@ def from_hlo(name: str,
xla_executable = dispatch.compile_or_get_cached(
backend, computation, compile_options, host_callbacks)

if auto_or_unspecified:
# TODO(yashkatariya): Make this work for UNSPECIFIED without mesh by
# returning `OpShardingSharding`.
if auto_spmd_lowering:
assert mesh is not None
in_shardings, out_shardings = _get_shardings_from_executable(
in_shardings, out_shardings = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
elif out_shardings and all(_is_unspecified(o) for o in out_shardings):
assert mesh is None
in_shardings, out_shardings = _get_op_sharding_shardings_from_executable(
xla_executable, first_sharding._device_assignment)

in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
Expand Down
22 changes: 15 additions & 7 deletions tests/pjit_test.py
Expand Up @@ -1499,6 +1499,17 @@ def test_numpy_array_input(self):
self.assertArraysEqual(out._value, input_data)

def test_unspecified_out_axis_resources(self):

def _checks(out, input_data):
self.assertIsInstance(out, array.Array)
self.assertIsInstance(out.sharding, OpShardingSharding)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(out._value, input_data)

global_input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mesh_axes = P('x', 'y')
Expand All @@ -1510,13 +1521,10 @@ def test_unspecified_out_axis_resources(self):
f = pjit(lambda x: x)

out = f(input_array)
self.assertIsInstance(out, array.Array)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(out._value, input_data)
_checks(out, input_data)

out2 = f(out)
_checks(out2, input_data)

@parameterized.named_parameters(
('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)),
Expand Down

0 comments on commit 4b6d4a4

Please sign in to comment.