-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create Primitive
with only impl and vjp
#3415
Comments
I think you might actually find writing a just a The other option would be to write both a primitive to implement the underlying operation and a |
@shoyer I attempted |
I think a As for how to do this, the VJP part is a bit tricky. JAX's AD system separates linearization from transposition, and in particular doesn't natively have VJP rules (and instead only natively has JVP rules for nonlinear primitives and transposition rules for linear ones). That has several benefits, and it's fruitful to organize things that way (I suspect even in this case!), but it's unfamiliar and perhaps annoying if you already have a VJP worked out. There is an internal API for attaching VJP rules to primitives, but I'd planned to deprecate it. Maybe we should keep it around for precisely this important use case. So with the caveat that this has no warranty and might break in the future, here's the kind of thing you can do: import jax
import jax.numpy as jnp
from jax import core
from jax.interpreters import ad
def sam(x):
return sam_p.bind(x)
sam_p = core.Primitive('sam')
# define eval rule where you can call into a foreign function, in this case onp
import numpy as onp
@sam_p.def_impl
def sam_impl(x):
return onp.sin(x)
# define a VJP rule, which calls traceable functions (not foreign functions)
def sam_vjp_maker(x):
y = sam(x)
vjp = lambda cotangent: (jnp.cos(x) * cotangent,)
return y, vjp
ad.defvjp_all(sam_p, sam_vjp_maker)
###
def f(x):
y = sam(x)
z = jnp.cos(y)
return sam(z)
from jax import grad
print(grad(f)(3.)) |
I agree that Primitive is probably the right tool for this, but the new custom_vjp would actually work here because you don’t need to be able to trace the vjp calculation. |
That's true only assuming our good friend @samuela doesn't want higher-order autodiff, or vmap, or jit, or pretty much anything other than |
@shoyer I see now that your notebook calls this out, "Conceivably if we implemented this via a JAX Primitive instead, we could define a batching rule with tf.vectorized_map.", but still I wanted to surface it as explicitly as possible in this issue thread for clarity and posterity! |
I think we covered the original question in this thread, so I'm going to close the issue, but @samuela let me know if I'm mistaken. |
Thanks guys! This answers my question! I ended up just doing the forward and backward passes "manually" for simplicity, but I may revisit this in the future if I need something more composable. |
I have an external solver that I'd like to use within JAX, but I'm having trouble implementing it along with it's vjp.
Looking at the docs I see here that I should implement a new
core.Primitive
:So I attempted to follow the tutorial How JAX primitives work. But it only defines the VJP in terms of the JVP and the transpose. I have no idea how this works. How can I define a Primitive with just the forward and vjp functions, esp something like
jax.custom_vjp
andmyfun.defvjp(fwd, bwd)
?The text was updated successfully, but these errors were encountered: