Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom VJPs for external functions #1142

Closed
yaugenst opened this issue Aug 8, 2019 · 16 comments
Closed

Custom VJPs for external functions #1142

yaugenst opened this issue Aug 8, 2019 · 16 comments

Comments

@yaugenst
Copy link

yaugenst commented Aug 8, 2019

Hi! I want to define custom gradients for a simulation for sensitivity analysis. I have been using autograd for this, but since it is not actively being developed anymore I wanted to switch to jax.
In autograd I would write something like this:

from autograd import grad
from autograd.extend import primitive, defvjp
import simulation

@primitive
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

defvjp(sim, sim_vjp)

In autograd, this worked fine and I was able to chain this together with some other differentiable transformations and get gradients out of the whole thing.
From what I was able to gather, the above would be written in jax as follows:

import jax
import simulation

@jax.custom_transforms
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

jax.defvjp_all(sim, sim_vjp)

However, this throws Exception: Tracer can't be used with raw numpy functions., which I assume is because the simulation code does not use jax. Are the custom gradients in jax not black-boxes as in autograd anymore, i.e. is this a fundamental limitation or have I screwed something up? Do I need to implement this using lax primitives, and if so, how?

I would be grateful for a minimal example implementing this for some arbitrary non-jax function. This code here for example works in autograd:

from autograd import grad
from autograd.extend import primitive, defvjp
from scipy.ndimage import gaussian_filter

@primitive
def filter(img):
    return gaussian_filter(img, 1)

def filter_vjp(ans, img):
    def vjp(g):
        return gaussian_filter(g, 1)
    return vjp

defvjp(filter, filter_vjp)

How would one translate this so it works in jax?
Thanks so much!

@tpr0p
Copy link

tpr0p commented Aug 8, 2019

I'm having the same issue. I can not get the example code from the defvjp_all documentation to work.

System information:
OS: Linux - Ubuntu 18.04
Python: 3.7.2

Build information:
I built from source following the instructions in the README.

git clone https://github.com/google/jax
cd jax
python build/build.py
pip install -e build  # install jaxlib (includes XLA)
pip install -e .      # install jax

Minimal code to reproduce (jax_test.py):

import jax
import numpy as np

@jax.custom_transforms
def f(x):
    return np.square(x)

def f_vjp(x):
    return f(x), lambda g: 2 * x * g

jax.defvjp_all(f, f_vjp)

def main():
    jax.grad(f, 0)(1.)

if __name__ == "__main__":
    main()

Stack trace:

Traceback (most recent call last):
  File "jax_test.py", line 23, in <module>
    main()
  File "jax_test.py", line 19, in main
    jax.grad(f, 0)(1.)
  File "/home/tcpropson/repos/jax/jax/api.py", line 341, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/tcpropson/repos/jax/jax/api.py", line 387, in value_and_grad_f
    ans, vjp_py = vjp(f_partial, *dyn_args)
  File "/home/tcpropson/repos/jax/jax/api.py", line 1002, in vjp
    out_primal, out_vjp = ad.vjp(jaxtree_fun, primals_flat)
  File "/home/tcpropson/repos/jax/jax/interpreters/ad.py", line 105, in vjp
    out_primal, pval, jaxpr, consts = linearize(traceable, *primals)
  File "/home/tcpropson/repos/jax/jax/interpreters/ad.py", line 94, in linearize
    jaxpr, out_pval, consts = pe.trace_to_jaxpr(jvpfun, in_pvals)
  File "/home/tcpropson/repos/jax/jax/interpreters/partial_eval.py", line 400, in trace_to_jaxpr
    jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
  File "/home/tcpropson/repos/jax/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/tcpropson/repos/jax/jax/api.py", line 1175, in __call__
    jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals_in, instantiate=True)
  File "/home/tcpropson/repos/jax/jax/interpreters/partial_eval.py", line 400, in trace_to_jaxpr
    jaxpr, (out_pval, consts, env) = fun.call_wrapped(pvals)
  File "/home/tcpropson/repos/jax/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "jax_test.py", line 10, in f
    return np.square(x)
  File "/home/tcpropson/repos/jax/jax/core.py", line 287, in __array__
    raise Exception("Tracer can't be used with raw numpy functions. "
Exception: Tracer can't be used with raw numpy functions. You might have
  import numpy as np
instead of
  import jax.numpy as np

@hawkinsp
Copy link
Member

hawkinsp commented Aug 9, 2019

Currently custom gradients have the same restrictions as jit; i.e., they must be jit-able. We agree that we should relax this restriction.

@hawkinsp hawkinsp added the enhancement New feature or request label Aug 9, 2019
@shoyer
Copy link
Member

shoyer commented Aug 13, 2019

Even if we get this working, I suspect you may still find performance to be painfully slow if you can't jit the rest of your code. This is why we'd also want support for JIT compilation with custom NumPy ops (#766).

@mattjj
Copy link
Member

mattjj commented Aug 18, 2019

Hi all, sorry for the slow response! @tpr0p @MRBaozi

The issue here is the difference between a custom_transforms function and a Primitive. You want a Primitive.

From the custom_transforms docstring (emphasis mine):

A primary use case of custom_transforms is defining custom VJP rules (aka custom gradients) for a Python function, while still supporting other transformations like jax.jit and jax.vmap.

Let me unpack that, because it's not very detailed.

The custom_transforms function is useful when you have a Python function that JAX can handle just fine (to compile, differentiate, batch, etc.) but you still want to override how it behaves under one (or more) of those transformations while retaining the default behavior for the others. So a custom_transforms function isn't totally opaque to the tracing/transforming machinery: in fact, if you don't override any of its transformation rules, then it's traced/transformed into just like a regular function. That's different from Autograd's primitives, because those were always totally opaque. The main use case for custom_transforms is where you have a Python function implemented with jax.numpy and you like how it behaves under jit, but you want to control how it behaves under grad.

In contrast, a JAX Primitive (defined in core.py) is directly analogous Autograd's primitive, in that it sets up an opaque function. When you define a Primitive you need to define a rule for every transformation you want to use (rather than just the ones you want to override). Most of JAX's primitives are in the lax package, and we implement everything on top of those.

We haven't documented how to set up your own Primitives yet (it's the venerable issue #116), but it's not too hard. Here's an adaptation of @tpr0p's example:

from jax import core
import numpy as onp  # I changed this name out of habit

# Set up a Primitive, using a handy level of indirection
def foo(x):
  return foo_p.bind(x)
foo_p = core.Primitive('foo')

At this point there are no rules defined for foo_p, not even an evaluation rule (we consider eval to be just another transformation!). Here's the error we get if we try to call it:

In [2]: foo(3)
NotImplementedError: Evaluation rule for 'foo' not implemented

Let's define an evaluation rule in terms of onp, a totally opaque un-traceable call into C code:

foo_p.def_impl(onp.square)

And now:

In [4]: foo(3)
Out[4]: 9

Woohoo! But we can't do anything else with it. We can add a VJP rule like this (though actually for all our primitives we instead define a JVP rule, this might be more familiar, cf. #636):

from jax.interpreters import ad
ad.defvjp(foo_p, lambda g, x: 2 * x * g)

And now:

In [5]: from jax import grad
In [6]: grad(foo)(3.)
Out[6]: DeviceArray(6., dtype=float32)

There's also an API closer to the one in @tpr0p 's original example:

def f_vjp(x):
  return foo(x), lambda g: (2 * g * x,)
ad.defvjp_all(foo_p, f_vjp)

To use jit you'll need to define a translation rule.

Does that make sense? What'd I miss?

@yaugenst
Copy link
Author

Thanks so much for the thorough explanation @mattjj! I managed to get everything working now. I can see how this approach is more general, only the solution is somewhat non-obvious from an outsider's perspective ;-)

@HamletWantToCode
Copy link

Hi @mattjj, the example code you give works fun. However, if my self-implement VJP function contains raw numpy functions, I will run into the same issue:
Here is an example:

import numpy as onp
from jax.core import Primitive
from jax.interpreters.ad import defvjp
from jax import grad

# Define function to be differentiate
def foo(x):
    return foo_p.bind(x)
foo_p = Primitive('foo')

def f(x):
    return onp.sin(x)
foo_p.def_impl(f)

# Define the derivative, USING PURE NUMPY FUNCTION !!!
def dfoo(g, x):
    return g*onp.cos(x)

defvjp(foo_p, dfoo)

When I run the code:

# evaluate the foo, works fun
In [1]: foo(onp.pi/6)                                                                                                                        
Out[1]: 0.49999999999999994

#evaluate gradient, not work
In [2]: grad(foo)(onp.pi/6) 
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-11-bac41fa563b2> in <module>
----> 1 grad(foo)(onp.pi/6)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/api.py in grad_f(*args, **kwargs)
    336   @wraps(fun, docstr=docstr, argnums=argnums)
    337   def grad_f(*args, **kwargs):
--> 338     _, g = value_and_grad_f(*args, **kwargs)
    339     return g
    340 

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    384     f_partial, dyn_args = _argnums_partial(f, argnums, args)
    385     if not has_aux:
--> 386       ans, vjp_py = vjp(f_partial, *dyn_args)
    387     else:
    388       ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/api.py in vjp(fun, *primals, **kwargs)
   1044   if not has_aux:
   1045     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1046     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1047     out_tree = out_tree()
   1048   else:

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    103 def vjp(traceable, primals, has_aux=False):
    104   if not has_aux:
--> 105     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    106   else:
    107     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     92   _, in_tree = tree_flatten(((primals, primals), {}))
     93   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 94   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     95   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     96   aval_primals, const_primals = unzip2(pval_primals)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, **kwargs)
    316   with new_master(JaxprTrace) as master:
    317     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 318     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    319     assert not env
    320     del master

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    163 
    164     del gen
--> 165     ans = self.f(*args, **dict(self.params, **kwargs))
    166     del args
    167     while stack:

<ipython-input-4-139a3384ba6d> in foo(x)
      1 def foo(x):
----> 2   return foo_p.bind(x)
      3 foo_p = Primitive('foo')

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
    131 
    132     tracers = map(top_trace.full_raise, args)
--> 133     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    134     if self.multiple_results:
    135       return map(full_lower, out_tracer)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/ad.py in process_primitive(self, primitive, tracers, params)
    218           "Forward-mode differentiation rule for '{}' not implemented"
    219           .format(primitive))
