Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding keep_masked_labels parameter to Labels_masker #3722

Merged
merged 17 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ Enhancements
Changes
-------

:bdg-danger:`Deprecation` Empty region signals resulting from applying `mask_img` in :class:`~maskers.NiftiLabelsMasker` will no longer be kept in release 0.15. Meanwhile, use `keep_masked_labels` parameter when initializing the :class:`~maskers.NiftiLabelsMasker` object to enable/disable this behavior. (:gh:`3722` by `Mohammad Torabi`_).
25 changes: 23 additions & 2 deletions nilearn/maskers/nifti_labels_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ class _ExtractionFunctor:
func_name = 'nifti_labels_masker_extractor'

def __init__(self, _resampled_labels_img_, background_label, strategy,
mask_img):
keep_masked_labels, mask_img):
self._resampled_labels_img_ = _resampled_labels_img_
self.background_label = background_label
self.strategy = strategy
self.keep_masked_labels = keep_masked_labels
self.mask_img = mask_img

def __call__(self, imgs):
Expand All @@ -25,7 +26,7 @@ def __call__(self, imgs):
return signal_extraction.img_to_signals_labels(
imgs, self._resampled_labels_img_,
background_label=self.background_label, strategy=self.strategy,
mask_img=self.mask_img)
keep_masked_labels=self.keep_masked_labels, mask_img=self.mask_img)


@_utils.fill_doc
Expand Down Expand Up @@ -89,6 +90,22 @@ class NiftiLabelsMasker(BaseMasker, _utils.CacheMixin):
Must be one of: sum, mean, median, minimum, maximum, variance,
standard_deviation. Default='mean'.

keep_masked_labels : :obj:`bool`, optional
When a mask is supplied through the "mask_img" parameter, some
labels in the atlas may not have any brain coverage within the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can still improve the docstring a little bit. maybe instead of "may not have any brain coverage within the masked region" we could say "some atlas regions may lie entirely outside of the brain mask"?

masked region, resulting in empty time series for those labels.
If True, the masked atlas with these empty labels will be retained
in the output, resulting in corresponding time series containing
zeros only. If False, the empty labels will be removed from the
output, ensuring no empty time series are present.
Default=True.

.. deprecated:: 0.9.2

The 'True' option for ``keep_masked_labels`` is deprecated.
The default value will change to 'False' in 0.13,
and the ``keep_masked_labels`` parameter will be removed in 0.15.

reports : :obj:`bool`, optional
If set to True, data is saved in order to produce a report.
Default=True.
Expand Down Expand Up @@ -137,6 +154,7 @@ def __init__(
memory_level=1,
verbose=0,
strategy='mean',
keep_masked_labels=True,
reports=True,
**kwargs,
):
Expand Down Expand Up @@ -199,6 +217,8 @@ def __init__(
f"parameter: {resampling_target}"
)

self.keep_masked_labels = keep_masked_labels

def generate_report(self):
"""Generate a report."""
from nilearn.reporting.html_report import generate_report
Expand Down Expand Up @@ -591,6 +611,7 @@ def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
self._resampled_labels_img_,
self.background_label,
self.strategy,
self.keep_masked_labels,
self._resampled_mask_img,
),
# Pre-processing
Expand Down
97 changes: 88 additions & 9 deletions nilearn/regions/signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
# Author: Philippe Gervais
# License: simplified BSD
import warnings

import numpy as np
from scipy import linalg, ndimage
Expand Down Expand Up @@ -102,7 +103,8 @@ def _check_shape_and_affine_compatibility(img1, img2=None, dim=None):


def _get_labels_data(
target_img, labels_img, mask_img=None, background_label=0, dim=None
target_img, labels_img, mask_img=None, background_label=0, dim=None,
keep_masked_labels=True
):
"""Get the label data.

Expand Down Expand Up @@ -133,6 +135,22 @@ def _get_labels_data(
dim : :obj:`int`, optional
Integer slices mask for a specific dimension.

keep_masked_labels : :obj:`bool`, optional
When a mask is supplied through the "mask_img" parameter, some
labels in the atlas may not have any brain coverage within the
masked region, resulting in empty time series for those labels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may want to take advantage of the _utils.fill_doc mechanism to avoid repeating this docstring

If True, the masked atlas with these empty labels will be retained
in the output, resulting in corresponding time series containing
zeros only. If False, the empty labels will be removed from the
output, ensuring no empty time series are present.
Default=True.

.. deprecated:: 0.9.2

The 'True' option for ``keep_masked_labels`` is deprecated.
The default value will change to 'False' in 0.13,
and the ``keep_masked_labels`` parameter will be removed in 0.15.

Returns
-------
labels : :obj:`list` or :obj:`tuple`
Expand All @@ -154,17 +172,59 @@ def _get_labels_data(

labels_data = _safe_get_data(labels_img, ensure_finite=True)

labels = list(np.unique(labels_data))
if background_label in labels:
labels.remove(background_label)
if keep_masked_labels:
labels = list(np.unique(labels_data))
ymzayek marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
'Starting in version 0.15, the behavior of "NiftiLabelsMasker" '
'will change when a mask is supplied through the "mask_img" '
'parameter. Applying "mask_img" before '
'signal extraction may result in empty region signals in the '
'output. These empty region signals used to be kept. '
'In the new behavior, they will be removed from the output.'
'\n'
'To explicitly enable/disable the retention of empty '
'region signals, set the parameter "keep_masked_labels" '
'to True/False when '
'initializing the "NiftiLabelsMasker" object. '
'Starting from version 0.13, the default behavior will be '
'changed to "keep_masked_labels=False". '
'"keep_masked_labels" parameter will be removed '
'in version 0.15.',
DeprecationWarning,
stacklevel=3
)

# Consider only data within the mask
use_mask = _check_shape_and_affine_compatibility(target_img, mask_img, dim)
if use_mask:
mask_img = _utils.check_niimg_3d(mask_img)
mask_data = _safe_get_data(mask_img, ensure_finite=True)
labels_data = labels_data.copy()
labels_before_mask = set(np.unique(labels_data))
# Applying mask on labels_data
labels_data[np.logical_not(mask_data)] = background_label
labels_after_mask = set(np.unique(labels_data))
labels_diff = labels_before_mask.difference(
labels_after_mask
)
# Raising a warning if any label is removed due to the mask
if len(labels_diff) > 0 and (not keep_masked_labels):
warnings.warn(
"After applying mask to the labels image, "
"the following labels were "
f"removed: {labels_diff}. "
f"Out of {len(labels_before_mask)} labels, the "
"masked labels image only contains "
f"{len(labels_after_mask)} labels "
"(including background).",
stacklevel=3
)

if not keep_masked_labels:
labels = list(np.unique(labels_data))

if background_label in labels:
labels.remove(background_label)

return labels, labels_data

Expand Down Expand Up @@ -206,6 +266,7 @@ def img_to_signals_labels(
background_label=0,
order="F",
strategy="mean",
keep_masked_labels=True
):
"""Extract region signals from image.

Expand Down Expand Up @@ -241,6 +302,22 @@ def img_to_signals_labels(
Must be one of: sum, mean, median, minimum, maximum, variance,
standard_deviation. Default="mean".

keep_masked_labels : :obj:`bool`, optional
When a mask is supplied through the "mask_img" parameter, some
labels in the atlas may not have any brain coverage within the
masked region, resulting in empty time series for those labels.
If True, the masked atlas with these empty labels will be retained
in the output, resulting in corresponding time series containing
zeros only. If False, the empty labels will be removed from the
output, ensuring no empty time series are present.
Default=True.

.. deprecated:: 0.9.2

The 'True' option for ``keep_masked_labels`` is deprecated.
The default value will change to 'False' in 0.13,
and the ``keep_masked_labels`` parameter will be removed in 0.15.

Returns
-------
signals : :class:`numpy.ndarray`
Expand Down Expand Up @@ -269,7 +346,8 @@ def img_to_signals_labels(
# (load one image at a time).
imgs = _utils.check_niimg_4d(imgs)
labels, labels_data = _get_labels_data(
imgs, labels_img, mask_img, background_label
imgs, labels_img, mask_img, background_label,
keep_masked_labels=keep_masked_labels
)

data = _safe_get_data(imgs, ensure_finite=True)
Expand All @@ -284,10 +362,11 @@ def img_to_signals_labels(
reduction_function(img, labels=labels_data, index=labels)
)
# Set to zero signals for missing labels. Workaround for Scipy behaviour
missing_labels = set(labels) - set(np.unique(labels_data))
labels_index = {l: n for n, l in enumerate(labels)}
for this_label in missing_labels:
signals[:, labels_index[this_label]] = 0
if keep_masked_labels:
missing_labels = set(labels) - set(np.unique(labels_data))
labels_index = {l: n for n, l in enumerate(labels)}
for this_label in missing_labels:
signals[:, labels_index[this_label]] = 0
return signals, labels


Expand Down
12 changes: 12 additions & 0 deletions nilearn/regions/tests/test_signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,18 @@ def test_signal_extraction_with_maps_and_labels(labeled_regions, fmri_img):
maps_img_r = signals_to_img_maps(maps_signals, maps_img, mask_img=mask_img)
assert maps_img_r.shape == SHAPE + (N_TIMEPOINTS,)

# apply img_to_signals_labels with a masking,
# containing only 3 regions, but
# not keeping the masked labels

labels_signals, labels_labels = img_to_signals_labels(
imgs=fmri_img, labels_img=labeled_regions, mask_img=mask_img,
keep_masked_labels=False
)
# only 3 regions must be kept, others must be removed
assert labels_signals.shape == (N_TIMEPOINTS, 3)
assert len(labels_labels) == 3


def test_signal_extraction_nans_in_regions_are_replaced_with_zeros():
shape = (4, 5, 6)
Expand Down