Skip to content

Commit

Permalink
Split for_loop_test out of lax_control_flow_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 473848277
  • Loading branch information
sharadmv authored and jax authors committed Sep 12, 2022
1 parent ef48533 commit e5725f1
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 335 deletions.
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,16 @@ jax_test(
],
)

jax_test(
name = "for_loop_test",
srcs = ["for_loop_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
},
)

jax_test(
name = "clear_backends_test",
srcs = ["clear_backends_test.py"],
Expand Down
365 changes: 365 additions & 0 deletions tests/for_loop_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from absl.testing import absltest
from absl.testing import parameterized

import numpy as np

import jax
from jax._src import test_util as jtu
import jax.numpy as jnp
from jax._src.lax.control_flow import for_loop

from jax.config import config
config.parse_flags_with_absl()

def remat_of_for_loop(nsteps, body, state, **kwargs):
return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state,
**kwargs))(state)

FOR_LOOP_IMPLS = [
(for_loop.for_loop, 'for_loop'),
(jax.jit(for_loop.for_loop, static_argnums=(0, 1)), 'jit_for_loop'),
(remat_of_for_loop, 'remat_for_loop'),
]


def _for_loop_impls(f):
return parameterized.named_parameters(
dict(testcase_name=impl_name, for_impl=for_impl)
for for_impl, impl_name in FOR_LOOP_IMPLS
)(f)


class ForLoopTest(jtu.JaxTestCase):

@_for_loop_impls
def test_for_loop_impl_trivial(self, for_impl):
out = for_impl(5, lambda i, _: None, None)
self.assertIsNone(out)

@_for_loop_impls
def test_for_loop_can_write_to_ref(self, for_impl):
def body(_, x_ref):
x_ref[()] = jnp.float32(1.)
out = for_impl(1, body, jnp.float32(0.))
self.assertEqual(out, 1.)

def body2(i, x_ref):
x_ref[()] = jnp.float32(i)
out = for_impl(2, body2, jnp.float32(0.))
self.assertEqual(out, 1.)

def body3(i, x_ref):
x_ref[()] = jnp.float32(i) * 2.
out = for_impl(2, body3, jnp.float32(0.))
self.assertEqual(out, 2.)

@_for_loop_impls
def test_for_loop_can_write_to_multiple_refs(self, for_impl):
def body(_, refs):
x_ref, y_ref = refs
x_ref[()] = jnp.float32(1.)
y_ref[()] = jnp.float32(2.)
x, y = for_impl(1, body, (jnp.float32(0.), jnp.float32(0.)))
self.assertEqual(x, 1.)
self.assertEqual(y, 2.)

@_for_loop_impls
def test_for_loop_can_read_from_ref(self, for_impl):
def body(_, x_ref):
x_ref[()] # pylint: disable=pointless-statement
x = for_impl(1, body, jnp.float32(0.))
self.assertEqual(x, 0.)

@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_ref(self, for_impl):
def body(_, x_ref):
x = x_ref[()]
x_ref[()] = x + jnp.float32(1.)
x = for_impl(5, body, jnp.float32(0.))
self.assertEqual(x, 5.)

@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_refs(self, for_impl):
def body2(_, refs):
x_ref, y_ref = refs
x = x_ref[()]
y_ref[()] = x + 1.
x_ref[()] = x + 1.
x, y = for_impl(5, body2, (0., 0.))
self.assertEqual(x, 5.)
self.assertEqual(y, 5.)

@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_ref_slice(self, for_impl):
def body(i, x_ref):
x = x_ref[i]
x_ref[i] = x + jnp.float32(1.)
x = for_impl(4, body, jnp.ones(4, jnp.float32))
np.testing.assert_allclose(x, 2 * jnp.ones(4, jnp.float32))

def body2(i, x_ref):
x = x_ref[i, 0]
x_ref[i, 1] = x + x_ref[i, 1]
x = for_impl(4, body2, jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(
x, jnp.array([[0., 1.], [2., 5.], [4., 9.], [6., 13.]]))

@_for_loop_impls
def test_for_loop_can_implement_cumsum(self, for_impl):
def cumsum(x):
def body(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]
accum = jnp.zeros(x.shape[0] + 1, x.dtype)
_, accum_out = for_impl(x.shape[0], body, (x, accum))
return accum_out[1:]

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8,))
np.testing.assert_allclose(cumsum(x), jnp.cumsum(x))

def for_body_swap(i, refs):
a_ref, b_ref = refs
a, b = a_ref[i], b_ref[i]
b_ref[i] = a
a_ref[i] = b

def swap_ref(a, b):
return b, a

def for_body_swap_swap(i, refs):
for_body_swap(i, refs)
for_body_swap(i, refs)

swap_swap_ref = lambda a, b: (a, b)

def for_body_sincos(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.sin(jnp.cos(a))

sincos_ref = lambda x, y: (x, jnp.sin(jnp.cos(x)))

def for_body_sincostan(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.tan(jnp.sin(jnp.cos(a)))

sincostan_ref = lambda x, y: (x, jnp.tan(jnp.sin(jnp.cos(x))))

def for_body_accum(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]

def accum_ref(x, accum):
for i in range(x.shape[0] - 1):
accum = accum.at[i + 1].set(accum[i] + x[i])
return x, accum

def for_body_sin_sq(i, refs):
x_ref, y_ref = refs
x = x_ref[i]
y = x
y_ref[i] = y
y = y_ref[i]
y_ref[i] = jnp.sin(y * y)

sin_sq_ref = lambda x, y: (x, jnp.sin(x * x))

def for_body_reverse(i, refs):
x_ref, y_ref = refs
j = y_ref.shape[0] - i - 1
y_ref[i] = x_ref[j]

reverse_ref = lambda x, y: (x, x[::-1])

def for_body_noop(i, refs):
pass
noop_ref = lambda x, y: (x, y)
for_reference = for_loop.discharged_for_loop


class ForLoopTransformationTest(jtu.JaxTestCase):

@parameterized.named_parameters(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
for_body_name, nsteps, impl_name),
"f": for_body, "body_shapes": body_shapes,
"ref": ref, "n": nsteps, "for_impl": for_impl}
for for_impl, impl_name in FOR_LOOP_IMPLS
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
def test_for_jvp(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()

args = [rng.randn(*s) for s in body_shapes]

tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.jvp( lambda *args: for_( n, f, args), args, args)
ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args)
expected = jax.jvp(ref, args, args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])

@parameterized.named_parameters(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
for_body_name, nsteps, impl_name),
"f": for_body, "body_shapes": body_shapes,
"ref": ref, "n": nsteps, "for_impl": for_impl}
for for_impl, impl_name in FOR_LOOP_IMPLS
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
def test_for_linearize(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()

args = [rng.randn(*s) for s in body_shapes]

tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.linearize(lambda *args: for_( n, f, args), *args)[1](*args)
ans_discharged = jax.linearize(lambda *args: for_reference(n, f, args),
*args)[1](*args)
expected = jax.linearize(ref, *args)[1](*args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)

def test_for_loop_invar(self):
def f(x):
s = jnp.ones((2, 32), x.dtype)
def body(i, refs):
x_ref, y_ref = refs
y_ref[i] = s * x_ref[i] * jnp.cos(s)
# We should save `s` and `jnp.cos(s)` as residuals and not broadcast
# them.
return for_loop.for_loop(x.shape[0], body, (x, jnp.zeros_like(x)))
_, f_vjp = jax.linearize(f, jnp.ones((5, 2, 32)))
jaxpr = jax.make_jaxpr(f_vjp)(jnp.ones((5, 2, 32)))
consts = [v.aval for v in jaxpr.jaxpr.constvars
if v.aval.shape == (2, 32)]
self.assertLen(consts, 2)

def loss(A):
def step(x, _):
return jnp.matmul(A, x), None
init_x = jnp.zeros(A.shape[-1:])
last_x, _ = for_loop.scan(step, init_x, jnp.arange(10))
return jnp.sum(last_x)

A = jnp.zeros((3, 3))
# The second DUS was unnecessarily replicating A across time.
# We check XLA because _scan_impl is "underneath" the jaxpr language.
s = str(jax.xla_computation(jax.grad(loss))(A).as_hlo_text())
assert s.count("dynamic-update-slice(") < 2

@_for_loop_impls
def test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals(
self, for_impl):

def body(i, refs):
a_ref, b_ref, c_ref = refs
a = a_ref[i]
b = b_ref[()]
x = jnp.sin(a)
b_ref[()] = jnp.sin(b * x)
c_ref[i] = x * b
def f(a, b):
c = jnp.zeros_like(a)
_, b, c = for_impl(5, body, (a, b, c))
return b, c
a = jnp.arange(5.) + 1.
b = 1.
_, f_lin = jax.linearize(f, a, b)
expected_tangents = f_lin(a, b)
_, actual_tangents = jax.jvp(f, (a, b), (a, b))
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])

def body2(_, refs):
# Here we use `i_ref` as a loop counter
a_ref, b_ref, c_ref, i_ref = refs
i = i_ref[()]
a = a_ref[i]
b = b_ref[()]
x = jnp.sin(a)
b_ref[()] = jnp.sin(b * x)
c_ref[i] = x * b
i_ref[()] = i + 1

def g(a, b):
c = jnp.zeros_like(a)
_, b, c, _ = for_impl(5, body2, (a, b, c, 0))
return b, c
a = jnp.arange(5.) + 1.
b = 1.
_, g_lin = jax.linearize(f, a, b)
expected_tangents = g_lin(a, b)
_, actual_tangents = jax.jvp(g, (a, b), (a, b))
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])

@parameterized.named_parameters(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
for_body_name, nsteps, impl_name),
"f": for_body, "body_shapes": body_shapes,
"ref": ref, "n": nsteps, "for_impl": for_impl}
for for_impl, impl_name in FOR_LOOP_IMPLS
for for_body_name, for_body, ref, body_shapes, nsteps in [
("noop", for_body_noop, noop_ref, [(4,), (4,)], 4),
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
def test_for_grad(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()

args = [rng.randn(*s) for s in body_shapes]

tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.grad(lambda args: for_( n, f, args)[1].sum())(args)
ans_discharged = jax.grad(
lambda args: for_reference(n, f, args)[1].sum())(args)
expected = jax.grad(lambda args: ref(*args)[1].sum())(args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol,
atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=3,
rtol=7e-3, atol=1e-2)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit e5725f1

Please sign in to comment.