Skip to content

Commit

Permalink
avoid redundant jit of lax.while_loop with static restriction
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Feb 3, 2022
1 parent eb6e75d commit ed1febe
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion jaxopt/_src/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def while_loop(cond_fun, body_fun, init_val, maxiter, unroll=False, jit=False):
else:
raise ValueError("unroll=False and jit=False cannot be used together")

if jit:
if jit and fun is not _while_loop_lax:
# jit of a lax while_loop is redundant, and this jit would only
# constrain maxiter to be static where it is not required.
fun = jax.jit(fun, static_argnums=(0, 1, 3))

return fun(cond_fun, body_fun, init_val, maxiter)

0 comments on commit ed1febe

Please sign in to comment.