diff --git a/nilearn/_utils/niimg.py b/nilearn/_utils/niimg.py index 918b5e8f6b..05a21ebb33 100644 --- a/nilearn/_utils/niimg.py +++ b/nilearn/_utils/niimg.py @@ -146,6 +146,29 @@ def load_niimg(niimg, dtype=None): return niimg +def _is_binary_niimg(niimg): + """Returns whether a given niimg is binary or not. + + Parameters + ---------- + niimg: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + Image to test. + + Returns + ------- + is_binary: Boolean + True if binary, False otherwise. + + """ + niimg = load_niimg(niimg) + data = _safe_get_data(niimg, ensure_finite=True) + unique_values = np.unique(data) + if len(unique_values) != 2: + return False + return sorted(list(unique_values)) == [0,1] + + def copy_img(img): """Copy an image to a nibabel.Nifti1Image. diff --git a/nilearn/image/resampling.py b/nilearn/image/resampling.py index ba0159a1da..e2cd71cab1 100644 --- a/nilearn/image/resampling.py +++ b/nilearn/image/resampling.py @@ -267,6 +267,14 @@ def _resample_one_img(data, A, b, target_shape, data = _extrapolate_out_mask(data, np.logical_not(not_finite), iterations=2)[0] + # If data is binary and interpolation is continuous or linear, + # warn the user as this might be unintentional + if sorted(list(np.unique(data))) == [0,1] and interpolation_order != 0: + warnings.warn("Resampling binary images with continuous or " + "linear interpolation. This might lead to " + "unexpected results. You might consider using " + "nearest interpolation instead.") + # Suppresses warnings in https://github.com/nilearn/nilearn/issues/1363 with warnings.catch_warnings(): if LooseVersion(scipy.__version__) >= LooseVersion('0.18'): diff --git a/nilearn/image/tests/test_resampling.py b/nilearn/image/tests/test_resampling.py index 81f540e2bb..3ce3d4b7b2 100644 --- a/nilearn/image/tests/test_resampling.py +++ b/nilearn/image/tests/test_resampling.py @@ -13,6 +13,7 @@ from nibabel import Nifti1Image +from nilearn import _utils from nilearn.image.resampling import resample_img, resample_to_img, reorder_img from nilearn.image.resampling import from_matrix_vector, coord_transform from nilearn.image.resampling import get_bounds @@ -249,6 +250,27 @@ def test_resampling_error_checks(): interpolation="an_invalid_interpolation" ) + # Resampling a binary image with continuous or + # linear interpolation should raise a warning. + data_binary = rng.randint(4, size=(1, 4, 4)) + data_binary[data_binary>0] = 1 + assert sorted(list(np.unique(data_binary))) == [0,1] + + rot = rotation(0, np.pi / 4) + img_binary = Nifti1Image(data_binary, np.eye(4)) + assert _utils.niimg._is_binary_niimg(img_binary) + + with pytest.warns(Warning, match="Resampling binary images with"): + rot_img = resample_img(img_binary, + target_affine=rot, + interpolation='continuous') + + with pytest.warns(Warning, match="Resampling binary images with"): + rot_img = resample_img(img_binary, + target_affine=rot, + interpolation='linear') + + # Noop target_shape = shape[:3] diff --git a/nilearn/plotting/displays.py b/nilearn/plotting/displays.py index d04e822f2a..1c61e6625b 100644 --- a/nilearn/plotting/displays.py +++ b/nilearn/plotting/displays.py @@ -28,7 +28,7 @@ from .edge_detect import _edge_map from .find_cuts import find_xyz_cut_coords, find_cut_slices from .. import _utils -from ..image import new_img_like +from ..image import new_img_like, load_img from ..image.resampling import (get_bounds, reorder_img, coord_transform, get_mask_bounds) from nilearn.image import get_data @@ -834,7 +834,17 @@ def add_contours(self, img, threshold=1e-6, filled=False, **kwargs): def _map_show(self, img, type='imshow', resampling_interpolation='continuous', threshold=None, **kwargs): - img = reorder_img(img, resample=resampling_interpolation) + # In the special case where the affine of img is not diagonal, + # the function `reorder_img` will trigger a resampling + # of the provided image with a continuous interpolation + # since this is the default value here. In the special + # case where this image is binary, such as when this function + # is called from `add_contours`, continuous interpolation + # does not make sense and we turn to nearest interpolation instead. + if _utils.niimg._is_binary_niimg(img): + img = reorder_img(img, resample='nearest') + else: + img = reorder_img(img, resample=resampling_interpolation) threshold = float(threshold) if threshold is not None else None if threshold is not None: diff --git a/nilearn/plotting/tests/test_img_plotting.py b/nilearn/plotting/tests/test_img_plotting.py index 51c05b9d2f..5d85e088e0 100644 --- a/nilearn/plotting/tests/test_img_plotting.py +++ b/nilearn/plotting/tests/test_img_plotting.py @@ -14,6 +14,7 @@ from scipy import sparse +from nilearn import _utils from nilearn.image.resampling import coord_transform, reorder_img from nilearn._utils import data_gen from nilearn.image import get_data @@ -509,6 +510,7 @@ def test_plot_img_with_resampling(testdata_3d): [0., 0., 1., 0.], [0., 0., 0., 1.]]) img = nibabel.Nifti1Image(data, affine) + assert not _utils.niimg._is_binary_niimg(img) display = plot_img(img) display.add_overlay(img) display.add_contours(img, contours=2, linewidth=4, @@ -518,6 +520,21 @@ def test_plot_img_with_resampling(testdata_3d): # Save execution time and memory plt.close() +def test_plot_binary_img_with_resampling(testdata_3d): + data = get_data(testdata_3d['img']) + data[data > 0] = 1 + data[data < 0] = 0 + affine = np.array([[1., -1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]]) + img = nibabel.Nifti1Image(data, affine) + assert _utils.niimg._is_binary_niimg(img) + display = plot_img(img) + display.add_overlay(img) + display.add_contours(img) + plt.close() + def test_plot_noncurrent_axes(): """Regression test for Issue #450"""