Skip to content

Commit

Permalink
Add support for post_process of xmap in BatchTrace
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 459108183
  • Loading branch information
apaszke authored and jax authors committed Jul 5, 2022
1 parent 0719f98 commit 5777c1e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
50 changes: 36 additions & 14 deletions jax/experimental/maps.py
Expand Up @@ -49,7 +49,7 @@
from jax.interpreters import ad
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2,
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
as_hashable_function, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name,
merge_lists, partition_list)
Expand Down Expand Up @@ -848,7 +848,10 @@ def process(self, trace, fun, tracers, params):
return trace.process_xmap(self, fun, tracers, params)

def post_process(self, trace, out_tracers, params):
raise NotImplementedError
post_process = getattr(trace, 'post_process_xmap', None)
if post_process is None:
raise NotImplementedError
return post_process(self, out_tracers, params)

def get_bind_params(self, params):
new_params = dict(params)
Expand Down Expand Up @@ -1228,6 +1231,17 @@ def new_spmd_out_axes_thunk():

return new_spmd_in_axes, new_spmd_out_axes_thunk

def _axis_after_insertion(axis, inserted_named_axes):
for inserted_axis in sorted(inserted_named_axes.values()):
if inserted_axis >= axis:
break
axis += 1
return axis

def _fmap_dims(axes, f):
return AxisNamePos(((name, f(axis)) for name, axis in axes.items()),
user_repr=axes.user_repr)

def _batch_trace_process_xmap(self, is_spmd, primitive, f: lu.WrappedFun, tracers, params):
not_mapped = batching.not_mapped
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
Expand All @@ -1236,33 +1250,24 @@ def _batch_trace_process_xmap(self, is_spmd, primitive, f: lu.WrappedFun, tracer
return primitive.bind(f, *vals, **params)
else:
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
def fmap_dims(axes, f):
return AxisNamePos(((name, f(axis)) for name, axis in axes.items()),
user_repr=axes.user_repr)
new_in_axes = tuple(
fmap_dims(in_axes, lambda a: a + (d is not not_mapped and d <= a))
_fmap_dims(in_axes, lambda a: a + (d is not not_mapped and d <= a))
for d, in_axes in zip(dims, params['in_axes']))
mapped_dims_in = tuple(
d if d is not_mapped else d - sum(a < d for a in in_axis.values())
for d, in_axis in zip(dims, params['in_axes']))
f, mapped_dims_out = batching.batch_subtrace(f, self.main, mapped_dims_in)
out_axes_thunk: Callable[[], Sequence[AxisNamePos]] = params['out_axes_thunk']
dims_out_thunk = lambda: tuple(d if d is not_mapped else axis_after_insertion(d, out_axes)
dims_out_thunk = lambda: tuple(d if d is not_mapped else _axis_after_insertion(d, out_axes)
for d, out_axes in zip(mapped_dims_out(), out_axes_thunk()))
def axis_after_insertion(axis, inserted_named_axes):
for inserted_axis in sorted(inserted_named_axes.values()):
if inserted_axis >= axis:
break
axis += 1
return axis
# NOTE: This assumes that the choice of the dimensions over which outputs
# are batched is entirely dependent on the function and not e.g. on the
# data or its shapes.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
return tuple(
out_axes if d is not_mapped else
fmap_dims(out_axes, lambda a, nd=axis_after_insertion(d, out_axes): a + (nd <= a))
_fmap_dims(out_axes, lambda a, nd=_axis_after_insertion(d, out_axes): a + (nd <= a))
for out_axes, d in zip(out_axes_thunk(), mapped_dims_out()))

if not is_spmd:
Expand All @@ -1285,6 +1290,23 @@ def new_out_axes_thunk():
pxla.SPMDBatchTrace.process_xmap = partialmethod(_batch_trace_process_xmap, True) # type: ignore


def _batch_trace_post_process_xmap(self, primitive, out_tracers, params):
not_mapped = batching.not_mapped
BT = batching.BatchTracer
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
return [BT(trace, v, d if d is not_mapped else _axis_after_insertion(d, oa), s)
for v, d, oa, s in zip(vals, dims, params['out_axes_thunk'](), srcs)]
def out_axes_transform(out_axes):
return tuple(oa if d is not_mapped else
_fmap_dims(oa, lambda a, nd=_axis_after_insertion(d, oa): a + (nd <= a))
for oa, d in zip(out_axes, dims))
return vals, (todo, out_axes_transform)
batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap


# -------- nested xmap handling --------

def _xmap_lowering_rule(ctx, *args, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions tests/xmap_test.py
Expand Up @@ -552,6 +552,12 @@ def testNestedMap(self,
y = rng.randn(*yshape)
self.assertAllClose(fm(x, y), fref(x, y))

def testBatchingPostProcess(self):
x = jnp.arange(10).reshape(5, 2)
f = jax.vmap(lambda y: xmap(lambda x: x + y, in_axes=['i', ...], out_axes=['i', ...])(x))
ref = jax.vmap(lambda y: jax.vmap(lambda x: x + y)(x))
self.assertAllClose(f(x * 2), ref(x * 2))

def testAutodiffBroadcast(self):
f = xmap(lambda x, y: jnp.cos(lax.dot(x, jnp.sin(y),
precision=lax.Precision.HIGHEST)),
Expand Down

0 comments on commit 5777c1e

Please sign in to comment.