Skip to content

Commit

Permalink
Add spmd_axis_name to vmap to allow constraining mapped PartitionSp…
Browse files Browse the repository at this point in the history
…ecs.
  • Loading branch information
pschuh committed Aug 9, 2022
1 parent c4192dc commit 8fb9573
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 27 deletions.
8 changes: 6 additions & 2 deletions jax/_src/api.py
Expand Up @@ -1396,7 +1396,8 @@ def vmap(fun: F,
in_axes: Union[int, Sequence[Any]] = 0,
out_axes: Any = 0,
axis_name: Optional[Hashable] = None,
axis_size: Optional[int] = None) -> F:
axis_size: Optional[int] = None,
spmd_axis_name: Optional[Hashable] = None) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
Args:
Expand Down Expand Up @@ -1441,6 +1442,8 @@ def vmap(fun: F,
axis so that parallel collectives can be applied.
axis_size: Optional, an integer indicating the size of the axis to be
mapped. If not provided, the mapped axis size is inferred from arguments.
spmd_axis_name: Optional, a hashable Python object to insert into any
mapped PartitionSpecs.
Returns:
Batched/vectorized version of ``fun`` with arguments that correspond to
Expand Down Expand Up @@ -1564,7 +1567,8 @@ def vmap_f(*args, **kwargs):
kws=True))
out_flat = batching.batch(
flat_fun, axis_name, axis_size_, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
spmd_axis_name=spmd_axis_name
).call_wrapped(*args_flat)
return tree_unflatten(out_tree(), out_flat)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_batching.py
Expand Up @@ -115,7 +115,7 @@ def maybe_bdim_at_front(x, bdim):
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
f, out_axes = batching.batch_subtrace(f)
f = batching._batch_outer(f, axis_name, axis_size, in_axes,
batching.BatchTrace)
batching.BatchTrace, None)
outs = f.call_wrapped(*args)
return outs, out_axes()

Expand Down
7 changes: 4 additions & 3 deletions jax/_src/custom_derivatives.py
Expand Up @@ -727,7 +727,7 @@ def _custom_vjp_call_jaxpr_jvp(
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

def _custom_vjp_call_jaxpr_vmap(
def _custom_vjp_call_jaxpr_vmap(spmd_axis_name,
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
Expand All @@ -752,15 +752,16 @@ def batched_fwd_jaxpr_thunk():
fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(bwd, axis_name, axis_size, fwd_out_dims,
fwd_args_batched, main_type)
fwd_args_batched, main_type, spmd_axis_name)

batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
out_trees=out_trees, num_consts=num_consts)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(_custom_vjp_call_jaxpr_vmap, None)

xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)

Expand Down
27 changes: 17 additions & 10 deletions jax/experimental/pjit.py
Expand Up @@ -912,7 +912,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
mlir.register_lowering(pjit_p, _pjit_lowering)


def _pjit_batcher(insert_axis,
def _pjit_batcher(insert_axis, spmd_axis_name,
axis_size, axis_name, main_type,
vals_in, dims_in,
jaxpr, in_shardings, out_shardings,
Expand All @@ -927,7 +927,8 @@ def _pjit_batcher(insert_axis,
instantiate=False, axis_name=axis_name, main_type=main_type)

# `insert_axis` is set to True only for some `xmap` uses.
new_parts = (axis_name,) if insert_axis else ()
new_parts = (axis_name,) if insert_axis else (
() if spmd_axis_name is None else (spmd_axis_name,))
mesh = resource_env.physical_mesh
in_shardings = tuple(
_pjit_batcher_for_sharding(i, 0, new_parts, mesh, aval.ndim) if is_mapped else i
Expand All @@ -947,8 +948,9 @@ def _pjit_batcher(insert_axis,
out_positional_semantics=out_positional_semantics)
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
return vals_out, dims_out
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True)
batching.spmd_axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)

def _pjit_batcher_for_sharding(
s: OpShardingSharding, dim: int, val: Tuple[str, ...], mesh, ndim: int):
Expand Down Expand Up @@ -1259,22 +1261,27 @@ def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding,
_sharding_constraint_mhlo_lowering)


def _sharding_constraint_batcher(insert_axis, axis_size, axis_name, main_type,
vals_in, dims_in, sharding, resource_env,
unconstrained_dims):
def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
axis_name, main_type, vals_in, dims_in,
sharding, resource_env, unconstrained_dims):
x, = vals_in
d, = dims_in
# None means unconstrained in ParsedPartitionSpec
new_parts = (axis_name,) if insert_axis else None
new_parts = (axis_name,) if insert_axis else (
None if spmd_axis_name is None else (spmd_axis_name,))
y = sharding_constraint_p.bind(
x,
sharding=_pjit_batcher_for_sharding(
sharding, d, new_parts, resource_env.physical_mesh, x.ndim),
resource_env=resource_env,
unconstrained_dims={ud + (d <= ud) for ud in unconstrained_dims})
return y, d
batching.axis_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, False)
pxla.spmd_primitive_batchers[sharding_constraint_p] = partial(_sharding_constraint_batcher, True)
batching.spmd_axis_primitive_batchers[sharding_constraint_p] = partial(
_sharding_constraint_batcher, False)
batching.axis_primitive_batchers[sharding_constraint_p] = partial(
_sharding_constraint_batcher, False, None)
pxla.spmd_primitive_batchers[sharding_constraint_p] = partial(
_sharding_constraint_batcher, True, None)


def _resource_typing_sharding_constraint(avals, params, source_info,
Expand Down
32 changes: 21 additions & 11 deletions jax/interpreters/batching.py
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

from functools import partial
from typing import (Any, Callable, Dict, Set, Optional, Tuple, Union, Iterable,
Type, Sequence)
from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence,
Set, Tuple, Type, Union)

import numpy as np

Expand Down Expand Up @@ -161,9 +161,11 @@ def _contents(self):
return [('val', self.val), ('batch_dim', self.batch_dim)]

class BatchTrace(Trace):
def __init__(self, *args, axis_name):

def __init__(self, *args, axis_name, spmd_axis_name = None):
super().__init__(*args)
self.axis_name = axis_name
self.spmd_axis_name = spmd_axis_name

def pure(self, val):
return BatchTracer(self, val, not_mapped, source_info_util.current())
Expand All @@ -177,6 +179,10 @@ def sublift(self, val):
def get_primitive_batcher(self, primitive, frame):
if primitive in primitive_batchers:
return primitive_batchers[primitive]
elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers:
return partial(spmd_axis_primitive_batchers[primitive],
self.spmd_axis_name, frame.size, frame.name,
frame.main_trace.trace_type)
elif primitive in axis_primitive_batchers:
return self.get_axis_primitive_batcher(primitive, frame)
msg = "Batching rule for '{}' not implemented"
Expand Down Expand Up @@ -335,7 +341,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
out_dims2, in_dims, self.main.trace_type)
out_dims2, in_dims, self.main.trace_type, self.spmd_axis_name)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
Expand Down Expand Up @@ -367,7 +373,7 @@ def todo(vals):
return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs)
def bwd_transform(bwd):
return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,),
trace_type)
trace_type, self.spmd_axis_name)
return vals, todo, bwd_transform

def _main_trace_for_axis_names(main_trace: core.MainTrace,
Expand All @@ -382,14 +388,17 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace,

def batch(fun: lu.WrappedFun, axis_name: core.AxisName, axis_size,
in_dims, out_dim_dests, main_type: Type[BatchTrace] = BatchTrace,
) -> lu.WrappedFun:
spmd_axis_name: Optional[Hashable] = None) -> lu.WrappedFun:
# we split up _batch_inner and _batch_outer for the leak checker
f = _batch_inner(fun, axis_size, out_dim_dests)
return _batch_outer(f, axis_name, axis_size, in_dims, main_type)
return _batch_outer(f, axis_name, axis_size, in_dims, main_type,
spmd_axis_name)

@lu.transformation
def _batch_outer(axis_name, axis_size, in_dims, main_type, *in_vals):
with core.new_main(main_type, axis_name=axis_name) as main:
def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name,
*in_vals):
with core.new_main(
main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main):
with source_info_util.transform_name_stack('vmap'):
outs = yield (main, in_dims, *in_vals), {}
Expand Down Expand Up @@ -558,9 +567,9 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
out_tangent_bds, out_dims, out_tangents)
yield out_primals + out_tangents, out_dims * 2

def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type):
def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, main_type, spmd_axis_name):
bwd, out_dims_thunk = batch_subtrace(bwd)
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type)
bwd_ = _batch_outer(bwd, axis_name, axis_size, in_dims, main_type, spmd_axis_name)
return _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, out_dim_dests)

@lu.transformation
Expand Down Expand Up @@ -594,6 +603,7 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}

def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)
Expand Down
19 changes: 19 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -554,6 +554,25 @@ def testVMapShardingConstraint(self):
self.assertListEqual(op.tile_assignment_devices, [0, 1])
self.assertFalse(pxla.is_op_sharding_replicated(op))

@jtu.with_mesh([('x', 2)])
def testVMapShardingConstraintWithSpmdAxis(self):
f = pjit(
jax.vmap(
lambda x: with_sharding_constraint(x, P(None)),
spmd_axis_name='x',
),
in_axis_resources=P('x'),
out_axis_resources=P('x'))
x = jnp.arange(16 * 4).reshape((16, 4))
jaxpr = jax.make_jaxpr(f)(x)
pjit_eqn, = jaxpr.eqns
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
op = constraint_eqn.params['sharding']._op_sharding
self.assertEqual(op.type, xc.OpSharding.Type.OTHER)
self.assertListEqual(op.tile_assignment_dimensions, [2, 1])
self.assertListEqual(op.tile_assignment_devices, [0, 1])
self.assertFalse(pxla.is_op_sharding_replicated(op))

@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingInXMap(self):
h = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None)
Expand Down

0 comments on commit 8fb9573

Please sign in to comment.