We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Found that jax.numpy.nonzero becomes very slow in from jax>=0.2.26, significantly impacting the runtime when trying to update the version of jax.
jax.numpy.nonzero
Right now nonzero function is used in generate_from_nontimed_roadmap and generate_from_timed_roadmap in TimedRoadmap, e.g.
generate_from_nontimed_roadmap
generate_from_timed_roadmap
TimedRoadmap
jaxmapp/jaxmapp/roadmap/timed_roadmap.py
Line 107 in 109c9b7
Replacing this with np.argwhere (acceptable because these two functions are not jit-compiled) would resolve the problem.
np.argwhere
The text was updated successfully, but these errors were encountered:
Successfully merging a pull request may close this issue.
Found that
jax.numpy.nonzero
becomes very slow in from jax>=0.2.26, significantly impacting the runtime when trying to update the version of jax.Right now nonzero function is used in
generate_from_nontimed_roadmap
andgenerate_from_timed_roadmap
inTimedRoadmap
, e.g.jaxmapp/jaxmapp/roadmap/timed_roadmap.py
Line 107 in 109c9b7
Replacing this with
np.argwhere
(acceptable because these two functions are not jit-compiled) would resolve the problem.The text was updated successfully, but these errors were encountered: