Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add keep_masked_maps to NiftiMapsMasker #3732

Merged
merged 15 commits into from
Jul 8, 2023
1 change: 1 addition & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ Changes
-------

- Removed old files and test code from deprecated datasets COBRE and NYU resting state (:gh:`3743` by `Michelle Wang`_).
- :bdg-danger:`Deprecation` Empty region signals resulting from applying `mask_img` in :class:`~maskers.NiftiMapsMasker` will no longer be kept in release 0.15. Meanwhile, use `keep_masked_maps` parameter when initializing the :class:`~maskers.NiftiMapsMasker` object to enable/disable this behavior. (:gh:`3732` by `Mohammad Torabi`_).
15 changes: 15 additions & 0 deletions nilearn/_utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,21 @@ def custom_function(vertices):
Passed to :func:`matplotlib.pyplot.imshow`.
"""

# keep_masked_maps
docdict["keep_masked_maps"] = """
keep_masked_maps : :obj:`bool`, optional
If True, masked atlas with invalid maps (maps that contain only
zeros after applying the mask) will be retained in the output, resulting
in corresponding time series containing zeros only. If False, the
invalid maps will be removed from the trimmed atlas, resulting in
no empty time series in the output.

.. deprecated:: 0.10.2.dev
The 'True' option for ``keep_masked_maps`` is deprecated.
The default value will change to 'False' in 0.13,
and the ``keep_masked_maps`` parameter will be removed in 0.15.
"""

##############################################################################

docdict_indented = {}
Expand Down
13 changes: 11 additions & 2 deletions nilearn/maskers/nifti_maps_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@ class _ExtractionFunctor:

func_name = 'nifti_maps_masker_extractor'

def __init__(self, _resampled_maps_img_, _resampled_mask_img_):
def __init__(self, _resampled_maps_img_, _resampled_mask_img_,
keep_masked_maps):
self._resampled_maps_img_ = _resampled_maps_img_
self._resampled_mask_img_ = _resampled_mask_img_
self.keep_masked_maps = keep_masked_maps

def __call__(self, imgs):
from ..regions import signal_extraction

return signal_extraction.img_to_signals_maps(
imgs, self._resampled_maps_img_,
mask_img=self._resampled_mask_img_)
mask_img=self._resampled_mask_img_,
keep_masked_maps=self.keep_masked_maps)


@_utils.fill_doc
Expand Down Expand Up @@ -75,6 +78,8 @@ class NiftiMapsMasker(BaseMasker, _utils.CacheMixin):
%(memory)s
%(memory_level)s
%(verbose0)s
%(keep_masked_maps)s

reports : :obj:`bool`, optional
If set to True, data is saved in order to produce a report.
Default=True.
Expand Down Expand Up @@ -122,6 +127,7 @@ def __init__(
t_r=None,
dtype=None,
resampling_target="data",
keep_masked_maps=True,
memory=Memory(location=None, verbose=0),
memory_level=0,
verbose=0,
Expand Down Expand Up @@ -180,6 +186,8 @@ def __init__(
"Set resampling_target to something else or provide a mask."
)

self.keep_masked_maps = keep_masked_maps

def generate_report(self, displayed_maps=10):
"""Generate an HTML report for the current ``NiftiMapsMasker`` object.

Expand Down Expand Up @@ -572,6 +580,7 @@ def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
imgs, _ExtractionFunctor(
self._resampled_maps_img_,
self._resampled_mask_img_,
self.keep_masked_maps,
),
# Pre-treatments
params,
Expand Down
35 changes: 33 additions & 2 deletions nilearn/regions/signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# Author: Philippe Gervais
# License: simplified BSD

import warnings

import numpy as np
from scipy import linalg, ndimage

Expand Down Expand Up @@ -376,7 +378,7 @@ def signals_to_img_labels(


@_utils.fill_doc
def img_to_signals_maps(imgs, maps_img, mask_img=None):
def img_to_signals_maps(imgs, maps_img, mask_img=None, keep_masked_maps=True):
"""Extract region signals from image.

This function is applicable to regions defined by maps.
Expand All @@ -396,6 +398,7 @@ def img_to_signals_maps(imgs, maps_img, mask_img=None):
Mask to apply to regions before extracting signals.
Every point outside the mask is considered
as background (i.e. outside of any region).
%(keep_masked_maps)s

Returns
-------
Expand Down Expand Up @@ -427,12 +430,40 @@ def img_to_signals_maps(imgs, maps_img, mask_img=None):
use_mask = _check_shape_and_affine_compatibility(imgs, mask_img)
if use_mask:
mask_img = _utils.check_niimg_3d(mask_img)
labels_before_mask = set(labels)
maps_data, maps_mask, labels = _trim_maps(
maps_data,
_safe_get_data(mask_img, ensure_finite=True),
keep_empty=True,
keep_empty=keep_masked_maps,
)
maps_mask = _utils.as_ndarray(maps_mask, dtype=bool)
if keep_masked_maps:
warnings.warn(
bthirion marked this conversation as resolved.
Show resolved Hide resolved
'Applying "mask_img" before '
"signal extraction may result in empty region signals in the "
"output. These are currently kept. "
"Starting from version 0.13, the default behavior will be "
"changed to remove them by setting "
'"keep_masked_maps=False". '
'"keep_masked_maps" parameter will be removed '
"in version 0.15.",
DeprecationWarning,
stacklevel=2,
)
else:
labels_after_mask = set(labels)
labels_diff = labels_before_mask.difference(labels_after_mask)
# Raising a warning if any map is removed due to the mask
if len(labels_diff) > 0:
warnings.warn(
"After applying mask to the maps image, "
"the following maps were "
f"removed: {labels_diff}. "
f"Out of {len(labels_before_mask)} maps, the "
"masked map image only contains "
f"{len(labels_after_mask)} maps.",
stacklevel=2,
)

data = _safe_get_data(imgs, ensure_finite=True)
region_signals = linalg.lstsq(maps_data[maps_mask, :], data[maps_mask, :])[
Expand Down
11 changes: 11 additions & 0 deletions nilearn/regions/tests/test_signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,17 @@ def test_signal_extraction_with_maps_and_labels(labeled_regions, fmri_img):
maps_img_r = signals_to_img_maps(maps_signals, maps_img, mask_img=mask_img)
assert maps_img_r.shape == SHAPE + (N_TIMEPOINTS,)

# apply img_to_signals_maps with a masking,
# containing only 3 regions, but
# not keeping the masked maps

maps_signals, maps_labels = img_to_signals_maps(
fmri_img, maps_img, mask_img=mask_img, keep_masked_maps=False
)
# only 3 regions must be kept, others must be removed
assert maps_signals.shape == (N_TIMEPOINTS, 3)
assert len(maps_labels) == 3


def test_signal_extraction_nans_in_regions_are_replaced_with_zeros():
shape = (4, 5, 6)
Expand Down