Skip to content

Commit

Permalink
Allow unevenly partitioned sharding_constraints.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 429601508
  • Loading branch information
ukoxyz authored and jax authors committed Feb 18, 2022
1 parent 1baa59c commit 8cb1692
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 27 deletions.
34 changes: 22 additions & 12 deletions jax/experimental/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,11 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
# first one is a valid spec for a scalar value, while the second is not!
_check_shapes_against_resources(
"pjit arguments", mesh.is_multi_process, mesh.shape, local_in_avals,
in_axis_resources_flat)
in_axis_resources_flat, allow_uneven_sharding=False)
else:
_check_shapes_against_resources("pjit arguments", False, mesh.local_mesh.shape,
local_in_avals, in_axis_resources_flat)
local_in_avals, in_axis_resources_flat,
allow_uneven_sharding=False)

global_in_avals = local_to_global(in_positional_semantics, mesh,
local_in_avals, canonicalized_in_axis_resources_flat)
Expand All @@ -400,7 +401,8 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
"pjit out_axis_resources", out_tree(),
out_axis_resources_thunk(), tupled_args=False)
_check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape,
global_out_avals, out_axis_resources_flat)
global_out_avals, out_axis_resources_flat,
allow_uneven_sharding=False)
canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat)
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return _ListWithW([jaxpr, canonicalized_in_axis_resources_flat,
Expand Down Expand Up @@ -553,8 +555,9 @@ def _check_unique_resources(axis_resources, arg_name):
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {pxla.show_axes(multiple_uses)}")

def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape,
flat_avals, flat_axis_resources):
def _check_shapes_against_resources(what: str, is_global_shape: bool,
mesh_shape, flat_avals, flat_axis_resources,
allow_uneven_sharding: bool):
global_str = " global" if is_global_shape else ""
for aval, aval_axis_resources in zip(flat_avals, flat_axis_resources):
if _is_from_gda(aval_axis_resources):
Expand All @@ -574,7 +577,7 @@ def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape
raise ValueError(f"One of {what} was given the resource assignment "
f"of {aval_axis_resources.user_spec}, but resource axis "
f"{e.args[0]} is undefined. Did you forget to declare the mesh?") from None
if shape[i] % size != 0:
if not allow_uneven_sharding and shape[i] % size != 0:
raise ValueError(f"One of {what} was given the resource assignment "
f"of {aval_axis_resources.user_spec}, which implies that "
f"the{global_str} size of its dimension {i} should be "
Expand Down Expand Up @@ -930,7 +933,7 @@ def with_sharding_constraint(x, axis_resources):
_check_shapes_against_resources(
"with_sharding_constraint arguments",
mesh.is_multi_process, mesh.shape,
x_flat, axis_resources_flat)
x_flat, axis_resources_flat, allow_uneven_sharding=True)
outs = [sharding_constraint_p.bind(y, axis_resources=r, resource_env=resource_env)
for y, r in safe_zip(x_flat, axis_resources_flat)]
return tree_unflatten(tree, outs)
Expand All @@ -957,7 +960,8 @@ def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, *,
xla.set_sharding_proto(
ctx.builder,
x_node,
get_aval_sharding_proto(aval, axis_resources, mesh),
get_aval_sharding_proto(
aval, axis_resources, mesh, allow_uneven_axes=True),
unspecified_dims=get_unconstrained_dims(axis_resources))
]
xla.register_translation(sharding_constraint_p, _sharding_constraint_translation_rule)
Expand All @@ -969,7 +973,12 @@ def _sharding_constraint_mhlo_lowering(ctx, x_node, *, axis_resources,
return [
mlir.wrap_with_sharding_op(
x_node,
get_aval_sharding_proto(aval, axis_resources, mesh, ctx.module_context.axis_context),
get_aval_sharding_proto(
aval,
axis_resources,
mesh,
ctx.module_context.axis_context,
allow_uneven_axes=True),
unspecified_dims=get_unconstrained_dims(axis_resources))
]
mlir.register_lowering(sharding_constraint_p,
Expand Down Expand Up @@ -1009,10 +1018,11 @@ def get_array_mapping(axis_resources: ParsedPartitionSpec) -> pxla.ArrayMapping:
def get_aval_sharding_proto(aval: core.AbstractValue,
axis_resources: ParsedPartitionSpec,
mesh: maps.Mesh,
axis_ctx: Optional[mlir.SPMDAxisContext] = None) -> xc.OpSharding:
axis_ctx: Optional[mlir.SPMDAxisContext] = None,
allow_uneven_axes: bool = False) -> xc.OpSharding:
array_mapping = get_array_mapping(axis_resources)
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval, array_mapping)
sharding_spec = pxla.mesh_sharding_specs(
mesh.shape, mesh.axis_names, allow_uneven_axes=True)(aval, array_mapping)
special_axes = {}
if axis_ctx is not None:
axis_names = mesh.axis_names
Expand Down
5 changes: 3 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2414,7 +2414,7 @@ def _check_aval(aval, what_thunk):
_check_aval(v.aval, what_thunk)


def mesh_sharding_specs(axis_sizes, axis_names):
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
# NOTE: This takes in the non-sharded avals!
def mk_sharding_spec(aval, aval_axes):
Expand All @@ -2428,7 +2428,8 @@ def mk_sharding_spec(aval, aval_axes):
# NOTE: sorted is stable, which is important when multiple resources
# map to the same axis.
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
assert aval_shape[axis] % axis_sizes[name] == 0, (axis_sizes[name], aval.shape[axis])
if not allow_uneven_axes:
assert aval_shape[axis] % axis_sizes[name] == 0, (axis_sizes[name], aval.shape[axis])
aval_shape[axis] //= axis_sizes[name]
chunked = sharding[axis]
if isinstance(chunked, NoSharding):
Expand Down
36 changes: 23 additions & 13 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ def f(x, y):
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
check_dtypes=False)

@jtu.with_mesh([('x', 2)])
def testUnevenShardingConstraint(self):
@partial(pjit,
in_axis_resources=(P('x'), P('x')),
out_axis_resources=None)
def f(x, y):
x = x[:3]
y = y[:3]
x = with_sharding_constraint(x, P('x'))
y = with_sharding_constraint(y, P('x'))
out = x + y
return jnp.pad(out, [[0, 1]])

shape = (4,)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
actual = f(x, x + 1)
expected = x + (x + 1)
self.assertAllClose(actual[:3], expected[:3], check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py()[:3], expected[:3],
check_dtypes=False)

def testBasic1DWithMeshContextManager(self):
@partial(pjit,
in_axis_resources=(P('x'), P('x')),
Expand Down Expand Up @@ -1057,19 +1080,6 @@ def testNonDivisibleOuts(self, mesh, resources):
r"divisible by " + mesh_size + r", but it is equal to 3"):
pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P(resources, None))(x)

@check_1d_2d_mesh(set_mesh=True)
def testNonDivisibleConstraint(self, mesh, resources):
x = jnp.ones((3, 2))
spec = P(resources,)
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
with self.assertRaisesRegex(ValueError,
r"One of with_sharding_constraint arguments"
r".*" + spec_regex(spec) + r".*implies that the size of "
r"its dimension 0 should be divisible by " + mesh_size +
r", but it is equal to 3"):
pjit(lambda x: with_sharding_constraint(x, spec),
in_axis_resources=None, out_axis_resources=None)(x)

@check_1d_2d_mesh(set_mesh=False)
@jtu.with_mesh([('z', 1)])
def testUndefinedResourcesArgs(self, mesh, resources):
Expand Down

0 comments on commit 8cb1692

Please sign in to comment.