-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
map_coordinates mode='mirror' does not match scipy #11097
Comments
Hi - I think this might be expected behavior, due to the JAX implementation fixing a known bug that's present in the scipy implementation (see the note at https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html) The scipy bug can be viewed here: scipy/scipy#2640 You can see the "ground truth" that JAX tests against here: jax/tests/scipy_ndimage_test.py Lines 38 to 57 in cd565f8
|
I'm going to close this because I think it's working as intended. Please feel free to reopen if you disagree! |
Thanks for providing the reference that JAX uses, I will defer to using that as well! |
Hello there! I'm using
map_coordinates
for some affine transformations and I usescipy
as "ground truth" for testing, please let me know if there is any reason why they should not match under certain tolerance.I create the target coordinates and then apply both interpolations. This does work for transform matrices that dont transform too much but fails for more severe transformations.
The versions of both are updated and I was able to reproduce it on
colab
.jax.version == '0.3.13'
scipy.version == '1.7.3'
Both outputs are
float32
, but this does not produce the same result:I could be missing something or have messed up something along the way, please let me know if that is the case.
Thanks a lot in advance!
The text was updated successfully, but these errors were encountered: