Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document Haiku version of JAX transforms (hk.jit, ...) #14

Closed
ibab opened this issue Feb 29, 2020 · 4 comments
Closed

Document Haiku version of JAX transforms (hk.jit, ...) #14

ibab opened this issue Feb 29, 2020 · 4 comments
Labels
documentation Improvements or additions to documentation

Comments

@ibab
Copy link
Contributor

ibab commented Feb 29, 2020

We currently don't explain what hk.jit, hk.remat, etc. are and why they exist. It would be good to extend the documentation with these.

@sjmielke
Copy link

As it's not documented I'm not sure if this is a bug or (not) working as intended:

jnp.exp(0)  # 1.0
jax.jit(jnp.exp)(0)  # 1.0
hk.jit(jnp.exp)(0)  # IndexError: deque index out of range

Lmk if I should delete this comment and report it instead :)

@ibab
Copy link
Contributor Author

ibab commented Mar 1, 2020

@sjmielke: Thanks for finding that! It fails because the hk.* transforms assume that they run inside of hk.transform right now, but there's no reason why they shouldn't work outside. I have a fix for this in #17.

@trevorcai
Copy link
Contributor

@sjmielke We anticipate that the situations in which you'd want to use hk.jit or hk.grad are limited! They exist as power user workarounds for a particular use cases. I regard them as much more alpha than the rest of Haiku.

Haiku presents a function, hk.transform, which converts impure, object-oriented code with magic functions like hk.get_parameter into JAX transform friendly pure functions.
However, sometimes we want to use JAX transformations inside of a monolithic chunk of Haiku code.
hk.{grad,jit,remat} are exposed for these cases; in all other situations, prefer the JAX equivalent.

For hk.jit:

  • We recommend against jitting functions inside of hk.transform that create or use parameters.
    • If you can jax.jit the entire hk.transform(my_fn), do that!
    • If you can't, prefer to extract pure functions representing the expensive portions of your computation and JIT those.
    • In the rare case in which neither of these are possible (e.g. data-dependent control flow that is hard to express with JAX XLA control flow tools, and is hard to break down from a code perspective), then we provide hk.jit as a workaround.

For hk.grad:

  • If your model involves taking derivatives inside of your neural network, use hk.grad.

TODO: Add documentation for this stuff. Contributions welcome!

@trevorcai trevorcai added the documentation Improvements or additions to documentation label Mar 1, 2020
@trevorcai
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants