Skip to content

Commit

Permalink
[EHN] Implement signal.clean to SurfMasker (#4117)
Browse files Browse the repository at this point in the history
This is an inital attempt to implement signal.clean to SurfMasker
Concerns:
  clean_kwarg  - should filter out confounds and sample mask to prevent
user adding those parameters upon creating the masker
  all the cleaning related parameters - what do we want to choose to
expose?
  • Loading branch information
htwangtw committed Nov 24, 2023
1 parent b607726 commit d6d822d
Showing 1 changed file with 78 additions and 5 deletions.
83 changes: 78 additions & 5 deletions nilearn/experimental/surface/_maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from typing import Any

import numpy as np
import pandas as pd
from joblib import Memory
from sklearn.base import BaseEstimator, TransformerMixin

from nilearn import signal
from nilearn._utils.cache_mixin import CacheMixin, cache
from nilearn._utils.class_inspect import get_params
from nilearn.experimental.surface._surface_image import PolyMesh, SurfaceImage


Expand All @@ -25,16 +31,43 @@ def check_same_n_vertices(mesh_1: PolyMesh, mesh_2: PolyMesh) -> None:
)


class SurfaceMasker:
class SurfaceMasker(BaseEstimator, TransformerMixin, CacheMixin):
"""Extract data from a SurfaceImage."""

mask_img: SurfaceImage | None

mask_img_: SurfaceImage | None
output_dimension_: int | None

def __init__(self, mask_img=None):
def __init__(
self,
mask_img=None,
standardize=False,
standardize_confounds=True,
detrend=False,
high_variance_confounds=False,
low_pass=None,
high_pass=None,
t_r=None,
memory_level=1,
memory=Memory(location=None),
**kwargs,
):
self.mask_img = mask_img
self.standardize = standardize
self.standardize_confounds = standardize_confounds
self.high_variance_confounds = high_variance_confounds
self.detrend = detrend
self.low_pass = low_pass
self.high_pass = high_pass
self.t_r = t_r

self.memory = memory
self.memory_level = memory_level
self._shelving = False
self.clean_kwargs = {
k[7:]: v for k, v in kwargs.items() if k.startswith("clean__")
}

def _fit_mask_img(self, img: SurfaceImage | None) -> None:
if self.mask_img is not None:
Expand Down Expand Up @@ -92,7 +125,12 @@ def _check_fitted(self):
"before calling transform."
)

def transform(self, img: SurfaceImage) -> np.ndarray:
def transform(
self,
img: SurfaceImage,
confounds: pd.DataFrame | None = None,
sample_mask: np.ndarray | None = None,
) -> np.ndarray:
"""Extract signals from fitted surface object.
Parameters
Expand All @@ -106,6 +144,15 @@ def transform(self, img: SurfaceImage) -> np.ndarray:
Signal for each element.
shape: (img data shape, total number of vertices)
"""
parameters = get_params(
self.__class__,
self,
ignore=[
"mask_img",
],
)
parameters["clean_kwargs"] = self.clean_kwargs

self._check_fitted()
assert self.mask_img_ is not None
assert self.output_dimension_ is not None
Expand All @@ -115,9 +162,35 @@ def transform(self, img: SurfaceImage) -> np.ndarray:
mask = self.mask_img_.data[part_name]
assert isinstance(mask, np.ndarray)
output[..., start:stop] = img.data[part_name][..., mask]

# signal cleaning here
output = cache(
signal.clean,
memory=self.memory,
func_memory_level=2,
memory_level=self.memory_level,
shelve=self._shelving,
)(
output,
detrend=parameters["detrend"],
standardize=parameters["standardize"],
standardize_confounds=parameters["standardize_confounds"],
t_r=parameters["t_r"],
low_pass=parameters["low_pass"],
high_pass=parameters["high_pass"],
confounds=confounds,
sample_mask=sample_mask,
**parameters["clean_kwargs"],
)
return output

def fit_transform(self, img: SurfaceImage, y: Any = None) -> np.ndarray:
def fit_transform(
self,
img: SurfaceImage,
y: Any = None,
confounds: pd.DataFrame | None = None,
sample_mask: np.ndarray | None = None,
) -> np.ndarray:
"""Prepare and perform signal extraction from regions.
Parameters
Expand All @@ -136,7 +209,7 @@ def fit_transform(self, img: SurfaceImage, y: Any = None) -> np.ndarray:
shape: (img data shape, total number of vertices)
"""
del y
return self.fit(img).transform(img)
return self.fit(img).transform(img, confounds, sample_mask)

def inverse_transform(self, masked_img: np.ndarray) -> SurfaceImage:
"""Transform extracted signal back to surface object.
Expand Down

0 comments on commit d6d822d

Please sign in to comment.