You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Following up on #15504, I ended up needing a more general solution because I kept running into the same issue with other function transformations that often don't support static arguments. My solution ended up being this this snippet that adds static argument support to any transformation. A meta transformation! :D
A simplifying design choice here is that static arguments are required to be passed as keyword arguments and cannot be passed as positional arguments, which I find much easier to specify. When the transformed function is already the output of some other transformation, counting argument names can be painful.
I'm wondering whether including a helper like the one above would be a useful addition, simplification, and unification for JAX? It might simplify the existing JAX transformations, provide a unified API to users, and allow users to use third-party transformations that may not support static arguments. Longer term, users could even be asked to apply the meta transformation themselves instead of repeating the API across the provided transformations.
The text was updated successfully, but these errors were encountered:
danijar
changed the title
Functionality to support static arguments in any function transformation
Support static arguments for any function transformation
May 19, 2023
You might find Equinox interesting, which does something similar: it has equinox.filter_{jit, grad, ...}, which automatically determines dynamic vs static arguments based on their type. After all, it mostly only makes sense to JIT wrt arrays, to grad wrt floating-point arrays, etc.
Following up on #15504, I ended up needing a more general solution because I kept running into the same issue with other function transformations that often don't support static arguments. My solution ended up being this this snippet that adds static argument support to any transformation. A meta transformation! :D
A simplifying design choice here is that static arguments are required to be passed as keyword arguments and cannot be passed as positional arguments, which I find much easier to specify. When the transformed function is already the output of some other transformation, counting argument names can be painful.
I'm wondering whether including a helper like the one above would be a useful addition, simplification, and unification for JAX? It might simplify the existing JAX transformations, provide a unified API to users, and allow users to use third-party transformations that may not support static arguments. Longer term, users could even be asked to apply the meta transformation themselves instead of repeating the API across the provided transformations.
The text was updated successfully, but these errors were encountered: