diff --git a/jax/_src/api.py b/jax/_src/api.py index ce275955ab43..d14701da6495 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1396,7 +1396,8 @@ def vmap(fun: F, in_axes: Union[int, Sequence[Any]] = 0, out_axes: Any = 0, axis_name: Optional[Hashable] = None, - axis_size: Optional[int] = None) -> F: + axis_size: Optional[int] = None, + spmd_axis_name: Optional[Hashable] = None) -> F: """Vectorizing map. Creates a function which maps ``fun`` over argument axes. Args: @@ -1441,6 +1442,8 @@ def vmap(fun: F, axis so that parallel collectives can be applied. axis_size: Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments. + spmd_axis_name: Optional, a hashable Python object to insert into any + mapped PartitionSpecs. Returns: Batched/vectorized version of ``fun`` with arguments that correspond to @@ -1564,7 +1567,8 @@ def vmap_f(*args, **kwargs): kws=True)) out_flat = batching.batch( flat_fun, axis_name, axis_size_, in_axes_flat, - lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) + lambda: flatten_axes("vmap out_axes", out_tree(), out_axes), + spmd_axis_name=spmd_axis_name ).call_wrapped(*args_flat) return tree_unflatten(out_tree(), out_flat) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 05450e0388cd..aa639a00e607 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -115,7 +115,7 @@ def maybe_bdim_at_front(x, bdim): def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size): f, out_axes = batching.batch_subtrace(f) f = batching._batch_outer(f, axis_name, axis_size, in_axes, - batching.BatchTrace) + batching.BatchTrace, None) outs = f.call_wrapped(*args) return outs, out_axes() diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 6af5643f1a5c..c02ab22f93c4 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -727,7 +727,7 @@ def _custom_vjp_call_jaxpr_jvp( return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp -def _custom_vjp_call_jaxpr_vmap( +def _custom_vjp_call_jaxpr_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]], bwd: lu.WrappedFun, out_trees: Callable, num_consts: int): @@ -752,7 +752,7 @@ def batched_fwd_jaxpr_thunk(): fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims, - fwd_args_batched, main_type) + fwd_args_batched, main_type, spmd_axis_name) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, @@ -760,7 +760,8 @@ def batched_fwd_jaxpr_thunk(): out_trees=out_trees, num_consts=num_consts) out_dims = out_dims2[0] if out_dims2 else out_dims1 return batched_outs, out_dims -batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap +batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap +batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None) xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 5da12f101d83..9ded8075ea2a 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -912,7 +912,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, mlir.register_lowering(pjit_p, _pjit_lowering) -def _pjit_batcher(insert_axis, +def _pjit_batcher(insert_axis, spmd_axis_name, axis_size, axis_name, main_type, vals_in, dims_in, jaxpr, in_shardings, out_shardings, @@ -927,7 +927,8 @@ def _pjit_batcher(insert_axis, instantiate=False, axis_name=axis_name, main_type=main_type) # `insert_axis` is set to True only for some `xmap` uses. - new_parts = (axis_name,) if insert_axis else () + new_parts = (axis_name,) if insert_axis else ( + () if spmd_axis_name is None else (spmd_axis_name,)) mesh = resource_env.physical_mesh in_shardings = tuple( _pjit_batcher_for_sharding(i, 0, new_parts, mesh, aval.ndim) if is_mapped else i @@ -947,8 +948,9 @@ def _pjit_batcher(insert_axis, out_positional_semantics=out_positional_semantics) dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out] return vals_out, dims_out -batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False) -pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True) +batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False) +batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None) +pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None) def _pjit_batcher_for_sharding( s: OpShardingSharding, dim: int, val: Tuple[str, ...], mesh, ndim: int): @@ -1259,13 +1261,14 @@ def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding, _sharding_constraint_mhlo_lowering) -def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type, - vals_in, dims_in, sharding, resource_env, - unconstrained_dims): +def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, + axis_name, main_type, vals_in, dims_in, + sharding, resource_env, unconstrained_dims): x, = vals_in d, = dims_in # None means unconstrained in ParsedPartitionSpec - new_parts = (axis_name,) if insert_axis else None + new_parts = (axis_name,) if insert_axis else ( + None if spmd_axis_name is None else (spmd_axis_name,)) y = sharding_constraint_p.bind( x, sharding=_pjit_batcher_for_sharding( @@ -1273,8 +1276,12 @@ def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type, resource_env=resource_env, unconstrained_dims={ud + (d <= ud) for ud in unconstrained_dims}) return y, d -batching.axis_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, False) -pxla.spmd_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, True) +batching.spmd_axis_primitive_batchers[sharding_constraint_p] = partial( + _sharding_constraint_batcher, False) +batching.axis_primitive_batchers[sharding_constraint_p] = partial( + _sharding_constraint_batcher, False, None) +pxla.spmd_primitive_batchers[sharding_constraint_p] = partial( + _sharding_constraint_batcher, True, None) def _resource_typing_sharding_constraint(avals, params, source_info, diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 7746e24c2601..d08e6cd09470 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -13,8 +13,8 @@ # limitations under the License. from functools import partial -from typing import (Any, Callable, Dict, Set, Optional, Tuple, Union, Iterable, - Type, Sequence) +from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence, + Set, Tuple, Type, Union) import numpy as np @@ -161,9 +161,11 @@ def _contents(self): return [('val', self.val), ('batch_dim', self.batch_dim)] class BatchTrace(Trace): - def __init__(self, *args, axis_name): + + def __init__(self, *args, axis_name, spmd_axis_name = None): super().__init__(*args) self.axis_name = axis_name + self.spmd_axis_name = spmd_axis_name def pure(self, val): return BatchTracer(self, val, not_mapped, source_info_util.current()) @@ -177,6 +179,10 @@ def sublift(self, val): def get_primitive_batcher(self, primitive, frame): if primitive in primitive_batchers: return primitive_batchers[primitive] + elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers: + return partial(spmd_axis_primitive_batchers[primitive], + self.spmd_axis_name, frame.size, frame.name, + frame.main_trace.trace_type) elif primitive in axis_primitive_batchers: return self.get_axis_primitive_batcher(primitive, frame) msg = "Batching rule for '{}' not implemented" @@ -335,7 +341,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees): fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims) bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, - out_dims2, in_dims, self.main.trace_type) + out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name) out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: @@ -367,7 +373,7 @@ def todo(vals): return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) def bwd_transform(bwd): return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,), - trace_type) + trace_type, self.spmd_axis_name) return vals, todo, bwd_transform def _main_trace_for_axis_names(main_trace: core.MainTrace, @@ -382,14 +388,17 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace, def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size, in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace, - ) -> lu.WrappedFun: + spmd_axis_name: Optional[Hashable] = None) -> lu.WrappedFun: # we split up _batch_inner and _batch_outer for the leak checker f = _batch_inner(fun, axis_size, out_dim_dests) - return _batch_outer(f, axis_name, axis_size, in_dims, main_type) + return _batch_outer(f, axis_name, axis_size, in_dims, main_type, + spmd_axis_name) @lu.transformation -def _batch_outer(axis_name, axis_size, in_dims, main_type, *in_vals): - with core.new_main(main_type, axis_name=axis_name) as main: +def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name, + *in_vals): + with core.new_main( + main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main: with core.extend_axis_env(axis_name, axis_size, main): with source_info_util.transform_name_stack('vmap'): outs = yield (main, in_dims, *in_vals), {} @@ -558,9 +567,9 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals): out_tangent_bds, out_dims, out_tangents) yield out_primals + out_tangents, out_dims * 2 -def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type): +def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name): bwd, out_dims_thunk = batch_subtrace(bwd) - bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type) + bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type, spmd_axis_name) return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests) @lu.transformation @@ -594,6 +603,7 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False) BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]] primitive_batchers : Dict[core.Primitive, BatchingRule] = {} axis_primitive_batchers: Dict[core.Primitive, Callable] = {} +spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {} def defvectorized(prim): primitive_batchers[prim] = partial(vectorized_batcher, prim) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ee2903a40295..2443eb1eca4b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -554,6 +554,25 @@ def testVMapShardingConstraint(self): self.assertListEqual(op.tile_assignment_devices, [0, 1]) self.assertFalse(pxla.is_op_sharding_replicated(op)) + @jtu.with_mesh([('x', 2)]) + def testVMapShardingConstraintWithSpmdAxis(self): + f = pjit( + jax.vmap( + lambda x: with_sharding_constraint(x, P(None)), + spmd_axis_name='x', + ), + in_axis_resources=P('x'), + out_axis_resources=P('x')) + x = jnp.arange(16 * 4).reshape((16, 4)) + jaxpr = jax.make_jaxpr(f)(x) + pjit_eqn, = jaxpr.eqns + constraint_eqn, = pjit_eqn.params['jaxpr'].eqns + op = constraint_eqn.params['sharding']._op_sharding + self.assertEqual(op.type, xc.OpSharding.Type.OTHER) + self.assertListEqual(op.tile_assignment_dimensions, [2, 1]) + self.assertListEqual(op.tile_assignment_devices, [0, 1]) + self.assertFalse(pxla.is_op_sharding_replicated(op)) + @jtu.with_mesh([('x', 2), ('y', 1)]) def testShardingInXMap(self): h = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)