You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "tmp.py", line 4, in <module>
jit(multigammaln)(1, 2)
File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
lax.div(jnp.arange(d), _constant_like(a, 2))),
File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
start = require(start, msg("stop"))
jax._src.traceback_util.FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
It arose in jax.numpy.arange argument `stop`.
While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "tmp.py", line 4, in <module>
jit(multigammaln)(1, 2)
File "/Users/vanderplas/github/google/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_tracebackreturn fun(*args, **kwargs)
File "/Users/vanderplas/github/google/jax/jax/api.py", line 218, in f_jitted
out = xla.xla_call(
File "/Users/vanderplas/github/google/jax/jax/core.py", line 1226, in bindreturn call_bind(self, fun, *args, **params)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 1217, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 1229, in processreturn trace.process_call(self, fun, tracers, params)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 595, in process_callreturn primitive.impl(f, *tracers, **params)
File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 569, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 251, in memoized_fun
ans = call(fun, *args)
File "/Users/vanderplas/github/google/jax/jax/interpreters/xla.py", line 645, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1230, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/vanderplas/github/google/jax/jax/interpreters/partial_eval.py", line 1211, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/vanderplas/github/google/jax/jax/linear_util.py", line 160, in call_wrapped
ans =self.f(*args, **dict(self.params, **kwargs))
File "/Users/vanderplas/github/google/jax/jax/_src/scipy/special.py", line 162, in multigammaln
lax.div(jnp.arange(d), _constant_like(a, 2))),
File "/Users/vanderplas/github/google/jax/jax/_src/numpy/lax_numpy.py", line 2749, in arange
start = require(start, msg("stop"))
File "/Users/vanderplas/github/google/jax/jax/core.py", line 919, in concrete_or_error
raise_concretization_error(val, context)
File "/Users/vanderplas/github/google/jax/jax/core.py", line 896, in raise_concretization_errorraise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
It arose in jax.numpy.arange argument `stop`.
While tracing the function multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154, this concrete value was not available in Python because it depends on the value of the arguments to multigammaln at /Users/vanderplas/github/google/jax/jax/_src/scipy/special.py:154 at flattened positions [1], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions, though at the cost of more recompiles.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/1)>
The text was updated successfully, but these errors were encountered:
I'm going to close this: I think requiring statid d is fine, and #5074 made improved the error message. In the scipy version, only an integer value is accepted.
Brief example below.
The need for
d
to be static is an implementation detail; the shape of the function result does not depend on the value ofd
.Easiest fix would be to have better errors for a traced
d
.Harder fix (and maybe not necessary) would be to implement the function in a way such that
d
need not be static.The text was updated successfully, but these errors were encountered: