From e1ba52bb256fdcfd6b9ba08c265c82cd91efc876 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 27 Jun 2022 18:20:32 -0700 Subject: [PATCH] Add tests for `jvp(for_loop)` --- tests/lax_control_flow_test.py | 93 ++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d7c7b6a17eb2..e2adebcde632 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2882,5 +2882,98 @@ def body(i, refs): x = jax.random.normal(key, (8,)) np.testing.assert_allclose(cumsum(x), jnp.cumsum(x)) +def for_body_swap(i, refs): + a_ref, b_ref = refs + a, b = a_ref[i], b_ref[i] + b_ref[i] = a + a_ref[i] = b + +def swap_ref(a, b): + return b, a + +def for_body_swap_swap(i, refs): + for_body_swap(i, refs) + for_body_swap(i, refs) + +swap_swap_ref = lambda a, b: (a, b) + +def for_body_sincos(i, refs): + a_ref, b_ref = refs + a = a_ref[i] + b_ref[i] = jnp.sin(jnp.cos(a)) + +sincos_ref = lambda x, y: (x, jnp.sin(jnp.cos(x))) + +def for_body_sincostan(i, refs): + a_ref, b_ref = refs + a = a_ref[i] + b_ref[i] = jnp.tan(jnp.sin(jnp.cos(a))) + +sincostan_ref = lambda x, y: (x, jnp.tan(jnp.sin(jnp.cos(x)))) + +def for_body_accum(i, refs): + x_ref, accum_ref = refs + accum_ref[i + 1] = accum_ref[i] + x_ref[i] + +def accum_ref(x, accum): + for i in range(x.shape[0] - 1): + accum = accum.at[i + 1].set(accum[i] + x[i]) + return x, accum + +def for_body_sin_sq(i, refs): + x_ref, y_ref = refs + x = x_ref[i] + y = x + y_ref[i] = y + y = y_ref[i] + y_ref[i] = jnp.sin(y * y) + +sin_sq_ref = lambda x, y: (x, jnp.sin(x * x)) + +def for_body_reverse(i, refs): + x_ref, y_ref = refs + j = y_ref.shape[0] - i - 1 + y_ref[i] = x_ref[j] + +reverse_ref = lambda x, y: (x, x[::-1]) + +identity = lambda x, y: (x, y) +for_reference = for_loop.discharged_for_loop + + +class ForLoopTransformationTest(jtu.JaxTestCase): + + @parameterized.named_parameters( + {"testcase_name": "_jit_for={}_f={}_nsteps={}".format( + jit_for, for_body_name, nsteps), + "jit_for": jit_for, "f": for_body, "body_shapes": body_shapes, + "ref": ref, "n": nsteps} + for jit_for in [False, True] + for for_body_name, for_body, ref, body_shapes, nsteps in [ + ("swap", for_body_swap, swap_ref, [(4,), (4,)], 4), + ("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4), + ("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4), + ("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4), + ("accum", for_body_accum, accum_ref, [(4,), (4,)], 3), + ("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4), + ("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4), + ]) + def test_for_jvp(self, jit_for, f, ref, body_shapes, n): + for_ = for_loop.for_loop + rng = self.rng() + + args = [rng.randn(*s) for s in body_shapes] + + if jit_for: + for_ = jax.jit(for_, static_argnums=(0, 1)) + tol = {np.float64: 1e-12, np.float32: 1e-4} + ans = jax.jvp( lambda *args: for_( n, f, args), args, args) + ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args) + expected = jax.jvp(ref, args, args) + self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol) + self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol) + jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"]) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())