Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.scipy.special.multigammaln fails under jit #5073

Closed
jakevdp opened this issue Dec 1, 2020 · 1 comment
Closed

jax.scipy.special.multigammaln fails under jit #5073

jakevdp opened this issue Dec 1, 2020 · 1 comment

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 1, 2020

Brief example below.

The need for d to be static is an implementation detail; the shape of the function result does not depend on the value of d.

Easiest fix would be to have better errors for a traced d.

Harder fix (and maybe not necessary) would be to implement the function in a way such that d need not be static.

from jax.scipy.special import multigammaln
from jax import jit
jit(multigammaln)(1, 2)
Traceback (most recent call last):
  File "tmp.py", line 4, in <module>
    jit(multigammaln)(1, 2)
  File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
    lax.div(jnp.arange(d), _constant_like(a, 2))),
  File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
    start = require(start, msg("stop"))
jax._src.traceback_util.FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "tmp.py", line 4, in <module>
    jit(multigammaln)(1, 2)
  File "/Users/vanderplas/github/google/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/vanderplas/github/google/jax/jax/api.py", line 218, in f_jitted
    out = xla.xla_call(
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 1226, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 1217, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 1229, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 595, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 569, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 251, in memoized_fun
    ans = call(fun, *args)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 645, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
    lax.div(jnp.arange(d), _constant_like(a, 2))),
  File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
    start = require(start, msg("stop"))
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 919, in concrete_or_error
    raise_concretization_error(val, context)
  File "/Users/vanderplas/github/google/jax/jax/core.py", line 896, in raise_concretization_error
    raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.

It arose in jax.numpy.arange argument `stop`.

While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).

You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.

See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.

Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
@jakevdp
Copy link
Collaborator Author

jakevdp commented Dec 3, 2020

I'm going to close this: I think requiring statid d is fine, and #5074 made improved the error message. In the scipy version, only an integer value is accepted.

@jakevdp jakevdp closed this as completed Dec 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant