Skip to content

Commit

Permalink
Fix lowering bug in effectful batched cond and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Sep 28, 2022
1 parent 1bcf8d6 commit ddeaa8d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
6 changes: 5 additions & 1 deletion jax/_src/lax/control_flow/loops.py
Expand Up @@ -1439,7 +1439,11 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
core.ordered_effects]
if cond_ordered_effects:
def cond(args):
return core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
# Pred can be batched
pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0]
if batched:
pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape))))
return pred
def body(args):
return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args))
def new_cond(pred_args):
Expand Down
60 changes: 60 additions & 0 deletions tests/debugging_primitives_test.py
Expand Up @@ -575,6 +575,66 @@ def _body(x):
x: 10
"""))

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_in_batched_while_cond(self, ordered):
def f(x):
def _cond(x):
debug_print("x: {x}", x=x, ordered=ordered)
return x < 5
def _body(x):
return x + 1
return lax.while_loop(_cond, _body, x)
with jtu.capture_stdout() as output:
jax.vmap(f)(jnp.arange(2))
jax.effects_barrier()
if ordered:
expected = _format_multiline("""
x: 0
x: 1
x: 1
x: 2
x: 2
x: 3
x: 3
x: 4
x: 4
x: 5
x: 5
x: 6
""")
self.assertEqual(output(), expected)
else:
# When the print is unordered, the `cond` is called an additional time
# after the `_body` runs, so we get more prints.
expected = _format_multiline("""
x: 0
x: 1
x: 0
x: 1
x: 1
x: 2
x: 1
x: 2
x: 2
x: 3
x: 2
x: 3
x: 3
x: 4
x: 3
x: 4
x: 4
x: 5
x: 4
x: 5
x: 5
x: 5
""")
self._assertLinesEqual(output(), expected)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
Expand Down

0 comments on commit ddeaa8d

Please sign in to comment.