Skip to content

Commit

Permalink
test JVP of while loop, and fix the nonzero tangent calculation in th…
Browse files Browse the repository at this point in the history
…e JVP rule
  • Loading branch information
froystig committed Jan 16, 2020
1 parent afb8af1 commit 335ecb9
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jax/lax/lax_control_flow.py
Expand Up @@ -314,16 +314,16 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,

carry_nz = init_nz
for _ in range(1 + len(carry_nz)):
nonzeros = bconst_nz + carry_nz
body_nonzeros = bconst_nz + carry_nz
body_jvp, nonzeros_out = ad.jvp_jaxpr(
body_jaxpr, nonzeros, instantiate=carry_nz)
body_jaxpr, body_nonzeros, instantiate=carry_nz)
if nonzeros_out == carry_nz:
break
else:
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
carry_nz = _map(operator.or_, carry_nz, nonzeros_out)
else:
assert False, "Fixpoint not reached"

nonzeros = cconst_nz + body_nonzeros
tangents = [ad.instantiate_zeros(x, t) if t is ad_util.zero and nz else t
for x, t, nz in zip(primals, tangents, nonzeros)]

Expand Down
73 changes: 73 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -42,6 +42,12 @@
config.parse_flags_with_absl()


def while_loop_reference(cond, body, carry):
while cond(carry):
carry = body(carry)
return carry


def scan_reference(f, init, xs):
carry = init
ys = []
Expand Down Expand Up @@ -1117,6 +1123,73 @@ def testWhileCondConstant(self):
out = lax.while_loop(lambda _: False, lambda _: (), ()) # doesn't crash
self.assertEqual(out, ())

@parameterized.named_parameters(
{"testcase_name": "_jit_loop={}_jit_body={}_jit_cond={}".format(
jit_loop, jit_body, jit_cond),
"jit_loop": jit_loop, "jit_body": jit_body, "jit_cond": jit_cond}
for jit_loop in [False, True]
for jit_body in [False, True]
for jit_cond in [False, True])
def testWhileJVP(self, jit_loop, jit_body, jit_cond):
cond = lambda x: x[0, 2] <= 8
body = lambda x: x * x

if jit_cond:
cond = api.jit(cond)
if jit_body:
body = api.jit(body)

loop = partial(lax.while_loop, cond, body)
if jit_loop:
loop = api.jit(loop)

loop_ref = partial(while_loop_reference, cond, body)

x = np.arange(9.).reshape((3, 3))
ans = api.jvp(loop, (x,), (x,))
expected = api.jvp(loop_ref, (x,), (x,))
self.assertAllClose(ans, expected, check_dtypes=False)

jtu.check_grads(loop, (x,), order=2, modes=["fwd"])

def testWhileJVPViaForiLoop(self):
f = lambda x: lax.fori_loop(0, 3, lambda i, x: x * 2, x)
self.assertAllClose(f(2.), 16., check_dtypes=False)
self.assertAllClose(api.jvp(f, (2.,), (1.,)), (16., 8.), check_dtypes=False)
jtu.check_grads(f, (2.,), order=2, modes=["fwd"])

f = lambda x: lax.fori_loop(0, 3, lambda i, x: x * (i + 1), x)
self.assertAllClose(f(2.), 12., check_dtypes=False)
self.assertAllClose(api.jvp(f, (2.,), (1.,)), (12., 6.), check_dtypes=False)
jtu.check_grads(f, (2.,), order=2, modes=["fwd"])

def testWhileJVPWithGrowingNonzeroTangents(self):
rng = onp.random.RandomState(0)

def cond(state):
i, x, y, z = state
return i < 2

def body(state):
i, x, y, z = state
y = x * x
z = y * y
return i + 1, x, y, z

y, z = rng.randn(2), rng.randn(2)
def loop(loop_impl, x):
return loop_impl(cond, body, (0, x, y, z))[1]

loop_lax = partial(loop, lax.while_loop)
loop_ref = partial(loop, while_loop_reference)

x = rng.randn(2)
ans = api.jvp(loop_lax, (x,), (x,))
expected = api.jvp(loop_ref, (x,), (x,))
self.assertAllClose(ans, expected, check_dtypes=False)

jtu.check_grads(loop_lax, (x,), order=2, modes=["fwd"])

def testIssue1316(self):
def f(carry, _):
c, key = carry
Expand Down

0 comments on commit 335ecb9

Please sign in to comment.