Skip to content

Commit

Permalink
Merge pull request #11278 from sharadmv:for-loop
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 457798502
  • Loading branch information
jax authors committed Jun 28, 2022
2 parents 6d8c6f8 + e1ba52b commit 90af8e8
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 1 deletion.
69 changes: 68 additions & 1 deletion jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Module for the `for_loop` primitive."""
from functools import partial
import operator

from typing import Any, Callable, Dict, Generic, List, Sequence, Tuple, TypeVar

Expand All @@ -31,6 +32,8 @@
from jax._src.util import safe_map, safe_zip, split_list
import jax.numpy as jnp

from jax._src.lax.control_flow import loops

## JAX utilities

map, unsafe_map = safe_map, map
Expand Down Expand Up @@ -268,7 +271,7 @@ def _swap_jvp(primals: List[Any], tangents: List[Any]):
ref_tangent, x_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, ShapedArrayRef)
x_tangent = ad_util.instantiate(x_tangent)
return (ref_swap(ref_tangent, idx, x_primal), # type: ignore[arg-type]
return (ref_swap(ref_primal, idx, x_primal), # type: ignore[arg-type]
ref_swap(ref_tangent, idx, x_tangent)) # type: ignore[arg-type]
ad.primitive_jvps[swap_p] = _swap_jvp

Expand Down Expand Up @@ -498,3 +501,67 @@ def body(carry):
return state
mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
for_p.def_impl(partial(xla.apply_primitive, for_p))

def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear):
nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents]
# We need to find out which `Ref`s have nonzero tangents after running the
# for loop. Ordinarily we do this with a fixed point on the body jaxpr but
# a `for` body jaxpr is stateful and has no outputs. We therefore discharge
# the state effect from the jaxpr and we will now have a "symmetric" jaxpr
# where the inputs line up with the outputs. We use this discharged jaxpr
# for the fixed point.
discharged_jaxpr, body_consts = discharge_state(jaxpr, ())
for _ in range(len(nonzero_tangents)):
_, out_nonzero_tangents = ad.jvp_jaxpr(
core.ClosedJaxpr(discharged_jaxpr, body_consts),
[False] + nonzero_tangents, instantiate=nonzero_tangents)
if out_nonzero_tangents == nonzero_tangents:
break
nonzero_tangents = map(operator.or_, nonzero_tangents, out_nonzero_tangents)
else:
raise Exception("Invalid fixpoint")
tangents = [ad.instantiate_zeros(t) if inst else t for t, inst in
zip(tangents, nonzero_tangents)]
tangents = [t for t in tangents if type(t) is not ad_util.Zero]
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, [False] + nonzero_tangents, [])
jvp_jaxpr, jvp_consts = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts
jvp_which_linear = ((False,) * len(jvp_consts) + which_linear
+ (True,) * len(tangents))
out_flat = for_p.bind(*jvp_consts, *primals, *tangents, jaxpr=jvp_jaxpr,
nsteps=nsteps, reverse=reverse,
which_linear=jvp_which_linear)
# `out_flat` includes constant inputs into the `for_loop` which are
# converted into outputs as well. We don't care about these in AD so we
# throw them out.
_, out_primals, out_tangents = split_list(out_flat,
[len(jvp_consts), len(primals)])
out_tangents_iter = iter(out_tangents)
out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p)
for p, nz in zip(out_primals, nonzero_tangents)]
return out_primals, out_tangents
ad.primitive_jvps[for_p] = _for_jvp


### Testing utility

def discharged_for_loop(nsteps, body, init_state):
"""A `for_loop` implementation that discharges its body right away.
Potentially useful for testing and benchmarking.
"""
flat_state, state_tree = tree_flatten(init_state)
state_avals = map(val_to_ref_aval, flat_state)
idx_aval = core.ShapedArray((), jnp.dtype("int32"))
jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs(
body, state_tree, [idx_aval, *state_avals])
if out_tree != tree_structure(None):
raise Exception("`body` should not return anything.")
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)

def fori_body(i, carry):
out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts,
jnp.int32(i), *carry)
return out_flat
out_flat = loops.fori_loop(0, nsteps, fori_body, flat_state)
return tree_unflatten(state_tree, out_flat)
93 changes: 93 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2882,5 +2882,98 @@ def body(i, refs):
x = jax.random.normal(key, (8,))
np.testing.assert_allclose(cumsum(x), jnp.cumsum(x))

def for_body_swap(i, refs):
a_ref, b_ref = refs
a, b = a_ref[i], b_ref[i]
b_ref[i] = a
a_ref[i] = b

def swap_ref(a, b):
return b, a

def for_body_swap_swap(i, refs):
for_body_swap(i, refs)
for_body_swap(i, refs)

swap_swap_ref = lambda a, b: (a, b)

def for_body_sincos(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.sin(jnp.cos(a))

sincos_ref = lambda x, y: (x, jnp.sin(jnp.cos(x)))

def for_body_sincostan(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.tan(jnp.sin(jnp.cos(a)))

sincostan_ref = lambda x, y: (x, jnp.tan(jnp.sin(jnp.cos(x))))

def for_body_accum(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]

def accum_ref(x, accum):
for i in range(x.shape[0] - 1):
accum = accum.at[i + 1].set(accum[i] + x[i])
return x, accum

def for_body_sin_sq(i, refs):
x_ref, y_ref = refs
x = x_ref[i]
y = x
y_ref[i] = y
y = y_ref[i]
y_ref[i] = jnp.sin(y * y)

sin_sq_ref = lambda x, y: (x, jnp.sin(x * x))

def for_body_reverse(i, refs):
x_ref, y_ref = refs
j = y_ref.shape[0] - i - 1
y_ref[i] = x_ref[j]

reverse_ref = lambda x, y: (x, x[::-1])

identity = lambda x, y: (x, y)
for_reference = for_loop.discharged_for_loop


class ForLoopTransformationTest(jtu.JaxTestCase):

@parameterized.named_parameters(
{"testcase_name": "_jit_for={}_f={}_nsteps={}".format(
jit_for, for_body_name, nsteps),
"jit_for": jit_for, "f": for_body, "body_shapes": body_shapes,
"ref": ref, "n": nsteps}
for jit_for in [False, True]
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
def test_for_jvp(self, jit_for, f, ref, body_shapes, n):
for_ = for_loop.for_loop
rng = self.rng()

args = [rng.randn(*s) for s in body_shapes]

if jit_for:
for_ = jax.jit(for_, static_argnums=(0, 1))
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.jvp( lambda *args: for_( n, f, args), args, args)
ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args)
expected = jax.jvp(ref, args, args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 90af8e8

Please sign in to comment.