Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Primitive with only impl and vjp #3415

Closed
samuela opened this issue Jun 11, 2020 · 8 comments
Closed

Create Primitive with only impl and vjp #3415

samuela opened this issue Jun 11, 2020 · 8 comments
Assignees
Labels
question Questions for the JAX team

Comments

@samuela
Copy link
Contributor

samuela commented Jun 11, 2020

I have an external solver that I'd like to use within JAX, but I'm having trouble implementing it along with it's vjp.

Looking at the docs I see here that I should implement a new core.Primitive:

  1. defining new core.Primitive instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.

So I attempted to follow the tutorial How JAX primitives work. But it only defines the VJP in terms of the JVP and the transpose. I have no idea how this works. How can I define a Primitive with just the forward and vjp functions, esp something like jax.custom_vjp and myfun.defvjp(fwd, bwd)?

@shoyer
Copy link
Collaborator

shoyer commented Jun 12, 2020

I think you might actually find writing a just a custom_vjp easiest, e.g., as shown in this notebook: https://gist.github.com/shoyer/5f72853c2788e99e785f4737ee8a6ae1

The other option would be to write both a primitive to implement the underlying operation and a custom_vjp on top of it. This might make sense if you want other transformation rules besides auto-diff, e.g., vmap or jit support.

@samuela
Copy link
Contributor Author

samuela commented Jun 12, 2020

@shoyer I attempted custom_vjp (or was it custom_transforms...), and ran into this issue. It looks like you circumvented this somehow though?

@mattjj
Copy link
Collaborator

mattjj commented Jun 12, 2020

I think a Primitive is the right tool here, not a custom_vjp, because Primitives are the way we set delineate where tracers can't go (whereas custom_vjp is for decorating traceable Python code). And tracers can't go into external solvers, just like they can't go into CUDA routines! (That's why the tutorial includes that bullet: to discourage using custom_vjp for this use case.)

As for how to do this, the VJP part is a bit tricky. JAX's AD system separates linearization from transposition, and in particular doesn't natively have VJP rules (and instead only natively has JVP rules for nonlinear primitives and transposition rules for linear ones). That has several benefits, and it's fruitful to organize things that way (I suspect even in this case!), but it's unfamiliar and perhaps annoying if you already have a VJP worked out.

There is an internal API for attaching VJP rules to primitives, but I'd planned to deprecate it. Maybe we should keep it around for precisely this important use case.

So with the caveat that this has no warranty and might break in the future, here's the kind of thing you can do:

import jax
import jax.numpy as jnp
from jax import core
from jax.interpreters import ad


def sam(x):
  return sam_p.bind(x)
sam_p = core.Primitive('sam')

# define eval rule where you can call into a foreign function, in this case onp
import numpy as onp

@sam_p.def_impl
def sam_impl(x):
  return onp.sin(x)


# define a VJP rule, which calls traceable functions (not foreign functions)
def sam_vjp_maker(x):
  y = sam(x)
  vjp = lambda cotangent: (jnp.cos(x) * cotangent,)
  return y, vjp
ad.defvjp_all(sam_p, sam_vjp_maker)


###


def f(x):
  y = sam(x)
  z = jnp.cos(y)
  return sam(z)

from jax import grad
print(grad(f)(3.))

@mattjj mattjj self-assigned this Jun 12, 2020
@mattjj mattjj added the question Questions for the JAX team label Jun 12, 2020
@shoyer
Copy link
Collaborator

shoyer commented Jun 12, 2020

I agree that Primitive is probably the right tool for this, but the new custom_vjp would actually work here because you don’t need to be able to trace the vjp calculation.

@mattjj
Copy link
Collaborator

mattjj commented Jun 12, 2020

That's true only assuming our good friend @samuela doesn't want higher-order autodiff, or vmap, or jit, or pretty much anything other than grad. If we write a primitive like in the preceding comment, then at least higher-order reverse-mode autodiff will work, and we will get semi-sane error messages for the other things asking us to implement more rules (which may be readily implementable).

@mattjj
Copy link
Collaborator

mattjj commented Jun 12, 2020

@shoyer I see now that your notebook calls this out, "Conceivably if we implemented this via a JAX Primitive instead, we could define a batching rule with tf.vectorized_map.", but still I wanted to surface it as explicitly as possible in this issue thread for clarity and posterity!

@mattjj
Copy link
Collaborator

mattjj commented Jun 16, 2020

I think we covered the original question in this thread, so I'm going to close the issue, but @samuela let me know if I'm mistaken.

@mattjj mattjj closed this as completed Jun 16, 2020
@samuela
Copy link
Contributor Author

samuela commented Jun 16, 2020

Thanks guys! This answers my question! I ended up just doing the forward and backward passes "manually" for simplicity, but I may revisit this in the future if I need something more composable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants