Skip to content

Commit

Permalink
AWN-enabled reduction over named axes in reverse-mode AD
Browse files Browse the repository at this point in the history
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.

In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.

If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.

Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.

Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
  - reductions aren't fused into any first-order primitives (e.g. a `pdot`
    should have a named contracting axis added rather than being followed by a
    `psum`; this can be implemented by putting these primitives into
    `reducing_transposes`)
  - reductions are performed eagerly, even over axes that are mapped to
    hardware resources (the optimal thing to do would be to reduce eagerly
    over any vectorized axis component while delaying the reduction over any
    hardware-mapped component until the end of the overall backward pass; this
    would require a way to represent these partially-reduced values)

PiperOrigin-RevId: 383685336
  • Loading branch information
jekbradbury authored and jax authors committed Jul 8, 2021
1 parent f0c3049 commit 8e86952
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 56 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...main).
* New features:
* New SciPy function {py:func}`jax.scipy.special.sph_harm`.
* Reverse-mode autodiff functions ({func}`jax.grad`,
{func}`jax.value_and_grad`, {func}`jax.vjp`, and
{func}`jax.linear_transpose`) support a parameter that indicates which named
axes should be summed over in the backward pass if they were broadcasted
over in the forward pass. This enables use of these APIs in a
non-per-example way inside maps (initially only
{func}`jax.experimental.maps.xmap`) ({jax-issue}`#6950`).

## jaxlib 0.1.69 (unreleased)

Expand Down
97 changes: 70 additions & 27 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,9 @@ def xla_computation(fun: Callable,
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
Expand All @@ -597,7 +597,7 @@ def xla_computation(fun: Callable,
``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the
wrapped function returns a pair where the first element is the XLA
Computation and the second element is a pytree representing the structure,
shapes, and dtypes of the output of ``fun``.
shapes, dtypes, and named shapes of the output of ``fun``.
Concrete example arguments are not always necessary. For those arguments not
indicated by ``static_argnums``, any object with ``shape`` and ``dtype``
Expand Down Expand Up @@ -750,7 +750,8 @@ def computation_maker(*args, **kwargs):
shapes = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d]
warn(f"Some donated buffers were not usable: {', '.join(shapes)}")
built = c.build(out_tuple)
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, xla.ShapedArray):
Expand All @@ -767,7 +768,8 @@ def computation_maker(*args, **kwargs):

def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False) -> Callable:
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
"""Creates a function which evaluates the gradient of ``fun``.
Args:
Expand All @@ -787,6 +789,13 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``grad(f, reduce_axes=('batch',))`` will create a
function that computes the total gradient while ``grad(f)`` will create
one that computes the per-example gradient.
Returns:
A function with the same arguments as ``fun``, that evaluates the gradient
Expand All @@ -806,7 +815,8 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
"""
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int)
allow_int=allow_int,
reduce_axes=reduce_axes)