--> 220     primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
    221     if primitive.multiple_results:
    222       return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/ad.py in fun_jvp(xs, ts, **params)
    346   def fun_jvp(xs, ts, **params):
    347     ts = map(instantiate_zeros, xs, ts)
--> 348     primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
    349     primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
    350     if prim.multiple_results:

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/core.py in bind(self, *args, **kwargs)
    131 
    132     tracers = map(top_trace.full_raise, args)
--> 133     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    134     if self.multiple_results:
    135       return map(full_lower, out_tracer)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
     86   def process_primitive(self, primitive, tracers, params):
     87     if primitive in custom_partial_eval_rules:
---> 88       return custom_partial_eval_rules[primitive](self, *tracers, **params)
     89     else:
     90       pvs, consts = unzip2(t.pval for t in tracers)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/ad.py in fun_jvp_partial_eval(trace, *tracers, **params)
    365     out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
    366     ct_pvals = [pe.PartialVal((aval, core.unit)) for aval in out_avals]
--> 367     jaxpr, _, res = pe.trace_to_jaxpr(wrap_init(vjp_py), ct_pvals, instantiate=True)
    368     tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
    369                                   num_res=len(res), out_avals=out_avals)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, **kwargs)
    316   with new_master(JaxprTrace) as master:
    317     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 318     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    319     assert not env
    320     del master

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    163 
    164     del gen
--> 165     ans = self.f(*args, **dict(self.params, **kwargs))
    166     del args
    167     while stack:

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/ad.py in <lambda>(ct)
    386     ans = prim.bind(*primals)
    387     vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
--> 388                          for x, vjp in zip(primals, vjps)]
    389     return ans, vjpfun
    390   defvjp_all(prim, vjpmaker)

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/interpreters/ad.py in <listcomp>(.0)
    386     ans = prim.bind(*primals)
    387     vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
--> 388                          for x, vjp in zip(primals, vjps)]
    389     return ans, vjpfun
    390   defvjp_all(prim, vjpmaker)

<ipython-input-9-d0c4d02422d6> in dfoo(g, x)
      1 def dfoo(g, x):
----> 2     return g*onp.cos(x)
      3 

~/miniconda3/envs/jax-env/lib/python3.7/site-packages/jax/core.py in __array__(self)
    262 
    263   def __array__(self):
--> 264     raise Exception("Tracer can't be used with raw numpy functions. "
    265                     "You might have\n  import numpy as np\ninstead of\n  import jax.numpy as np")
    266 

Exception: Tracer can't be used with raw numpy functions. You might have
  import numpy as np
instead of
  import jax.numpy as np

System information:
OS: MacOS 10.14.6
Python: 3.7.4
jax: 0.1.46

@shoyer
Copy link
Member

shoyer commented Oct 20, 2019

@HamletWantToCode my understanding is that JVP rules cannot make use of NumPy functions directly, because JAX wants to support higher order differentiation. You could make your pure-NumPy example work by defining another Primitive for cos, e.g.,

import numpy as onp
from jax.core import Primitive
from jax.interpreters.ad import defvjp
from jax import grad

# Define function to be differentiate
def foo(x):
    return foo_p.bind(x)
foo_p = Primitive('foo')

def f(x):
    return onp.sin(x)
foo_p.def_impl(f)

def dfoo(g, x):
    return g*bar(x)

defvjp(foo_p, dfoo)

def bar(x):
    return bar_p.bind(x)
bar_p = Primitive('bar')

def g(x):
    return onp.cos(x)
bar_p.def_impl(g)

def dbar(g, x):
    return -g*foo(x)

defvjp(bar_p, dbar)

@HamletWantToCode
Copy link

Thank you very much @shoyer

@hawkinsp
Copy link
Member

@mattjj should we consider this fixed with the new custom gradients code?

@mattjj
Copy link
Member

mattjj commented Apr 16, 2020

Actually, for external functions a new primitive should be used, not custom_jvp/vjp stuff. That is, external functions fall into case 2 articulated at the top of the Custom derivative rules for JAX-transformable Python functions tutorial.

I think this topic is important enough that it needs its own tutorial explanation (i.e. I don't think the "How JAX primitives work" is quite the right explanation for people looking to solve this particular issue, just because we should have more direct examples for this use case).

@mattjj mattjj added documentation and removed enhancement New feature or request labels Apr 16, 2020
@mattjj
Copy link
Member

mattjj commented Apr 16, 2020

I changed the issue label to "documentation" so that we can add such a tutorial.

@shoyer
Copy link
Member

shoyer commented Apr 16, 2020

The new custom gradients does make this easier in some cases, though indeed a primitive would allow for a more complete solution.

Here's a rough prototype I worked out for TensorFlow 2 <-> JAX, which may be a useful point of reference:
https://gist.github.com/shoyer/5f72853c2788e99e785f4737ee8a6ae1

@shoyer
Copy link
Member

shoyer commented Apr 6, 2021

The current recommendation for wrapping external functions in a way that is compatible with JIT would be to use jax.experimental.host_callback.call: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html

If you need to use a pre-existing VJP rule, then I think you need to use custom_vjp. But for full capability with JAX (e.g., for vmap and jvp support), the recommendation would be to try to write a full JAX primitive, which requires decomposing auto-diff into jvp & transpose rules.

@wq0729
Copy link

wq0729 commented Jan 9, 2022

Hi! I want to define custom gradients for a simulation for sensitivity analysis. I have been using autograd for this, but since it is not actively being developed anymore I wanted to switch to jax. In autograd I would write something like this:

from autograd import grad
from autograd.extend import primitive, defvjp
import simulation

@primitive
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

defvjp(sim, sim_vjp)

In autograd, this worked fine and I was able to chain this together with some other differentiable transformations and get gradients out of the whole thing. From what I was able to gather, the above would be written in jax as follows:

import jax
import simulation

@jax.custom_transforms
def sim(params):
    results = simulation.run(params)
    return results

def sim_vjp(ans, params):
    def vjp(g):
        # custom gradient code goes here
        return gradient
    return vjp

jax.defvjp_all(sim, sim_vjp)

However, this throws Exception: Tracer can't be used with raw numpy functions., which I assume is because the simulation code does not use jax. Are the custom gradients in jax not black-boxes as in autograd anymore, i.e. is this a fundamental limitation or have I screwed something up? Do I need to implement this using lax primitives, and if so, how?

I would be grateful for a minimal example implementing this for some arbitrary non-jax function. This code here for example works in autograd:

from autograd import grad
from autograd.extend import primitive, defvjp
from scipy.ndimage import gaussian_filter

@primitive
def filter(img):
    return gaussian_filter(img, 1)

def filter_vjp(ans, img):
    def vjp(g):
        return gaussian_filter(g, 1)
    return vjp

defvjp(filter, filter_vjp)

How would one translate this so it works in jax? Thanks so much!

Hello. It seems that you were doing something like inverse photonics design with a solver. Currently, I am doing a project about topology optimization for photonics structure. Actually, I am struggling in applying Autograd into a solver, so that I can get the gradient. If possible, could I ask you how to do that because I still have no clue how to do that? Really appreciate that if we can discuss that.

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 27, 2023

JAX now has a supported/recommended way of doing this: pure_callback along with custom_jvp. There's a more full example at https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp

@jakevdp jakevdp closed this as completed Jul 27, 2023
@mfschubert
Copy link

In case anyone is still interested in using autograd functions in a jit/vmap/jacrev compatible way, I have an experimental wrapper in the agjax package which addresses this: https://github.com/mfschubert/agjax
https://github.com/mfschubert/agjax/blob/main/src/agjax/experimental/wrapper.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

9 participants