Skip to content

Commit

Permalink
Delete sharding spec to HloSharding conversion since it's not used an…
Browse files Browse the repository at this point in the history
…ymore.

PiperOrigin-RevId: 595192496
  • Loading branch information
yashk2810 authored and jax authors committed Jan 2, 2024
1 parent fff5ea5 commit c0d4653
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 117 deletions.
6 changes: 5 additions & 1 deletion jax/_src/interpreters/pxla.py
Expand Up @@ -2475,6 +2475,10 @@ def _get_layouts_from_executable(
return new_in_layouts, new_out_layouts # type: ignore


def get_logical_mesh_ids(mesh_shape):
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)


@weakref_lru_cache
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering,
Expand Down Expand Up @@ -2528,7 +2532,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
assert mesh is not None
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
opts.auto_spmd_partitioning_mesh_ids = (
sharding_specs.get_logical_mesh_ids(list(mesh.shape.values()))
get_logical_mesh_ids(list(mesh.shape.values()))
.reshape(-1))
compile_options.parameter_is_tupled_arguments = tuple_args
opts.allow_spmd_sharding_propagation_to_output = list(_allow_propagation_to_outputs)
Expand Down
92 changes: 2 additions & 90 deletions jax/_src/sharding_specs.py
Expand Up @@ -29,18 +29,15 @@

from __future__ import annotations

import collections
from collections.abc import Mapping, Sequence
from collections.abc import Sequence
import itertools
import math
from typing import Any, Union, cast
from typing import Union

import numpy as np

from jax._src import op_shardings
from jax._src import util
from jax._src.lib import pmap_lib
from jax._src.lib import xla_client as xc

unsafe_map, map = map, util.safe_map

Expand All @@ -56,9 +53,6 @@

ShardingSpec = pmap_lib.ShardingSpec

OpShardingType = Any



def _sharding_spec_mesh_shape(self):
sharded_axis_sizes = []
Expand All @@ -76,79 +70,6 @@ def _sharding_spec_mesh_shape(self):
for a in self.mesh_mapping)


def get_logical_mesh_ids(mesh_shape):
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)


_MeshAxisName = Any

def sharding_spec_sharding_proto(
self, special_axes: Mapping[int, OpShardingType] | None = None
) -> xc.HloSharding:
"""Converts a ShardingSpec to an OpSharding proto.
See
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
Unfortunately the semantics are not very well described in the proto spec, but
the code here might help:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/compiler/xla/experimental/xla_sharding/xla_sharding.py
"""
special_axes_dict = {} if special_axes is None else special_axes
mesh_shape = cast(tuple[int, ...], self.mesh_shape)

sharded_axes = {} # maps sharded axis identifiers to mesh axis indices to which they're mapped
replicated_maxes = [] # lists mesh axis identifiers to replicate over
for maxis, assignment in enumerate(self.mesh_mapping):
if isinstance(assignment, Replicated):
replicated_maxes.append((maxis, assignment.replicas))
elif isinstance(assignment, ShardedAxis):
sharded_axes[assignment.axis] = maxis
else:
util.assert_unreachable(assignment)

if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes_dict:
return xc.HloSharding.replicate()

mesh_permutation = []
new_mesh_shape = []
next_sharded_axis = 0
for axis, sharding in enumerate(self.sharding):
if isinstance(sharding, NoSharding):
new_mesh_shape.append(1) # Add a dummy mesh axis we won't be sharding over
elif isinstance(sharding, Chunked):
for nchunks in sharding.chunks:
maxis = sharded_axes[next_sharded_axis]
assert mesh_shape[maxis] == nchunks
mesh_permutation.append(maxis)
next_sharded_axis += 1
new_mesh_shape.append(math.prod(sharding.chunks))
elif isinstance(sharding, Unstacked):
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
else:
util.assert_unreachable(sharding)

# Create a partial sharding proto if tensor is replicated or partitioned
# specially over some mesh axes.
last_tile_dims = []
if replicated_maxes:
axes_by_type: dict[OpShardingType, list[_MeshAxisName]] = {}
size_by_type: dict[OpShardingType, int] = collections.defaultdict(lambda: 1)
assert {x[0] for x in replicated_maxes}.issuperset(set(special_axes_dict.keys()))
for axis, size in replicated_maxes:
ty = special_axes_dict.get(axis, xc.OpSharding.Type.REPLICATED)
axes_by_type.setdefault(ty, []).append(axis)
size_by_type[ty] *= size
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
last_tile_dims.append(ty)
new_mesh_shape.append(size_by_type[ty])
mesh_permutation.extend(axes)

return xc.HloSharding.iota_tile(
dims=new_mesh_shape, reshape_dims=mesh_shape,
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)


def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
"""Returns NumPy-style indices corresponding to a sharding spec.
Expand All @@ -163,14 +84,6 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
"""
assert len(shape) == len(self.sharding), (shape, self.sharding)

has_unstacked = any(isinstance(s, Unstacked) for s in self.sharding)
# Take the op sharding indices generation route for pjit/xmap cases.
if not has_unstacked:
hlo_sharding = sharding_spec_sharding_proto(self)
return op_shardings.op_sharding_to_numpy_indices(
hlo_sharding, shape, math.prod(self.mesh_shape)
).reshape(self.mesh_shape)

axis_indices: list[Sequence[Index]] = []
shard_indices_shape = []
for dim, sharding in enumerate(self.sharding):
Expand Down Expand Up @@ -221,7 +134,6 @@ def _sharding_spec_repr(self):


ShardingSpec.mesh_shape = property(_sharding_spec_mesh_shape)
ShardingSpec.sharding_proto = sharding_spec_sharding_proto
ShardingSpec.indices = _sharding_spec_indices
# mypy raises: error: Cannot assign to a method [assignment]
ShardingSpec.__repr__ = _sharding_spec_repr # type: ignore
Expand Down
11 changes: 6 additions & 5 deletions tests/pickle_test.py
Expand Up @@ -164,12 +164,13 @@ def testPickleSharding(self):
self.assertEqual(pickle.loads(pickle.dumps(sharding)), sharding)

def testPickleOpSharding(self):
sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(0), pxla.ShardedAxis(1)))
op_sharding = sharding.sharding_proto().to_proto()
op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [4, 2]
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
self.assertTrue(
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op_sharding))),
xc.HloSharding.from_proto(op_sharding))
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op))),
xc.HloSharding.from_proto(op))

def test_pickle_single_device_sharding(self):
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
Expand Down
22 changes: 1 addition & 21 deletions tests/pmap_test.py
Expand Up @@ -2988,25 +2988,12 @@ def device_array(x):
# unsharded
[(4, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=())],
# partitioned, 1 axis
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))],
# partitioned, 2 axes
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))],
# partitioned, 2 axes, permuted
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))],
# replication + sharding
[(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))],
# replication, no sharding
[(2, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=(pxla.Replicated(3),))],
# multiple replicated axes
[(1, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([1]), pxla.Chunked([2])),
mesh_mapping=(pxla.Replicated(2), pxla.ShardedAxis(0),
pxla.Replicated(2), pxla.ShardedAxis(1)))],
# replicated scalar
[(), pxla.ShardingSpec(sharding=(),
mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))],
Expand All @@ -3018,14 +3005,7 @@ def testShardArgs(self, shape, spec, make_arg):
raise SkipTest
x = np.arange(math.prod(shape)).reshape(shape)
arg = make_arg(x)
sharding = None
if any(isinstance(s, pxla.Unstacked) for s in spec.sharding):
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
else:
sharding = jax.sharding.GSPMDSharding(
jax.devices()[:nshards],
sharding_specs.sharding_spec_sharding_proto(spec))

sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
results = pxla.shard_args(
jax.devices()[:nshards], [indices], [sharding], [arg]
)
Expand Down

0 comments on commit c0d4653

Please sign in to comment.