From c0d4653fc9ff19004dab79de85d5acbfbde566dc Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 2 Jan 2024 13:12:44 -0800 Subject: [PATCH] Delete sharding spec to HloSharding conversion since it's not used anymore. PiperOrigin-RevId: 595192496 --- jax/_src/interpreters/pxla.py | 6 ++- jax/_src/sharding_specs.py | 92 +---------------------------------- tests/pickle_test.py | 11 +++-- tests/pmap_test.py | 22 +-------- 4 files changed, 14 insertions(+), 117 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 62eddfd0a50f..d5dbc23ccabe 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2475,6 +2475,10 @@ def _get_layouts_from_executable( return new_in_layouts, new_out_layouts # type: ignore +def get_logical_mesh_ids(mesh_shape): + return np.arange(math.prod(mesh_shape)).reshape(mesh_shape) + + @weakref_lru_cache def _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, @@ -2528,7 +2532,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, assert mesh is not None opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values()) opts.auto_spmd_partitioning_mesh_ids = ( - sharding_specs.get_logical_mesh_ids(list(mesh.shape.values())) + get_logical_mesh_ids(list(mesh.shape.values())) .reshape(-1)) compile_options.parameter_is_tupled_arguments = tuple_args opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs) diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index 938ebe868add..46439c6105d7 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -29,18 +29,15 @@ from __future__ import annotations -import collections -from collections.abc import Mapping, Sequence +from collections.abc import Sequence import itertools import math -from typing import Any, Union, cast +from typing import Union import numpy as np -from jax._src import op_shardings from jax._src import util from jax._src.lib import pmap_lib -from jax._src.lib import xla_client as xc unsafe_map, map = map, util.safe_map @@ -56,9 +53,6 @@ ShardingSpec = pmap_lib.ShardingSpec -OpShardingType = Any - - def _sharding_spec_mesh_shape(self): sharded_axis_sizes = [] @@ -76,79 +70,6 @@ def _sharding_spec_mesh_shape(self): for a in self.mesh_mapping) -def get_logical_mesh_ids(mesh_shape): - return np.arange(math.prod(mesh_shape)).reshape(mesh_shape) - - -_MeshAxisName = Any - -def sharding_spec_sharding_proto( - self, special_axes: Mapping[int, OpShardingType] | None = None -) -> xc.HloSharding: - """Converts a ShardingSpec to an OpSharding proto. - - See - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601 - for details on the OpSharding proto. - Unfortunately the semantics are not very well described in the proto spec, but - the code here might help: - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py - """ - special_axes_dict = {} if special_axes is None else special_axes - mesh_shape = cast(tuple[int, ...], self.mesh_shape) - - sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped - replicated_maxes = [] # lists mesh axis identifiers to replicate over - for maxis, assignment in enumerate(self.mesh_mapping): - if isinstance(assignment, Replicated): - replicated_maxes.append((maxis, assignment.replicas)) - elif isinstance(assignment, ShardedAxis): - sharded_axes[assignment.axis] = maxis - else: - util.assert_unreachable(assignment) - - if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes_dict: - return xc.HloSharding.replicate() - - mesh_permutation = [] - new_mesh_shape = [] - next_sharded_axis = 0 - for axis, sharding in enumerate(self.sharding): - if isinstance(sharding, NoSharding): - new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over - elif isinstance(sharding, Chunked): - for nchunks in sharding.chunks: - maxis = sharded_axes[next_sharded_axis] - assert mesh_shape[maxis] == nchunks - mesh_permutation.append(maxis) - next_sharded_axis += 1 - new_mesh_shape.append(math.prod(sharding.chunks)) - elif isinstance(sharding, Unstacked): - raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding") - else: - util.assert_unreachable(sharding) - - # Create a partial sharding proto if tensor is replicated or partitioned - # specially over some mesh axes. - last_tile_dims = [] - if replicated_maxes: - axes_by_type: dict[OpShardingType, list[_MeshAxisName]] = {} - size_by_type: dict[OpShardingType, int] = collections.defaultdict(lambda: 1) - assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes_dict.keys())) - for axis, size in replicated_maxes: - ty = special_axes_dict.get(axis, xc.OpSharding.Type.REPLICATED) - axes_by_type.setdefault(ty, []).append(axis) - size_by_type[ty] *= size - for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value): - last_tile_dims.append(ty) - new_mesh_shape.append(size_by_type[ty]) - mesh_permutation.extend(axes) - - return xc.HloSharding.iota_tile( - dims=new_mesh_shape, reshape_dims=mesh_shape, - transpose_perm=mesh_permutation, subgroup_types=last_tile_dims) - - def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray: """Returns NumPy-style indices corresponding to a sharding spec. @@ -163,14 +84,6 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray: """ assert len(shape) == len(self.sharding), (shape, self.sharding) - has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding) - # Take the op sharding indices generation route for pjit/xmap cases. - if not has_unstacked: - hlo_sharding = sharding_spec_sharding_proto(self) - return op_shardings.op_sharding_to_numpy_indices( - hlo_sharding, shape, math.prod(self.mesh_shape) - ).reshape(self.mesh_shape) - axis_indices: list[Sequence[Index]] = [] shard_indices_shape = [] for dim, sharding in enumerate(self.sharding): @@ -221,7 +134,6 @@ def _sharding_spec_repr(self): ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape) -ShardingSpec.sharding_proto = sharding_spec_sharding_proto ShardingSpec.indices = _sharding_spec_indices # mypy raises: error: Cannot assign to a method [assignment] ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 929a21e01dc1..8fa6613cf895 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -164,12 +164,13 @@ def testPickleSharding(self): self.assertEqual(pickle.loads(pickle.dumps(sharding)), sharding) def testPickleOpSharding(self): - sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), - (pxla.ShardedAxis(0), pxla.ShardedAxis(1))) - op_sharding = sharding.sharding_proto().to_proto() + op = xc.OpSharding() + op.type = xc.OpSharding.Type.OTHER + op.tile_assignment_dimensions = [4, 2] + op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] self.assertTrue( - xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op_sharding))), - xc.HloSharding.from_proto(op_sharding)) + xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op))), + xc.HloSharding.from_proto(op)) def test_pickle_single_device_sharding(self): s = jax.sharding.SingleDeviceSharding(jax.devices()[0]) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index f60ffe699ef2..81040590f096 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2988,25 +2988,12 @@ def device_array(x): # unsharded [(4, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()), mesh_mapping=())], - # partitioned, 1 axis - [(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()), - mesh_mapping=(pxla.ShardedAxis(0),))], - # partitioned, 2 axes - [(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])), - mesh_mapping=map(pxla.ShardedAxis, (0, 1)))], - # partitioned, 2 axes, permuted - [(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])), - mesh_mapping=map(pxla.ShardedAxis, (1, 0)))], # replication + sharding [(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()), mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))], # replication, no sharding [(2, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()), mesh_mapping=(pxla.Replicated(3),))], - # multiple replicated axes - [(1, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([1]), pxla.Chunked([2])), - mesh_mapping=(pxla.Replicated(2), pxla.ShardedAxis(0), - pxla.Replicated(2), pxla.ShardedAxis(1)))], # replicated scalar [(), pxla.ShardingSpec(sharding=(), mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))], @@ -3018,14 +3005,7 @@ def testShardArgs(self, shape, spec, make_arg): raise SkipTest x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) - sharding = None - if any(isinstance(s, pxla.Unstacked) for s in spec.sharding): - sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - else: - sharding = jax.sharding.GSPMDSharding( - jax.devices()[:nshards], - sharding_specs.sharding_spec_sharding_proto(spec)) - + sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) results = pxla.shard_args( jax.devices()[:nshards], [indices], [sharding], [arg] )