Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Let _check_embedded_nifti_masker work with surface masker #4120

Merged
merged 14 commits into from
Dec 5, 2023
5 changes: 0 additions & 5 deletions examples/08_experimental/plot_surface_image_and_maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,6 @@ def adjust_screening_percentile(screening_percentile, *args, **kwargs):

param_validation.adjust_screening_percentile = adjust_screening_percentile

def check_embedded_nifti_masker(estimator, *args, **kwargs):
return estimator.mask

decoding.decoder._check_embedded_nifti_masker = check_embedded_nifti_masker


monkeypatch_masker_checks()

Expand Down
6 changes: 1 addition & 5 deletions nilearn/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ def all_modules(modules_to_ignore=None, modules_to_consider=None):
"cannot be both specified."
)
if modules_to_ignore is None:
modules_to_ignore = {
"data",
"tests",
"externals",
}
modules_to_ignore = {"data", "tests", "externals", "conftest"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functions in conftest.py should be ignored by default

all_modules = []
root = str(Path(__file__).parent.parent)
with warnings.catch_warnings():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import numpy as np

from .._utils.cache_mixin import _check_memory
from .._utils.class_inspect import get_params
from .multi_nifti_masker import MultiNiftiMasker
from .nifti_masker import NiftiMasker
from nilearn.experimental.surface import SurfaceMasker
from nilearn.maskers import MultiNiftiMasker, NiftiMasker

from .cache_mixin import _check_memory
from .class_inspect import get_params

def _check_embedded_nifti_masker(estimator, multi_subject=True):

def check_embedded_masker(estimator, masker_type="multi_nii"):
"""Create a masker from instance parameters.

Base function for using a masker within a BaseEstimator class
Expand All @@ -29,20 +30,26 @@ def _check_embedded_nifti_masker(estimator, multi_subject=True):
instance : object, instance of BaseEstimator
The object that gives us the values of the parameters

multi_subject : boolean, default=True
Indicates whether to return a MultiNiftiMasker or a NiftiMasker
masker_type : {"multi_nii", "nii", "surface"}, default="mutli_nii"
Indicates whether to return a MultiNiftiMasker, NiftiMasker, or a
SurfaceMasker

Returns
-------
masker : MultiNiftiMasker or NiftiMasker
masker : MultiNiftiMasker, NiftiMasker, or SurfaceMasker
New masker

"""
masker_type = MultiNiftiMasker if multi_subject else NiftiMasker
if masker_type == "surface":
masker_type = SurfaceMasker
elif masker_type == "multi_nii":
masker_type = MultiNiftiMasker
else:
masker_type = NiftiMasker
estimator_params = get_params(masker_type, estimator)
mask = getattr(estimator, "mask", None)

if isinstance(mask, (NiftiMasker, MultiNiftiMasker)):
if isinstance(mask, (NiftiMasker, MultiNiftiMasker, SurfaceMasker)):
# Creating (Multi)NiftiMasker from provided masker
masker_params = get_params(masker_type, mask)
new_masker_params = masker_params
Expand All @@ -52,7 +59,7 @@ def _check_embedded_nifti_masker(estimator, multi_subject=True):
new_masker_params = estimator_params
new_masker_params["mask_img"] = mask
# Forwarding system parameters of instance to new masker in all case
if multi_subject and hasattr(estimator, "n_jobs"):
if masker_type == "multi_nii" and hasattr(estimator, "n_jobs"):
# For MultiNiftiMasker only
new_masker_params["n_jobs"] = estimator.n_jobs

Expand Down Expand Up @@ -86,7 +93,6 @@ def _check_embedded_nifti_masker(estimator, multi_subject=True):
warning_msg.substitute(attribute="verbose", default_value="0")
)
new_masker_params["verbose"] = 0

# Raising warning if masker override parameters
conflict_string = ""
for param_key in sorted(estimator_params):
Expand Down
2 changes: 1 addition & 1 deletion nilearn/_utils/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_number_public_functions():
If this is intentional, then the number should be updated in the test.
Otherwise it means that the public API of nilearn has changed by mistake.
"""
assert len({_[0] for _ in all_functions()}) == 227
assert len({_[0] for _ in all_functions()}) == 205
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consequence of ignoring functions in conftest.py



def test_number_public_classes():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from joblib import Memory
from sklearn.base import BaseEstimator

from nilearn._utils.masker_validation import check_embedded_masker
from nilearn.experimental.surface import SurfaceMasker
from nilearn.maskers import MultiNiftiMasker, NiftiMasker
from nilearn.maskers._masker_validation import _check_embedded_nifti_masker


class OwningClass(BaseEstimator):
Expand Down Expand Up @@ -53,10 +54,10 @@ def __init__(self, **kwargs):
setattr(self, k, v)

def fit(self, *args, **kwargs):
self.masker = _check_embedded_nifti_masker(self)
self.masker = check_embedded_masker(self)


def test_check_embedded_nifti_masker_defaults():
def test_check_embedded_masker_defaults():
dummy = DummyEstimator(memory=None, memory_level=1)
with pytest.warns(
Warning, match="Provided estimator has no verbose attribute set."
Expand All @@ -75,19 +76,18 @@ def test_check_embedded_nifti_masker_defaults():
assert dummy.masker.verbose == 1


def test_check_embedded_nifti_masker():
def test_check_embedded_masker():
owner = OwningClass()
masker = _check_embedded_nifti_masker(owner)
masker = check_embedded_masker(owner)
assert type(masker) is MultiNiftiMasker

for mask, multi_subject in (
(MultiNiftiMasker(), True),
(NiftiMasker(), False),
for mask, masker_type in (
(MultiNiftiMasker(), "multi_nii"),
(NiftiMasker(), "nii"),
(SurfaceMasker(), "surface"),
):
owner = OwningClass(mask=mask)
masker = _check_embedded_nifti_masker(
owner, multi_subject=multi_subject
)
masker = check_embedded_masker(owner, masker_type=masker_type)
assert isinstance(masker, type(mask))
for param_key in masker.get_params():
if param_key not in [
Expand All @@ -105,7 +105,7 @@ def test_check_embedded_nifti_masker():
affine = np.eye(4)
mask = nibabel.Nifti1Image(np.ones(shape[:3], dtype=np.int8), affine)
owner = OwningClass(mask=mask)
masker = _check_embedded_nifti_masker(owner)
masker = check_embedded_masker(owner)
assert masker.mask_img is mask

# Check attribute forwarding
Expand All @@ -115,11 +115,11 @@ def test_check_embedded_nifti_masker():
mask = MultiNiftiMasker()
mask.fit([[imgs]])
owner = OwningClass(mask=mask)
masker = _check_embedded_nifti_masker(owner)
masker = check_embedded_masker(owner)
assert masker.mask_img is mask.mask_img_

# Check conflict warning
mask = NiftiMasker(mask_strategy="epi")
owner = OwningClass(mask=mask)
with pytest.warns(UserWarning):
_check_embedded_nifti_masker(owner)
check_embedded_masker(owner)
8 changes: 8 additions & 0 deletions nilearn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from nilearn.datasets.tests._testing import request_mocker # noqa: F401
from nilearn.datasets.tests._testing import temp_nilearn_data_dir # noqa: F401

# TODO This import needs to be removed once the experimental surface API and
# its pytest fixtures are integrated into the stable API
from nilearn.experimental.surface.tests.conftest import ( # noqa: F401
make_mini_img,
mini_img,
mini_mesh,
)

collect_ignore = ["datasets/data/convert_templates.py"]
collect_ignore_glob = ["reporting/_visual_testing/*"]

Expand Down
20 changes: 12 additions & 8 deletions nilearn/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@

from nilearn._utils import CacheMixin, fill_doc
from nilearn._utils.cache_mixin import _check_memory
from nilearn._utils.masker_validation import check_embedded_masker
from nilearn._utils.param_validation import check_feature_screening
from nilearn.maskers._masker_validation import _check_embedded_nifti_masker
from nilearn.experimental.surface import SurfaceMasker
from nilearn.regions.rena_clustering import ReNA

SUPPORTED_ESTIMATORS = dict(
Expand Down Expand Up @@ -462,13 +463,14 @@ class _BaseDecoder(LinearRegression, CacheMixin):
For regression, choose among:
%(regressor_options)s

mask: filename, Nifti1Image, NiftiMasker, or MultiNiftiMasker, \
default=None
mask: filename, Nifti1Image, NiftiMasker, MultiNiftiMasker, or\
SurfaceMasker, default=None
Mask to be used on data. If an instance of masker is passed,
then its mask and parameters will be used. If no mask is given, mask
will be computed automatically from provided images by an inbuilt
masker with default parameters. Refer to NiftiMasker or
MultiNiftiMasker to check for default parameters.
MultiNiftiMasker to check for default parameters. For use with
SurfaceImage data, a SurfaceMasker instance must be passed.

cv: cross-validation generator or int, default=10
A cross-validation generator.
Expand Down Expand Up @@ -625,10 +627,10 @@ def fit(self, X, y, groups=None):

Attributes
----------
masker_ : instance of NiftiMasker or MultiNiftiMasker
masker_ : instance of NiftiMasker, MultiNiftiMasker, or SurfaceMasker
The NiftiMasker used to mask the data.

mask_img_ : Nifti1Image
mask_img_ : Nifti1Image or SurfaceImage
Mask computed by the masker object.

classes_ : numpy.ndarray
Expand Down Expand Up @@ -920,8 +922,10 @@ class would be predicted.
return scores

def _apply_mask(self, X):
# Nifti masking
self.masker_ = _check_embedded_nifti_masker(self, multi_subject=False)
masker_type = "nifti"
if isinstance(self.mask, SurfaceMasker):
masker_type = "surface"
self.masker_ = check_embedded_masker(self, masker_type=masker_type)
X = self.masker_.fit_transform(X)
self.mask_img_ = self.masker_.mask_img_

Expand Down
9 changes: 6 additions & 3 deletions nilearn/decoding/space_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from sklearn.utils import check_array, check_X_y
from sklearn.utils.extmath import safe_sparse_dot

from nilearn._utils.masker_validation import check_embedded_masker
from nilearn.experimental.surface import SurfaceMasker
from nilearn.image import get_data
from nilearn.maskers._masker_validation import _check_embedded_nifti_masker
from nilearn.masking import _unmask_from_to_3d_array

from .._utils import fill_doc
Expand Down Expand Up @@ -850,8 +851,10 @@
if self.verbose:
tic = time.time()

# nifti masking
self.masker_ = _check_embedded_nifti_masker(self, multi_subject=False)
masker_type = "nifti"
if isinstance(self.mask, SurfaceMasker):
masker_type = "surface"

Check warning on line 856 in nilearn/decoding/space_net.py

View check run for this annotation

Codecov / codecov/patch

nilearn/decoding/space_net.py#L856

Added line #L856 was not covered by tests
self.masker_ = check_embedded_masker(self, masker_type=masker_type)
X = self.masker_.fit_transform(X)

X, y = check_X_y(
Expand Down
11 changes: 11 additions & 0 deletions nilearn/decoding/tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
_wrap_param_grid,
)
from nilearn.decoding.tests.test_same_api import to_niimgs
from nilearn.experimental.surface import SurfaceMasker
from nilearn.maskers import NiftiMasker

N_SAMPLES = 100
Expand Down Expand Up @@ -968,6 +969,16 @@ def test_decoder_multiclass_classification_apply_mask_attributes(affine_eye):
assert model.masker_.smoothing_fwhm == smoothing_fwhm


def test_decoder_apply_mask_surface(mini_img):
"""Test whether _apply_mask works for surface image."""
X = mini_img
model = Decoder(mask=SurfaceMasker())
X_masked = model._apply_mask(X)

assert X_masked.shape == X.shape
assert type(model.mask_img_).__name__ == "SurfaceImage"


def test_decoder_multiclass_error_incorrect_cv(multiclass_data):
"""Check whether ValueError is raised when cv is not set correctly."""
X, y, _ = multiclass_data
Expand Down
4 changes: 2 additions & 2 deletions nilearn/decomposition/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from sklearn.utils.extmath import randomized_svd, svd_flip

import nilearn
from nilearn._utils.masker_validation import check_embedded_masker
from nilearn.maskers import NiftiMapsMasker
from nilearn.maskers._masker_validation import _check_embedded_nifti_masker

from .._utils import fill_doc
from .._utils.cache_mixin import CacheMixin, cache
Expand Down Expand Up @@ -420,7 +420,7 @@ def fit(self, imgs, y=None, confounds=None):
"Need one or more Niimg-like objects as input, "
"an empty list was given."
)
self.masker_ = _check_embedded_nifti_masker(self)
self.masker_ = check_embedded_masker(self)

# Avoid warning with imgs != None
# if masker_ has been provided a mask_img
Expand Down