docstr = ("Gradient of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
Expand All @@ -829,7 +839,8 @@ def grad_f_aux(*args, **kwargs):

def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False) -> Callable[..., Tuple[Any, Any]]:
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
) -> Callable[..., Tuple[Any, Any]]:
"""Create a function which evaluates both ``fun`` and the gradient of ``fun``.
Args:
Expand All @@ -847,6 +858,14 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
gradient will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``value_and_grad(f, reduce_axes=('batch',))`` will
create a function that computes the total gradient while
``value_and_grad(f)`` will create one that computes the per-example
gradient.
Returns:
A function with the same arguments as ``fun`` that evaluates both ``fun``
Expand Down Expand Up @@ -879,9 +898,10 @@ def value_and_grad_f(*args, **kwargs):
f_partial, dyn_args = argnums_partial(f, argnums, args)
tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
if not has_aux:
ans, vjp_py = _vjp(f_partial, *dyn_args)
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
else:
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
ans, vjp_py, aux = _vjp(
f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
_check_scalar(ans)
dtype = dtypes.result_type(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
Expand Down Expand Up @@ -1891,12 +1911,14 @@ def _vjp_pullback_wrapper(cotangent_dtypes, io_tree, fun, py_args):
@overload # type: ignore
def vjp(fun: Callable[..., T],
*primals: Any,
has_aux: Literal[False] = False) -> Tuple[T, Callable]:
has_aux: Literal[False] = False,
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable]:
...

@overload
def vjp(fun: Callable[..., Tuple[T, U]], *primals: Any,
has_aux: Literal[True]) -> Tuple[T, Callable, U]:
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> Tuple[T, Callable, U]:
...
else:

Expand All @@ -1907,12 +1929,14 @@ def vjp(fun: Callable[..., T], *primals: Any) -> Tuple[T, Callable]:
@overload
def vjp(
fun: Callable[..., Any], *primals: Any,
has_aux: bool) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
has_aux: bool,
reduce_axes: Sequence[AxisName] = ()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
...


def vjp( # type: ignore
fun: Callable, *primals, has_aux: bool = False,
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
Expand All @@ -1929,6 +1953,13 @@ def vjp( # type: ignore
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding gradient. Otherwise, the
VJP will be per-example over named axes. For example, if ``'batch'``
is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will
create a VJP function that sums over the batch while ``vjp(f, *args)``
will create a per-example VJP.
Returns:
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
Expand All @@ -1953,19 +1984,22 @@ def vjp( # type: ignore
-0.2524413
"""
_check_callable(fun)
return _vjp(lu.wrap_init(fun), *primals, has_aux=has_aux)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux, reduce_axes=reduce_axes)

def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False, reduce_axes=()):
"""Variant of vjp() that takes an lu.WrappedFun."""
primals_flat, in_tree = tree_flatten(primals)
for arg in primals_flat: _check_arg(arg)
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
out_primal, out_vjp = ad.vjp(
flat_fun, primals_flat, reduce_axes=reduce_axes)
out_tree = out_tree()
else:
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
out_primal, out_vjp, aux = ad.vjp(
flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes)
out_tree, aux_tree = out_aux_trees()
out_primal_py = tree_unflatten(out_tree, out_primal)
ct_dtypes = [core.primal_dtype_to_tangent_dtype(_dtype(x)) for x in out_primal]
Expand All @@ -1981,7 +2015,7 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)


