Skip to content

Commit

Permalink
Combine initial_style_batchers with collective_rules
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Sep 9, 2021
1 parent 8353d6d commit cc3e197
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 114 deletions.
4 changes: 3 additions & 1 deletion jax/_src/api.py
Expand Up @@ -1263,6 +1263,8 @@ def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F:
docstr += "\n\nOriginal documentation:\n\n"
docstr += fun.__doc__

axis_name = core.no_axis_name if axis_name is None else axis_name

if isinstance(in_axes, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
Expand Down Expand Up @@ -2460,7 +2462,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]):
raise ValueError("`devices` argument to `device_put_replicated must be "
"a non-empty sequence.")
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), None, 0,
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
core.raise_to_shaped(core.get_aval(x)))
assert isinstance(aval, core.ShapedArray) and aval._num_buffers == 1
buf, = xla.device_put(x, devices[0])
Expand Down
20 changes: 9 additions & 11 deletions jax/_src/custom_derivatives.py
Expand Up @@ -305,7 +305,7 @@ def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **p
del args, params
return fun_jaxpr.out_avals

custom_jvp_call_jaxpr_p = core.Primitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p = core.AxisPrimitive('custom_jvp_call_jaxpr')
custom_jvp_call_jaxpr_p.multiple_results = True
custom_jvp_call_jaxpr_p.def_impl(_custom_jvp_call_jaxpr_impl)
custom_jvp_call_jaxpr_p.def_abstract_eval(_custom_jvp_call_jaxpr_abstract_eval)
Expand All @@ -331,31 +331,30 @@ def _custom_jvp_call_jaxpr_jvp(
ad.primitive_jvps[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_jvp

def _custom_jvp_call_jaxpr_vmap(
args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int):
size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
num_out = len(fun_jaxpr.out_avals)

in_batched = [d is not not_mapped for d in in_dims]
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, size, in_batched, False, axis_name, main_type)
fun_jaxpr, axis_size, in_batched, False, axis_name, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = [] # mutable cell updated by batched_jvp_jaxpr_thunk

@pe._memoize
def batched_jvp_jaxpr_thunk():
jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
_, args_batched = split_list(in_batched, [num_consts])
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, size, args_batched * 2, False,
_, all_batched = batching.batch_jaxpr(jvp_jaxpr, axis_size, args_batched * 2, False,
axis_name, main_type)
primals_batched, tangents_batched = split_list(all_batched, [num_out])
out_batched = map(op.or_, primals_batched, tangents_batched)
out_dims2.append([0 if b else not_mapped for b in out_batched])
batched_jvp_jaxpr, _ = batching.batch_jaxpr(
jvp_jaxpr, size, args_batched * 2, out_batched * 2,
jvp_jaxpr, axis_size, args_batched * 2, out_batched * 2,
axis_name, main_type)
return batched_jvp_jaxpr.jaxpr, batched_jvp_jaxpr.consts

Expand All @@ -364,7 +363,7 @@ def batched_jvp_jaxpr_thunk():
jvp_jaxpr_thunk=batched_jvp_jaxpr_thunk, num_consts=num_consts)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.initial_style_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap

xla.initial_style_translations[custom_jvp_call_jaxpr_p] = \
xla.lower_fun_initial_style(_custom_jvp_call_jaxpr_impl)
Expand Down Expand Up @@ -611,7 +610,7 @@ def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_):
def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__):
return fun_jaxpr.out_avals

custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr')
custom_vjp_call_jaxpr_p = core.AxisPrimitive('custom_vjp_call_jaxpr')
custom_vjp_call_jaxpr_p.multiple_results = True
custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl)
custom_vjp_call_jaxpr_p.def_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval)
Expand Down Expand Up @@ -641,10 +640,9 @@ def _custom_vjp_call_jaxpr_jvp(
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

def _custom_vjp_call_jaxpr_vmap(
args, in_dims, axis_name, main_type, *, fun_jaxpr: core.ClosedJaxpr,
axis_size, axis_name, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[[], Tuple[core.Jaxpr, Sequence[Any]]],
bwd: lu.WrappedFun, out_trees: Callable, num_consts: int):
axis_size, = {x.shape[d] for x, d in zip(args, in_dims) if d is not not_mapped}
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]

Expand Down Expand Up @@ -674,7 +672,7 @@ def batched_fwd_jaxpr_thunk():
out_trees=out_trees, num_consts=num_consts)
out_dims = out_dims2[0] if out_dims2 else out_dims1
return batched_outs, out_dims
batching.initial_style_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap

xla.initial_style_translations[custom_vjp_call_jaxpr_p] = \
xla.lower_fun_initial_style(_custom_vjp_call_jaxpr_impl)
Expand Down

0 comments on commit cc3e197

Please sign in to comment.