Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ravel_pytree now produces jit-compatible unravel functions #13834

Closed

Conversation

patrick-kidger
Copy link
Collaborator

@patrick-kidger patrick-kidger commented Dec 31, 2022

Previously,

_, unravel1 = ravel_pytree(pytree)
_, unravel2 = ravel_pytree(pytree)

@partial(jax.jit, static_argnums=0)
def run(unravel, ...):
    ...

run(unravel1, ...)
run(unravel2, ...)

would unecessarily induce recompilation.

@patrick-kidger
Copy link
Collaborator Author

Ping @mattjj

@mattjj
Copy link
Member

mattjj commented Mar 13, 2023

Thanks for this, Patrick. Sorry I've been so slow to respond for the last N months. (I wish I could say only "recently"!)

This is a great idea. I adapted it in #14954 to reuse the Partial HashablePartial class, and to use a tuple in place of a hashably-wraped numpy.ndarray (though we can switch to the latter if the need arises, e.g. if we want different hash functions or performance characteristics). I pasted the same test case.

Thanks for this improvement!

@patrick-kidger
Copy link
Collaborator Author

Thanks for this, Patrick. Sorry I've been so slow to respond for the last N months. (I wish I could say only "recently"!)

Haha!

Thanks, glad to see this in. FWIW the docs on jnp.split only mention supporting ints or arrays, not tuples. (I know that JAX departs from numpy in often disallowing arraylikes in place of arrays.)

@mattjj
Copy link
Member

mattjj commented Mar 13, 2023

FWIW the docs on jnp.split only mention supporting ints or arrays, not tuples. (I know that JAX departs from numpy in often disallowing arraylikes in place of arrays

Good catch!

Also I done goofed in another way: I was thinking of HashablePartial from shard_map.py (to be moved into util.py), not Partial.

@patrick-kidger
Copy link
Collaborator Author

Hah, I feel like the various kinds of JAX-internal function wrappers are starting to get a bit complicated. Off the top of my head there's Partial, HashablePartial, HashableFunction, _HashableCallableShim, and callable pytrees (e.g. checkify.Error, or this PR).

Outside of core JAX I just use eqx.Module for all of these use cases. It should probably be possible to standardise on a single approach here too.


Side note, the issue fixed in this PR is a pretty common one in JAX -- e.g. jax.{jit, grad, ...} etc. all return a fresh closure each time.

@mattjj
Copy link
Member

mattjj commented Mar 13, 2023

It should probably be possible to standardise on a single approach here too.

You might be a lumper. Internal JAX utilities are set up to make lumpers out themselves. 😛

@mattjj
Copy link
Member

mattjj commented Mar 13, 2023

We effectively merged this as #14954!

@mattjj mattjj closed this Mar 13, 2023
@patrick-kidger patrick-kidger deleted the ravel-pytree-no-recompile branch March 13, 2023 22:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants