Skip to content

Commit

Permalink
adding keep_masked_labels param
Browse files Browse the repository at this point in the history
  • Loading branch information
mtorabi59 committed Apr 26, 2023
1 parent 1b79941 commit 1819e24
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
15 changes: 13 additions & 2 deletions nilearn/maskers/nifti_labels_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ class _ExtractionFunctor:
func_name = 'nifti_labels_masker_extractor'

def __init__(self, _resampled_labels_img_, background_label, strategy,
mask_img):
keep_masked_labels, mask_img):
self._resampled_labels_img_ = _resampled_labels_img_
self.background_label = background_label
self.strategy = strategy
self.keep_masked_labels = keep_masked_labels
self.mask_img = mask_img

def __call__(self, imgs):
Expand All @@ -25,7 +26,7 @@ def __call__(self, imgs):
return signal_extraction.img_to_signals_labels(
imgs, self._resampled_labels_img_,
background_label=self.background_label, strategy=self.strategy,
mask_img=self.mask_img)
keep_masked_labels=self.keep_masked_labels, mask_img=self.mask_img)


@_utils.fill_doc
Expand Down Expand Up @@ -89,6 +90,12 @@ class NiftiLabelsMasker(BaseMasker, _utils.CacheMixin):
Must be one of: sum, mean, median, minimum, maximum, variance,
standard_deviation. Default='mean'.
keep_masked_labels : :obj:`bool`, optional
If False, the labels in labels_img that are masked by mask_img
will be removed from the output. If True, they are kept, meaning
that they will be filled with zero in signals in the output.
Default=True.
reports : :obj:`bool`, optional
If set to True, data is saved in order to produce a report.
Default=True.
Expand Down Expand Up @@ -137,6 +144,7 @@ def __init__(
memory_level=1,
verbose=0,
strategy='mean',
keep_masked_labels=True,
reports=True,
**kwargs,
):
Expand Down Expand Up @@ -199,6 +207,8 @@ def __init__(
f"parameter: {resampling_target}"
)

self.keep_masked_labels = keep_masked_labels

def generate_report(self):
"""Generate a report."""
from nilearn.reporting.html_report import generate_report
Expand Down Expand Up @@ -591,6 +601,7 @@ def transform_single_imgs(self, imgs, confounds=None, sample_mask=None):
self._resampled_labels_img_,
self.background_label,
self.strategy,
self.keep_masked_labels,
self._resampled_mask_img,
),
# Pre-processing
Expand Down
38 changes: 29 additions & 9 deletions nilearn/regions/signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def _check_shape_and_affine_compatibility(img1, img2=None, dim=None):


def _get_labels_data(
target_img, labels_img, mask_img=None, background_label=0, dim=None
target_img, labels_img, mask_img=None, background_label=0, dim=None,
keep_masked_labels=True
):
"""Get the label data.
Expand Down Expand Up @@ -133,6 +134,11 @@ def _get_labels_data(
dim : :obj:`int`, optional
Integer slices mask for a specific dimension.
keep_masked_labels : :obj:`bool`, optional
If False, the labels in labels_img that are masked by mask_img
will be removed from output labels. If True, they are kept.
Default=True.
Returns
-------
labels : :obj:`list` or :obj:`tuple`
Expand All @@ -154,9 +160,8 @@ def _get_labels_data(

labels_data = _safe_get_data(labels_img, ensure_finite=True)

labels = list(np.unique(labels_data))
if background_label in labels:
labels.remove(background_label)
if keep_masked_labels:
labels = list(np.unique(labels_data))

# Consider only data within the mask
use_mask = _check_shape_and_affine_compatibility(target_img, mask_img, dim)
Expand All @@ -166,6 +171,12 @@ def _get_labels_data(
labels_data = labels_data.copy()
labels_data[np.logical_not(mask_data)] = background_label

if not keep_masked_labels:
labels = list(np.unique(labels_data))

if background_label in labels:
labels.remove(background_label)

return labels, labels_data


Expand Down Expand Up @@ -206,6 +217,7 @@ def img_to_signals_labels(
background_label=0,
order="F",
strategy="mean",
keep_masked_labels=True
):
"""Extract region signals from image.
Expand Down Expand Up @@ -241,6 +253,12 @@ def img_to_signals_labels(
Must be one of: sum, mean, median, minimum, maximum, variance,
standard_deviation. Default="mean".
keep_masked_labels : :obj:`bool`, optional
If False, the labels in labels_img that are masked by mask_img
will be removed from the output. If True, they are kept, meaning
that they will be filled with zero in signals in the output.
Default=True.
Returns
-------
signals : :class:`numpy.ndarray`
Expand Down Expand Up @@ -269,7 +287,8 @@ def img_to_signals_labels(
# (load one image at a time).
imgs = _utils.check_niimg_4d(imgs)
labels, labels_data = _get_labels_data(
imgs, labels_img, mask_img, background_label
imgs, labels_img, mask_img, background_label,
keep_masked_labels=keep_masked_labels
)

data = _safe_get_data(imgs, ensure_finite=True)
Expand All @@ -284,10 +303,11 @@ def img_to_signals_labels(
reduction_function(img, labels=labels_data, index=labels)
)
# Set to zero signals for missing labels. Workaround for Scipy behaviour
missing_labels = set(labels) - set(np.unique(labels_data))
labels_index = {l: n for n, l in enumerate(labels)}
for this_label in missing_labels:
signals[:, labels_index[this_label]] = 0
if keep_masked_labels:
missing_labels = set(labels) - set(np.unique(labels_data))
labels_index = {l: n for n, l in enumerate(labels)}
for this_label in missing_labels:
signals[:, labels_index[this_label]] = 0
return signals, labels


Expand Down
12 changes: 12 additions & 0 deletions nilearn/regions/tests/test_signal_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,18 @@ def test_signal_extraction_with_maps_and_labels(labeled_regions, fmri_img):
maps_img_r = signals_to_img_maps(maps_signals, maps_img, mask_img=mask_img)
assert maps_img_r.shape == SHAPE + (N_TIMEPOINTS,)

# apply img_to_signals_labels with a masking,
# containing only 3 regions, but
# not keeping the masked labels

labels_signals, labels_labels = img_to_signals_labels(
imgs=fmri_img, labels_img=labeled_regions, mask_img=mask_img,
keep_masked_labels=False
)
# only 3 regions must be kept, others must be removed
assert labels_signals.shape == (N_TIMEPOINTS, 3)
assert len(labels_labels) == 3


def test_signal_extraction_nans_in_regions_are_replaced_with_zeros():
shape = (4, 5, 6)
Expand Down

0 comments on commit 1819e24

Please sign in to comment.