Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_coordinates mode='mirror' does not match scipy #11097

Closed
pabloduque0 opened this issue Jun 14, 2022 · 3 comments
Closed

map_coordinates mode='mirror' does not match scipy #11097

pabloduque0 opened this issue Jun 14, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@pabloduque0
Copy link

pabloduque0 commented Jun 14, 2022

Hello there! I'm using map_coordinates for some affine transformations and I use scipy as "ground truth" for testing, please let me know if there is any reason why they should not match under certain tolerance.

I create the target coordinates and then apply both interpolations. This does work for transform matrices that dont transform too much but fails for more severe transformations.

The versions of both are updated and I was able to reproduce it on colab.
jax.version == '0.3.13'
scipy.version == '1.7.3'

from scipy import ndimage
import jax
import numpy as np
import jax.numpy as jnp

src = np.random.uniform(0., 1., size=(100, 100, 3)).astype(np.float32)
# matrix = jnp.array([[-0.2, 0.5, 0], [0.2, 0.1, 0], [0, 0, 1]]) # Runs correctly
matrix = jnp.array([[-0.5, 0.5, 0], [0.8, 0.5, 0], [0, 0, 1]]) # Fails

meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in src.shape],
                          indexing="ij")
indices = jnp.concatenate(
    [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1)

coordinates = indices @ matrix.T
coordinates = jnp.moveaxis(coordinates, source=-1, destination=0)


def scipy_map_coordinates():
  return ndimage.map_coordinates(src, coordinates, order=0, mode="mirror")

def jax_map_coordinates():
  return jax.scipy.ndimage.map_coordinates(src, coordinates, order=0, mode="mirror")


from_jax = jax_map_coordinates()
from_scipy = scipy_map_coordinates()

print(from_jax.dtype, from_scipy.dtype)
np.testing.assert_allclose(from_jax, from_scipy, rtol=1e-4, atol=1e-4)

Both outputs are float32, but this does not produce the same result:

AssertionError: 
Not equal to tolerance rtol=0.0001, atol=0.0001

Mismatched elements: 294 / 30000 (0.98%)
Max absolute difference: 0.9023502
Max relative difference: 119.13727
 x: array([[[0.915364, 0.563342, 0.219187],
        [0.65514 , 0.092821, 0.330349],
        [0.65514 , 0.092821, 0.330349],...
 y: array([[[0.915364, 0.563342, 0.219187],
        [0.65514 , 0.092821, 0.330349],
        [0.65514 , 0.092821, 0.330349],...

I could be missing something or have messed up something along the way, please let me know if that is the case.
Thanks a lot in advance!

@pabloduque0 pabloduque0 added the bug Something isn't working label Jun 14, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 14, 2022

Hi - I think this might be expected behavior, due to the JAX implementation fixing a known bug that's present in the scipy implementation (see the note at https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html)

The scipy bug can be viewed here: scipy/scipy#2640

You can see the "ground truth" that JAX tests against here:

def _fixed_ref_map_coordinates(input, coordinates, order, mode, cval=0.0):
# SciPy's implementation of map_coordinates handles boundaries incorrectly,
# unless mode='reflect'. For order=1, this only affects interpolation outside
# the bounds of the original array.
# https://github.com/scipy/scipy/issues/2640
assert order <= 1
padding = [(max(-np.floor(c.min()).astype(int) + 1, 0),
max(np.ceil(c.max()).astype(int) + 1 - size, 0))
for c, size in zip(coordinates, input.shape)]
shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)]
pad_mode = {
'nearest': 'edge', 'mirror': 'reflect', 'reflect': 'symmetric'
}.get(mode, mode)
if mode == 'constant':
padded = np.pad(input, padding, mode=pad_mode, constant_values=cval)
else:
padded = np.pad(input, padding, mode=pad_mode)
result = osp_ndimage.map_coordinates(
padded, shifted_coords, order=order, mode=mode, cval=cval)
return result

@hawkinsp
Copy link
Member

I'm going to close this because I think it's working as intended. Please feel free to reopen if you disagree!

@pabloduque0
Copy link
Author

Thanks for providing the reference that JAX uses, I will defer to using that as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants