Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError when computing grad of function with @jax.custom_gradient #1875

Closed
jburnim opened this issue Dec 17, 2019 · 3 comments
Closed
Assignees

Comments

@jburnim
Copy link

jburnim commented Dec 17, 2019

With JAX 0.1.55, when I run:

import jax
import jax.numpy as np
from jax import grad

@jax.custom_gradient
def multiply_no_nan(x, y):
  def grad(dz):
    return (multiply_no_nan(dz, y), multiply_no_nan(x, dz))
  ret = np.where(np.equal(y, 0.), np.zeros_like(y), np.multiply(x, y))
  return ret, grad

grad(multiply_no_nan)(1., 1.)

The final line raises an AssertionError: " If you see this error, please let us know by opening an issue at https://github.com/google/jax/issues since we thought this was unreachable!"

The full traceback is:

>>> grad(multiply_no_nan)(1., 1.)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 355, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 410, in value_and_grad_f
    ans, vjp_py = vjp(f_partial, *dyn_args)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 1272, in vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/ad.py", line 108, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/ad.py", line 97, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 315, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/linear_util.py", line 153, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 1429, in __call__
    num_consts=len(consts))
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/core.py", line 153, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/ad.py", line 314, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/ad.py", line 442, in fun_jvp
    primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/core.py", line 153, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 89, in process_primitive
    return custom_partial_eval_rules[primitive](self, *tracers, **params)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/ad.py", line 461, in fun_jvp_partial_eval
    jaxpr, _, res = pe.trace_to_jaxpr(wrap_init(vjp_py), ct_pvals, instantiate=True)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 315, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/linear_util.py", line 153, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 1709, in vjp_flat
    args_cts_flat, in_tree2 = tree_flatten(vjp(cts))
  File "<stdin>", line 4, in grad
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 1429, in __call__
    num_consts=len(consts))
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/core.py", line 153, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 96, in process_primitive
    out_aval = primitive.abstract_eval(*avals, **params)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 1492, in fun_abstract_eval
    return pe.abstract_eval_fun(fun_impl, *avals, **params)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 234, in abstract_eval_fun
    instantiate=True)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 315, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/linear_util.py", line 153, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/api.py", line 1480, in fun_impl
    return core.eval_jaxpr(params['jaxpr'], consts, (), *args)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/core.py", line 205, in eval_jaxpr
    ans = eqn.primitive.bind(*(subfuns + in_vals), **eqn.params)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/core.py", line 153, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 96, in process_primitive
    out_aval = primitive.abstract_eval(*avals, **params)
  File "/Users/jburnim/jax-playground/venv/lib/python3.6/site-packages/jax/lax/lax.py", line 1493, in standard_abstract_eval
    assert pe._thread_local_state.remat, msg
AssertionError: If you see this error, please let us know by opening an issue at
https://github.com/google/jax/issues 
since we thought this was unreachable!
@mattjj mattjj self-assigned this Dec 17, 2019
@fehiepsi
Copy link
Member

We faced this issue for custom transforms with jvp rule. For example,

import jax
from jax import lax, numpy as np, random

def f(key):
    return random.gamma(key, 1.)

lax.map(f, random.split(random.PRNGKey(0), 10))

@shoyer
Copy link
Member

shoyer commented Dec 31, 2019

I'm also hitting this same assertion error, though I'm pretty sure my code does not use any custom transforms.

@mattjj
Copy link
Member

mattjj commented Mar 22, 2020

I merged #2026, which should fix this.

@mattjj mattjj closed this as completed Mar 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants