Skip to content

Commit

Permalink
Add tests for jvp(for_loop)
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Jun 28, 2022
1 parent c4b938f commit e1ba52b
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions tests/lax_control_flow_test.py
Expand Up @@ -2882,5 +2882,98 @@ def body(i, refs):
x = jax.random.normal(key, (8,))
np.testing.assert_allclose(cumsum(x), jnp.cumsum(x))

def for_body_swap(i, refs):
a_ref, b_ref = refs
a, b = a_ref[i], b_ref[i]
b_ref[i] = a
a_ref[i] = b

def swap_ref(a, b):
return b, a

def for_body_swap_swap(i, refs):
for_body_swap(i, refs)
for_body_swap(i, refs)

swap_swap_ref = lambda a, b: (a, b)

def for_body_sincos(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.sin(jnp.cos(a))

sincos_ref = lambda x, y: (x, jnp.sin(jnp.cos(x)))

def for_body_sincostan(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.tan(jnp.sin(jnp.cos(a)))

sincostan_ref = lambda x, y: (x, jnp.tan(jnp.sin(jnp.cos(x))))

def for_body_accum(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]

def accum_ref(x, accum):
for i in range(x.shape[0] - 1):
accum = accum.at[i + 1].set(accum[i] + x[i])
return x, accum

def for_body_sin_sq(i, refs):
x_ref, y_ref = refs
x = x_ref[i]
y = x
y_ref[i] = y
y = y_ref[i]
y_ref[i] = jnp.sin(y * y)

sin_sq_ref = lambda x, y: (x, jnp.sin(x * x))

def for_body_reverse(i, refs):
x_ref, y_ref = refs
j = y_ref.shape[0] - i - 1
y_ref[i] = x_ref[j]

reverse_ref = lambda x, y: (x, x[::-1])

identity = lambda x, y: (x, y)
for_reference = for_loop.discharged_for_loop


class ForLoopTransformationTest(jtu.JaxTestCase):

@parameterized.named_parameters(
{"testcase_name": "_jit_for={}_f={}_nsteps={}".format(
jit_for, for_body_name, nsteps),
"jit_for": jit_for, "f": for_body, "body_shapes": body_shapes,
"ref": ref, "n": nsteps}
for jit_for in [False, True]
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
def test_for_jvp(self, jit_for, f, ref, body_shapes, n):
for_ = for_loop.for_loop
rng = self.rng()

args = [rng.randn(*s) for s in body_shapes]

if jit_for:
for_ = jax.jit(for_, static_argnums=(0, 1))
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.jvp( lambda *args: for_( n, f, args), args, args)
ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args)
expected = jax.jvp(ref, args, args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit e1ba52b

Please sign in to comment.