Skip to content

Commit

Permalink
add keep_masked_maps to NiftiMapsMasker (#3732)
Browse files Browse the repository at this point in the history
* add keep_masked_maps to NiftiMapsMasker

* add stacklevel to deprecation warning

* improve warning message and docstring

* raise warning when maps are removed

* make warning shorter

* minor change

* minor change

* update doc

* run black

* update warning message

* fix isort and flake8

* change the warning message

* minor formatting
  • Loading branch information
mtorabi59 committed Jul 8, 2023
1 parent 021255c commit 95086d8
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 4 deletions.
1 change: 1 addition & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ Changes
- Removed old files and test code from deprecated datasets COBRE and NYU resting state (:gh:`3743` by `Michelle Wang`_).
- :bdg-secondary:`Maint` PEP8 and isort compliance extended to the whole nilearn codebase. (:gh:`3538`, :gh:`3566`, :gh:`3548`, :gh:`3556`, :gh:`3601`, :gh:`3609`, :gh:`3646`, :gh:`3650`, :gh:`3647`, :gh:`3640`, :gh:`3615`, :gh:`3614`, :gh:`3648`, :gh:`#3651` by `Rémi Gau`_).
- :bdg-danger:`Deprecation` Empty region signals resulting from applying `mask_img` in :class:`~maskers.NiftiLabelsMasker` will no longer be kept in release 0.15. Meanwhile, use `keep_masked_labels` parameter when initializing the :class:`~maskers.NiftiLabelsMasker` object to enable/disable this behavior. (:gh:`3722` by `Mohammad Torabi`_).
- :bdg-danger:`Deprecation` Empty region signals resulting from applying `mask_img` in :class:`~maskers.NiftiMapsMasker` will no longer be kept in release 0.15. Meanwhile, use `keep_masked_maps` parameter when initializing the :class:`~maskers.NiftiMapsMasker` object to enable/disable this behavior. (:gh:`3732` by `Mohammad Torabi`_).
16 changes: 16 additions & 0 deletions nilearn/_utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,22 @@ def custom_function(vertices):
and the ``keep_masked_labels`` parameter will be removed in 0.15.
"""

# keep_masked_maps
docdict["keep_masked_maps"] = """
keep_masked_maps : :obj:`bool`, optional
If True, masked atlas with invalid maps (maps that contain only
zeros after applying the mask) will be retained in the output, resulting
in corresponding time series containing zeros only. If False, the
invalid maps will be removed from the trimmed atlas, resulting in
no empty time series in the output.
.. deprecated:: 0.10.2.dev
The 'True' option for ``keep_masked_maps`` is deprecated.
The default value will change to 'False' in 0.13,
and the ``keep_masked_maps`` parameter will be removed in 0.15.
"""

##############################################################################

docdict_indented = {}
Expand Down
13 changes: 11 additions & 2 deletions nilearn/maskers/nifti_maps_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ class _ExtractionFunctor:

func_name = 'nifti_maps_masker_extractor'

def __init__(self, _resampled_maps_img_, _resampled_mask_img_):
def __init__(self, _resampled_maps_img_, _resampled_mask_img_,
keep_masked_maps):
self._resampled_maps_img_ = _resampled_maps_img_
self._resampled_mask_img_ = _resampled_mask_img_
self.keep_masked_maps = keep_masked_maps

def __call__(self, imgs):
from ..regions import signal_extraction

return signal_extraction.img_to_signals_maps(
imgs, self._resampled_maps_img_,
mask_img=self._resampled_mask_img_)
mask_img=self._resampled_mask_img_,
keep_masked_maps=self.keep_masked_maps)


@_utils.fill_doc
Expand Down Expand Up @@ -75,6 +78,8 @@ class NiftiMapsMasker(BaseMasker, _utils.CacheMixin):
%(memory)s
%(memory_level)s
%(verbose0)s
%(keep_masked_maps)s
reports : :obj:`bool`, optional
If set to True, data is saved in order to produce a report.
Default=True.
Expand Down Expand Up @@ -122,6 +127,7 @@ def __init__(
t_r=None,
dtype=None,
resampling_target="data",
keep_masked_maps=True,
memory=Memory(location=None, verbose=0),
memory_level=0,
verbose=0,
Expand Down Expand Up @@ -180,6 +186,8 @@ def __init__(
"Set resampling_target to something else or provide a mask."
)

self.keep_masked_maps = keep_masked_maps

def generate_report(self, displayed_maps=10):
"""Generate an HTML report for the current ``NiftiMapsMasker`` object.
Expand Down Expand Up @@ -566,6 +574,7 @@ def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
imgs, _ExtractionFunctor(
self._resampled_maps_img_,
self._resampled_mask_img_,
self.keep_masked_maps,
),
# Pre-treatments
params,
Expand Down
33 changes: 31 additions & 2 deletions nilearn/regions/signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def signals_to_img_labels(


@_utils.fill_doc
def img_to_signals_maps(imgs, maps_img, mask_img=None):
def img_to_signals_maps(imgs, maps_img, mask_img=None, keep_masked_maps=True):
"""Extract region signals from image.
This function is applicable to regions defined by maps.
Expand All @@ -443,6 +443,7 @@ def img_to_signals_maps(imgs, maps_img, mask_img=None):
Mask to apply to regions before extracting signals.
Every point outside the mask is considered
as background (i.e. outside of any region).
%(keep_masked_maps)s
Returns
-------
Expand Down Expand Up @@ -474,12 +475,40 @@ def img_to_signals_maps(imgs, maps_img, mask_img=None):
use_mask = _check_shape_and_affine_compatibility(imgs, mask_img)
if use_mask:
mask_img = _utils.check_niimg_3d(mask_img)
labels_before_mask = set(labels)
maps_data, maps_mask, labels = _trim_maps(
maps_data,
_safe_get_data(mask_img, ensure_finite=True),
keep_empty=True,
keep_empty=keep_masked_maps,
)
maps_mask = _utils.as_ndarray(maps_mask, dtype=bool)
if keep_masked_maps:
warnings.warn(
'Applying "mask_img" before '
"signal extraction may result in empty region signals in the "
"output. These are currently kept. "
"Starting from version 0.13, the default behavior will be "
"changed to remove them by setting "
'"keep_masked_maps=False". '
'"keep_masked_maps" parameter will be removed '
"in version 0.15.",
DeprecationWarning,
stacklevel=2,
)
else:
labels_after_mask = set(labels)
labels_diff = labels_before_mask.difference(labels_after_mask)
# Raising a warning if any map is removed due to the mask
if len(labels_diff) > 0:
warnings.warn(
"After applying mask to the maps image, "
"maps with the following indices were "
f"removed: {labels_diff}. "
f"Out of {len(labels_before_mask)} maps, the "
"masked map image only contains "
f"{len(labels_after_mask)} maps.",
stacklevel=2,
)

data = _safe_get_data(imgs, ensure_finite=True)
region_signals = linalg.lstsq(maps_data[maps_mask, :], data[maps_mask, :])[
Expand Down
43 changes: 43 additions & 0 deletions nilearn/regions/tests/test_signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,49 @@ def test_signal_extraction_with_maps_and_labels(labeled_regions, fmri_img):
assert labels_signals.shape == (N_TIMEPOINTS, 8)
assert len(labels_labels) == 8

# apply img_to_signals_maps with a masking,
# containing only 3 regions, but
# not keeping the masked maps
with pytest.warns(
UserWarning,
match="After applying mask to the maps image, "
"maps with the following indices were "
r"removed: \{2, 3, 5, 6, 7\}. "
"Out of 8 maps, the "
"masked map image only contains "
"3 maps.",
):
maps_signals, maps_labels = img_to_signals_maps(
fmri_img, maps_img, mask_img=mask_img, keep_masked_maps=False
)

# only 3 regions must be kept, others must be removed
assert maps_signals.shape == (N_TIMEPOINTS, 3)
assert len(maps_labels) == 3

# apply img_to_signals_labels with a masking,
# containing only 3 regions, and
# keeping the masked labels
# test if the warning is raised
with pytest.warns(
DeprecationWarning,
match='Applying "mask_img" before '
"signal extraction may result in empty region signals in the "
"output. These are currently kept. "
"Starting from version 0.13, the default behavior will be "
"changed to remove them by setting "
'"keep_masked_maps=False". '
'"keep_masked_maps" parameter will be removed '
"in version 0.15.",
):
maps_signals, maps_labels = img_to_signals_maps(
fmri_img, maps_img, mask_img=mask_img, keep_masked_maps=True
)

# all regions must be kept
assert maps_signals.shape == (N_TIMEPOINTS, 8)
assert len(maps_labels) == 8


def test_signal_extraction_nans_in_regions_are_replaced_with_zeros():
shape = (4, 5, 6)
Expand Down

0 comments on commit 95086d8

Please sign in to comment.