-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unnecessary recompilation of _while_loop_lax #563
Comments
CC @froystig |
This is mostly relevant when reusing the same solver (I've been using BFGS), as reinstantiation creates new references of Line 80 in 58bac0a
|
Do you have a minimal example that reproduces the slowdown, and would you mind posting it here if so? |
import time, jaxopt, jax, jax.numpy as jnp
def rosenbrock(x):
return jnp.sum(100. * jnp.diff(x) ** 2 + (1. - x[:-1]) ** 2)
solver = jaxopt.BFGS(rosenbrock)
x0 = jnp.zeros(2)
_time = time.time()
sol = solver.run(x0)
_time = time.time() - _time
print(f'Total {_time} s')
jax.config.update('jax_log_compiles', True)
_time = time.time()
sol2 = solver.run(x0)
_time = time.time() - _time
print(f'Total2 {_time} s') With
Original library (jax 0.4.23, jaxopt 0.8.2):
|
More generally, what do you think about jaxopt caching all solvers, so recompilation would be reduced automatically when not using nested functions? |
what do you propose more concretely? |
At the moment I essentially use jaxopt with ed1febe reverted and make sure that I cache my solvers: @lru_cache
def get_solver(solver, *args, **kwargs):
return solver(*args, **kwargs) My suggestion would be to
|
Would it be possible to detect that you're inside a jitted context rather than accepting the Boolean |
ed1febe claims that jitting
_while_loop_lax
is redundant. However, this change also seems to prevent the result from being cached, causing recompilation of loops likejaxopt/jaxopt/_src/base.py
Line 314 in 58bac0a
The text was updated successfully, but these errors were encountered: