Skip to content

Commit

Permalink
Initial transpose implementation
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
sharadmv and mattjj committed Aug 15, 2022
1 parent b90aa87 commit 72dbe31
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 2 deletions.
68 changes: 68 additions & 0 deletions jax/_src/lax/control_flow/for_loop.py
Expand Up @@ -34,6 +34,7 @@
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list)
import jax.numpy as jnp
import numpy as np

from jax._src.lax.control_flow import loops
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
Expand Down Expand Up @@ -468,6 +469,73 @@ def eval_jaxpr(i, *refs):
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
return jaxpr

def transpose_jaxpr(jaxpr: core.Jaxpr, which_linear: List[bool]) -> core.Jaxpr:
which_linear = map(bool, np.cumsum(which_linear).astype(np.bool_))
def trans(i, *args):
# First we want to run the computation to read all the residual refs. We can
# do that by using partial evaluation with all linear inputs unknown.
res_jaxpr, tangent_jaxpr_, *_ = \
_partial_eval_jaxpr_custom(jaxpr, [False, *which_linear],
_save_everything)
res_args = [x for x, lin in zip(args, which_linear) if not lin]
res = core.eval_jaxpr(res_jaxpr, (), i, *res_args)

# Now that we have residual values, we run the tangent jaxpr. It takes as
# input the residuals, the loop index, and all the refs (at least, the ones
# that are used in the body). Luckily, `tangent_jaxpr_` has all known and
# unknown inputs!
tangent_jaxpr, used = pe.dce_jaxpr(tangent_jaxpr_, [])
used_res, (used_i,), used_ct = split_list(used, [len(res), 1])
primals_args = [*(r for u, r in zip(used_res, res) if u)]
if used_i:
primals_args = [*primals_args, i]
ct_args = [x for x, u in zip(args, used_ct) if u]
ad.backward_pass(
tangent_jaxpr, (), False, (), (*primals_args, *ct_args), ())
return []
jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
return jaxpr_trans

def _for_transpose(in_cts, *args, jaxpr, nsteps, reverse, which_linear):
# if any in_ct is nonzero, we definitely want it in args_ (and the
# corresponding x in args could be an undefined primal, but doesnt have to be)
# for non-res stuff:
# getting and setting => (nonzero ct, UndefinedPrimal arg)
# just setting => (nonzero ct, not UndefinedPrimal, dummy value)
# just getting => (zero ct , UndefinedPrimal arg)
# for res stuff:
# (zero ct , not UndefinedPrimal)
args_ = []
which_linear_transpose = []
for x, ct in zip(args, in_cts):
if type(ct) is ad_util.Zero and not ad.is_undefined_primal(x):
# this is a residual, take x!
args_.append(x)
which_linear_transpose.append(False)
elif type(ct) is ad_util.Zero and ad.is_undefined_primal(x):
# the loop was 'just getting', plug in a zero
args_.append(ad_util.zeros_like_aval(x.aval))
which_linear_transpose.append(False)
elif type(ct) is not ad_util.Zero and not ad.is_undefined_primal(x):
# the loop was 'just setting', grab that cotangent! x is dummy
args_.append(ct)
which_linear_transpose.append(False)
elif type(ct) is not ad_util.Zero and ad.is_undefined_primal(x):
# the loop was 'getting and setting', grab that cotangent!
args_.append(ct)
which_linear_transpose.append(True)

jaxpr_transpose = transpose_jaxpr(jaxpr, which_linear)
assert len(args_) == len(jaxpr_transpose.invars) - 1
all_outs = for_p.bind(*args_, jaxpr=jaxpr_transpose, nsteps=nsteps,
reverse=not reverse,
which_linear=tuple(which_linear_transpose))
ct_outs = [ct if ad.is_undefined_primal(x) else None
for x, ct in zip(args, all_outs)]
return ct_outs
ad.primitive_transposes[for_p] = _for_transpose

### Testing utility

def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False):
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/state/primitives.py
Expand Up @@ -252,3 +252,10 @@ def _swap_transpose(g, ref, x, *idx):
x_bar = ref_swap(ref, idx, ad_util.instantiate(g))
return [None, x_bar] + [None] * len(idx)
ad.primitive_transposes[swap_p] = _swap_transpose

def addupdate_transpose(cts_in, ref, x, *idx):
# addupdate transpose is get
del cts_in, x
g = ref_get(ref, idx)
return [None] + [None] * len(idx) + [g]
ad.primitive_transposes[addupdate_p] = addupdate_transpose
2 changes: 1 addition & 1 deletion jax/interpreters/ad.py
Expand Up @@ -173,7 +173,7 @@ def recast_to_float0(primal, tangent):
# errors if you will)
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
consts, primals_in, cotangents_in):
if all(type(ct) is Zero for ct in cotangents_in):
if all(type(ct) is Zero for ct in cotangents_in) and not jaxpr.effects:
return map(lambda v: Zero(v.aval), jaxpr.invars)

def write_cotangent(prim, v, ct):
Expand Down
35 changes: 34 additions & 1 deletion tests/lax_control_flow_test.py
Expand Up @@ -1607,7 +1607,7 @@ def f(c, a):
"jit_scan": jit_scan, "jit_f": jit_f, "scan": scan_impl}
for jit_scan in [False, True]
for jit_f in [False, True]
for scan_impl, scan_name in SCAN_IMPLS)
for scan_impl, scan_name in SCAN_IMPLS_WITH_FOR)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testScanGrad(self, jit_scan, jit_f, scan):
rng = self.rng()
Expand Down Expand Up @@ -2799,6 +2799,39 @@ def f(a, b):
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])

@parameterized.named_parameters(
{"testcase_name": "_jit_for={}_f={}_nsteps={}".format(
jit_for, for_body_name, nsteps),
"jit_for": jit_for, "f": for_body, "body_shapes": body_shapes,
"ref": ref, "n": nsteps}
for jit_for in [False, True]
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
def test_for_grad(self, jit_for, f, ref, body_shapes, n):
for_ = for_loop.for_loop
rng = self.rng()

args = [rng.randn(*s) for s in body_shapes]

if jit_for:
for_ = jax.jit(for_, static_argnums=(0, 1))
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.grad(lambda args: for_( n, f, args)[1].sum())(args)
ans_discharged = jax.grad(
lambda args: for_reference(n, f, args)[1].sum())(args)
expected = jax.grad(lambda args: ref(*args)[1].sum())(args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol,
atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 72dbe31

Please sign in to comment.