Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnecessary recompilation of _while_loop_lax #563

Open
hrdl-github opened this issue Jan 3, 2024 · 8 comments
Open

Unnecessary recompilation of _while_loop_lax #563

hrdl-github opened this issue Jan 3, 2024 · 8 comments

Comments

@hrdl-github
Copy link

ed1febe claims that jitting _while_loop_lax is redundant. However, this change also seems to prevent the result from being cached, causing recompilation of loops like

many_step = loop.while_loop(
Reverting ed1febe drastically reduces compilation times for my use case, so it probably makes sense to address this cache issue.

@mblondel
Copy link
Collaborator

mblondel commented Jan 3, 2024

CC @froystig

@hrdl-github
Copy link
Author

This is mostly relevant when reusing the same solver (I've been using BFGS), as reinstantiation creates new references of _cond_fun and _body_fun, which are static arguments in

fun = jax.jit(fun, static_argnums=(0, 1, 3))
, I think

@froystig
Copy link
Member

froystig commented Jan 4, 2024

Do you have a minimal example that reproduces the slowdown, and would you mind posting it here if so?

@hrdl-github
Copy link
Author

hrdl-github commented Jan 4, 2024

import time, jaxopt, jax, jax.numpy as jnp

def rosenbrock(x):
    return jnp.sum(100. * jnp.diff(x) ** 2 + (1. - x[:-1]) ** 2)

solver = jaxopt.BFGS(rosenbrock)
x0 = jnp.zeros(2)

_time = time.time()
sol = solver.run(x0)
_time = time.time() - _time
print(f'Total {_time} s')

jax.config.update('jax_log_compiles', True)

_time = time.time()
sol2 = solver.run(x0)
_time = time.time() - _time
print(f'Total2 {_time} s')

With jax.jit:

Total 1.3411040306091309 s
[1. 1.]
Total2 0.007392406463623047 s

Original library (jax 0.4.23, jaxopt 0.8.2):

Total 1.3537604808807373 s
Finished tracing + transforming while for pjit in 0.0003256797790527344 sec
Compiling while for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(float32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[2]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[2,2]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
Finished jaxpr to MLIR module conversion jit(while) in 0.1349470615386963 sec
Finished XLA compilation of jit(while) in 0.3022458553314209 sec
[1. 1.]
Total2 0.45375514030456543 s

@hrdl-github
Copy link
Author

More generally, what do you think about jaxopt caching all solvers, so recompilation would be reduced automatically when not using nested functions?

@mblondel
Copy link
Collaborator

mblondel commented Jan 9, 2024

what do you propose more concretely?

@hrdl-github
Copy link
Author

hrdl-github commented Jan 9, 2024

At the moment I essentially use jaxopt with ed1febe reverted and make sure that I cache my solvers:

@lru_cache
def get_solver(solver, *args, **kwargs):
    return solver(*args, **kwargs)

My suggestion would be to

  1. Implement caching of _while_loop_lax cleanly without relying on jax.jit -- I haven't dug deep enough into jax yet to know the best way to do this, and
  2. Make it easy for the user to reuse solvers or at least document that reusing solvers will reduce / avoid recompilation to benefit from this change. Another topic would be advising against nested functions, which I've seen in a lot of non-official examples.

@NeilGirdhar
Copy link
Contributor

Would it be possible to detect that you're inside a jitted context rather than accepting the Boolean jit parameter? That way, only the user would ever call jit, and would totally control caching and compilation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants