Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Let _check_embedded_nifti_masker work with surface masker #4120

Merged
merged 14 commits into from
Dec 5, 2023
5 changes: 0 additions & 5 deletions examples/08_experimental/plot_surface_image_and_maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,6 @@ def adjust_screening_percentile(screening_percentile, *args, **kwargs):

param_validation.adjust_screening_percentile = adjust_screening_percentile

def check_embedded_nifti_masker(estimator, *args, **kwargs):
return estimator.mask

decoding.decoder._check_embedded_nifti_masker = check_embedded_nifti_masker


monkeypatch_masker_checks()

Expand Down
9 changes: 7 additions & 2 deletions nilearn/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from nilearn._utils import CacheMixin, fill_doc
from nilearn._utils.cache_mixin import _check_memory
from nilearn._utils.param_validation import check_feature_screening
from nilearn.experimental.surface import SurfaceMasker
from nilearn.maskers._masker_validation import _check_embedded_nifti_masker
from nilearn.regions.rena_clustering import ReNA

Expand Down Expand Up @@ -920,8 +921,12 @@
return scores

def _apply_mask(self, X):
# Nifti masking
self.masker_ = _check_embedded_nifti_masker(self, multi_subject=False)
masker_type = "nifti"
if isinstance(self.mask, SurfaceMasker):
masker_type = "surface"

Check warning on line 926 in nilearn/decoding/decoder.py

View check run for this annotation

Codecov / codecov/patch

nilearn/decoding/decoder.py#L926

Added line #L926 was not covered by tests
self.masker_ = _check_embedded_nifti_masker(
self, masker_type=masker_type
)
X = self.masker_.fit_transform(X)
self.mask_img_ = self.masker_.mask_img_

Expand Down
9 changes: 7 additions & 2 deletions nilearn/decoding/space_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sklearn.utils import check_array, check_X_y
from sklearn.utils.extmath import safe_sparse_dot

from nilearn.experimental.surface import SurfaceMasker
from nilearn.image import get_data
from nilearn.maskers._masker_validation import _check_embedded_nifti_masker
from nilearn.masking import _unmask_from_to_3d_array
Expand Down Expand Up @@ -850,8 +851,12 @@
if self.verbose:
tic = time.time()

# nifti masking
self.masker_ = _check_embedded_nifti_masker(self, multi_subject=False)
masker_type = "nifti"
if isinstance(self.mask, SurfaceMasker):
masker_type = "surface"

Check warning on line 856 in nilearn/decoding/space_net.py

View check run for this annotation

Codecov / codecov/patch

nilearn/decoding/space_net.py#L856

Added line #L856 was not covered by tests
self.masker_ = _check_embedded_nifti_masker(
self, masker_type=masker_type
)
X = self.masker_.fit_transform(X)

X, y = check_X_y(
Expand Down
16 changes: 11 additions & 5 deletions nilearn/maskers/_masker_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import numpy as np

from nilearn.experimental.surface import SurfaceMasker

from .._utils.cache_mixin import _check_memory
from .._utils.class_inspect import get_params
from .multi_nifti_masker import MultiNiftiMasker
from .nifti_masker import NiftiMasker


def _check_embedded_nifti_masker(estimator, multi_subject=True):
def _check_embedded_nifti_masker(estimator, masker_type="multi_nii"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just remove this keyword argument completely and just set masker_type based on isinstance checks?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that. ATM the masker_type could be "multi_nii", 'surface', 'nifti', but this is fare from clear.

"""Create a masker from instance parameters.

Base function for using a masker within a BaseEstimator class
Expand Down Expand Up @@ -38,11 +40,16 @@ def _check_embedded_nifti_masker(estimator, multi_subject=True):
New masker

"""
masker_type = MultiNiftiMasker if multi_subject else NiftiMasker
if masker_type == "surface":
masker_type = SurfaceMasker
elif masker_type == "multi_nii":
masker_type = MultiNiftiMasker
else:
masker_type = NiftiMasker
estimator_params = get_params(masker_type, estimator)
mask = getattr(estimator, "mask", None)

if isinstance(mask, (NiftiMasker, MultiNiftiMasker)):
if isinstance(mask, (NiftiMasker, MultiNiftiMasker, SurfaceMasker)):
# Creating (Multi)NiftiMasker from provided masker
masker_params = get_params(masker_type, mask)
new_masker_params = masker_params
Expand All @@ -52,7 +59,7 @@ def _check_embedded_nifti_masker(estimator, multi_subject=True):
new_masker_params = estimator_params
new_masker_params["mask_img"] = mask
# Forwarding system parameters of instance to new masker in all case
if multi_subject and hasattr(estimator, "n_jobs"):
if masker_type == "multi_nii" and hasattr(estimator, "n_jobs"):
# For MultiNiftiMasker only
new_masker_params["n_jobs"] = estimator.n_jobs

Expand Down Expand Up @@ -86,7 +93,6 @@ def _check_embedded_nifti_masker(estimator, multi_subject=True):
warning_msg.substitute(attribute="verbose", default_value="0")
)
new_masker_params["verbose"] = 0

# Raising warning if masker override parameters
conflict_string = ""
for param_key in sorted(estimator_params):
Expand Down
12 changes: 6 additions & 6 deletions nilearn/maskers/tests/test_masker_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from joblib import Memory
from sklearn.base import BaseEstimator

from nilearn.experimental.surface import SurfaceMasker
from nilearn.maskers import MultiNiftiMasker, NiftiMasker
from nilearn.maskers._masker_validation import _check_embedded_nifti_masker

Expand Down Expand Up @@ -80,14 +81,13 @@ def test_check_embedded_nifti_masker():
masker = _check_embedded_nifti_masker(owner)
assert type(masker) is MultiNiftiMasker

for mask, multi_subject in (
(MultiNiftiMasker(), True),
(NiftiMasker(), False),
for mask, masker_type in (
(MultiNiftiMasker(), "multi_nii"),
(NiftiMasker(), "nii"),
(SurfaceMasker(), "surface"),
):
owner = OwningClass(mask=mask)
masker = _check_embedded_nifti_masker(
owner, multi_subject=multi_subject
)
masker = _check_embedded_nifti_masker(owner, masker_type=masker_type)
assert isinstance(masker, type(mask))
for param_key in masker.get_params():
if param_key not in [
Expand Down