Skip to content

Commit

Permalink
fix ad_checkpoint.checkpoint vmap rule
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 5, 2022
1 parent 4618f9c commit b92c6b1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
8 changes: 4 additions & 4 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
policy = params['policy'] or (lambda *_, **__: False)
# unzip into jaxpr_known and jaxpr_unknown
in_unknowns = [not t.is_known() for t in tracers]
# TODO(mattjj): use cached version of pe.partial_eval_jaxpr_custom
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
Expand Down Expand Up @@ -374,11 +375,10 @@ def transposed(*args):

def remat_vmap(axis_size, axis_name, main_type, args, dims, *, jaxpr, **params):
assert not jaxpr.constvars
in_batched = [d is not batching.not_mapped for d in dims]
jaxpr_ = core.ClosedJaxpr(jaxpr, ())
jaxpr_batched_, out_batched = batching.batch_jaxpr(
jaxpr_, axis_size, in_batched, instantiate=False, axis_name=axis_name,
main_type=main_type)
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
jaxpr_, axis_size, dims, [batching.zero_if_mapped] * len(jaxpr.outvars),
axis_name=axis_name, main_type=main_type)
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
out_dims = [0 if b else None for b in out_batched]
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
Expand Down
12 changes: 7 additions & 5 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from jax import linear_util as lu
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name,
split_list, canonicalize_axis, moveaxis,
as_hashable_function, curry, memoize, cache)
as_hashable_function, curry, memoize,
weakref_lru_cache)
from jax.interpreters import partial_eval as pe

map = safe_map
Expand Down Expand Up @@ -473,9 +474,9 @@ def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name,
return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes),
tuple(out_axes_dest), axis_name, main_type)

@cache()
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name,
main_type):
@weakref_lru_cache
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
axis_name, main_type):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest)
f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
Expand Down Expand Up @@ -527,7 +528,8 @@ def _merge_bdims(x, y):
else:
return x # arbitrary

zero_if_mapped = object()
class ZeroIfMapped: pass
zero_if_mapped = ZeroIfMapped()

### functions for handling custom_vjp

Expand Down
18 changes: 18 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3631,6 +3631,24 @@ def g(x):
expected = np.diag(np.cos(np.sin(x)) * np.cos(x))
self.assertAllClose(ans, expected, check_dtypes=False)

@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
('', api.remat),
('_policy', partial(api.remat, policy=lambda *_, **__: False)),
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
])
def test_remat_vmap_not_leading_dim(self, remat):
@remat
def g(x):
return lax.sin(lax.sin(x))

x = np.arange(3 * 5.).reshape(3, 5)

ans = api.vmap(g, 1, 0)(x)
expected = np.sin(np.sin(x)).T
self.assertAllClose(ans, expected, check_dtypes=False)

@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}
for suffix, remat in [
Expand Down

0 comments on commit b92c6b1

Please sign in to comment.