def linear_transpose(fun: Callable, *primals) -> Callable:
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
"""Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to ``vjp``, but
Expand All @@ -2002,6 +2036,14 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
is not required: only the ``shape`` and ``dtype`` attributes are accessed.
See below for an example. (Note that the duck-typed objects cannot be
namedtuples because those are treated as standard Python containers.)
reduce_axes: Optional, tuple of axis names. If an axis is listed here, and
``fun`` implicitly broadcasts a value over that axis, the backward pass
will perform a ``psum`` of the corresponding cotangent. Otherwise, the
transposed function will be per-example over named axes. For example, if
``'batch'`` is a named batch axis, ``linear_transpose(f, *args,
reduce_axes=('batch',))`` will create a transpose function that sums over
the batch while ``linear_transpose(f, args)`` will create a per-example
transpose.
Returns:
A callable that calculates the transpose of ``fun``. Valid input into this
Expand Down Expand Up @@ -2046,7 +2088,7 @@ def transposed_fun(out_cotangent):
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cotangents = map(
ad.instantiate_zeros,
ad.backward_pass(jaxpr, consts, dummies, out_cotangents))
ad.backward_pass(jaxpr, reduce_axes, consts, dummies, out_cotangents))
return tree_unflatten(in_tree, in_cotangents)

return transposed_fun
Expand All @@ -2071,19 +2113,19 @@ def make_jaxpr(fun: Callable,
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the ``jaxpr``
and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
argument ``return_shape`` is ``True``, then the returned function instead
returns a pair where the first element is the ``ClosedJaxpr``
representation of ``fun`` and the second element is a pytree representing
the structure, shape, and dtypes of the output of ``fun``.
the structure, shape, dtypes, and named shapes of the output of ``fun``.
A ``jaxpr`` is JAX's intermediate representation for program traces. The
``jaxpr`` language is based on the simply-typed first-order lambda calculus
Expand Down Expand Up @@ -2135,7 +2177,8 @@ def jaxpr_maker(*args, **kwargs):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
return closed_jaxpr

Expand Down
9 changes: 5 additions & 4 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,12 @@ def batched_jvp_jaxpr_thunk():
# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
# already been linearized, we can drop the jvp rule.
def _custom_jvp_call_jaxpr_transpose(cts, *args, fun_jaxpr, jvp_jaxpr_thunk,
num_consts):
def _custom_jvp_call_jaxpr_transpose(reduce_axes, cts, *args, fun_jaxpr,
jvp_jaxpr_thunk, num_consts):
del jvp_jaxpr_thunk, num_consts
return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)
ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose
return ad.backward_pass(
fun_jaxpr.jaxpr, reduce_axes, fun_jaxpr.consts, args, cts)
ad.reducing_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose


### VJPs
Expand Down
25 changes: 14 additions & 11 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())

def _transpose_cond_jaxpr(jaxpr, num_res):
def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = _map(raise_to_shaped, primal_avals)

Expand All @@ -1029,19 +1029,19 @@ def transposed(*args):
res, cts_out = split_list(args, [num_res])
primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
cts_in = ad.backward_pass(
jaxpr.jaxpr, jaxpr.consts, primals, cts_out)
jaxpr.jaxpr, reduce_axes, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)

return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)

def _cond_transpose(cts, *args, branches, linear):
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
index, *ops = args
in_avals = _map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)

branches_trans = tuple(
_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches)
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
lin_in_avals = [raise_to_shaped(a, weak_type=False)
for a, l in zip(in_avals, linear) if l]
assert all(core.typematch(out_aval, lin_in_aval)
Expand Down Expand Up @@ -1131,7 +1131,7 @@ def cond_bind(*args, branches, linear):
cond_p.def_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.primitive_transposes[cond_p] = _cond_transpose
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.initial_style_batchers[cond_p] = _cond_batching_rule
xla.initial_style_translations[cond_p] = _cond_translation_rule
Expand Down Expand Up @@ -1673,8 +1673,8 @@ def _maybe_device_put(x):
else:
return x

def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll):
def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll):
# we've only implemented transposing scans with specific lin/nonlin patterns
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
num_ires = len(consts_lin) - sum(consts_lin)
Expand Down Expand Up @@ -1702,7 +1702,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,
# jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b])
# jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a])
jaxpr_trans = _transpose_scan_jaxpr(
num_ires, num_consts - num_ires, num_eres, jaxpr)
num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes)
linear_trans = ([False] * num_ires +
[True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) +
[False] * num_eres)
Expand All @@ -1717,8 +1717,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, num_carry, jaxpr,

# transpose_scan_jaxpr :: ([res1, c, a, res2] -> b)
# -> ([res1, CT c, CT b, res2] -> [CT c, CT a])
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes):
num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2
# TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals
# if an axis isn't reduced
res1_avals, c_avals, a_avals, res2_avals = split_list(
jaxpr.in_avals, [num_res1, num_c, num_a])
num_b = len(jaxpr.out_avals)
Expand All @@ -1730,7 +1732,8 @@ def transposed(*res1_cbar_bbar_res2):
res1_cbar_bbar_res2, [num_res1, num_c, num_b])
primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] +
[ad.UndefinedPrimal(aval) for aval in a_avals] + res2)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
cbar_abar = ad.backward_pass(jaxpr.jaxpr, reduce_axes, jaxpr.consts,
primals, b_bar)
_, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a])
a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar)
c_bar = _map(ad.instantiate_zeros_aval, c_avals,
Expand Down Expand Up @@ -1879,7 +1882,7 @@ def scan_bind(*args, **params):
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
scan_p.def_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.primitive_transposes[scan_p] = _scan_transpose
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.initial_style_translations[scan_p] = xla.lower_fun_initial_style(_scan_impl)
batching.initial_style_batchers[scan_p] = _scan_batching_rule
Expand Down

0 comments on commit 8e86952

Please sign in to comment.