Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX/ENH] Improve generate_report method to work with fit_transform for the maskers #3897

Merged
merged 8 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/changes/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Fixes

- Fix pathlib.Path not being counted as Niimg-like object in :func:`~image.new_img_like` (:gh:`3723` by `Maximilian Cosmo Sitter`_).

- Fix ``fit_transform`` behavior to match when ``fit`` method is passed image data (:gh:`3897` by `Yasmin Mzayek`_)

Enhancements
------------
Expand All @@ -48,6 +49,8 @@ Enhancements

- Add ``vmin`` and ``symmetric_cbar`` arguments to :func:`~nilearn.plotting.plot_img_on_surf` (:gh:`3873` by `Michelle Wang`_).

- Improve ``generate_report`` method of maskers by allowing users to pass a cmap argument for plotting image (:gh:`3897` by `Yasmin Mzayek`_)

Changes
-------

Expand Down
18 changes: 14 additions & 4 deletions nilearn/maskers/nifti_labels_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def __init__(

self.keep_masked_labels = keep_masked_labels

self.cmap = kwargs.get("cmap", "CMRmap_r")

def generate_report(self):
"""Generate a report."""
from nilearn.reporting.html_report import generate_report
Expand Down Expand Up @@ -299,7 +301,7 @@ def _reporting(self):
display = plotting.plot_img(
img,
black_bg=False,
cmap='CMRmap_r',
cmap=self.cmap,
)
plt.close()
display.add_contours(labels_image, filled=False, linewidths=3)
Expand Down Expand Up @@ -334,8 +336,15 @@ def _reporting(self):
def fit(self, imgs=None, y=None):
"""Prepare signal extraction from regions.

All parameters are unused, they are for scikit-learn compatibility.
Parameters
----------
imgs : :obj:`list` of Niimg-like objects
See :ref:`extracting_data`.
Image data passed to the reporter.

y : None
This parameter is unused. It is solely included for scikit-learn
compatibility.
"""
repr = _utils._repr_niimgs(self.labels_img,
shorten=(not self.verbose))
Expand Down Expand Up @@ -451,8 +460,9 @@ def fit_transform(self, imgs, confounds=None, sample_mask=None):
shape: (number of scans, number of labels)

"""
return self.fit().transform(imgs, confounds=confounds,
sample_mask=sample_mask)
return self.fit(imgs).transform(
imgs, confounds=confounds, sample_mask=sample_mask
)

def _check_fitted(self):
if not hasattr(self, 'labels_img_'):
Expand Down
15 changes: 12 additions & 3 deletions nilearn/maskers/nifti_maps_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def __init__(

self.keep_masked_maps = keep_masked_maps

self.cmap = kwargs.get("cmap", "CMRmap_r")

def generate_report(self, displayed_maps=10):
"""Generate an HTML report for the current ``NiftiMapsMasker`` object.

Expand Down Expand Up @@ -343,7 +345,7 @@ def _reporting(self):
img,
cut_coords=cut_coords[idx],
black_bg=False,
cmap="CMRmap_r",
cmap=self.cmap,
)
display.add_overlay(
image.index_img(maps_image, idx),
Expand All @@ -356,8 +358,15 @@ def _reporting(self):
def fit(self, imgs=None, y=None):
"""Prepare signal extraction from regions.

All parameters are unused, they are for scikit-learn compatibility.
Parameters
----------
imgs : :obj:`list` of Niimg-like objects
See :ref:`extracting_data`.
Image data passed to the reporter.

y : None
This parameter is unused. It is solely included for scikit-learn
compatibility.
"""
# Load images
repr = _utils._repr_niimgs(self.mask_img, shorten=(not self.verbose))
Expand Down Expand Up @@ -439,7 +448,7 @@ def _check_fitted(self):

def fit_transform(self, imgs, confounds=None, sample_mask=None):
"""Prepare and perform signal extraction."""
return self.fit().transform(
return self.fit(imgs).transform(
imgs, confounds=confounds, sample_mask=sample_mask
)

Expand Down
6 changes: 4 additions & 2 deletions nilearn/maskers/nifti_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def __init__(
k[7:]: v for k, v in kwargs.items() if k.startswith("clean__")
}

self.cmap = kwargs.get("cmap", "CMRmap_r")

def generate_report(self):
"""Generate a report of the masker."""
from nilearn.reporting.html_report import generate_report
Expand Down Expand Up @@ -357,7 +359,7 @@ def _reporting(self):
init_display = plotting.plot_img(
img,
black_bg=False,
cmap="CMRmap_r",
cmap=self.cmap,
)
plt.close()
if mask is not None:
Expand Down Expand Up @@ -386,7 +388,7 @@ def _reporting(self):
final_display = plotting.plot_img(
resampl_img,
black_bg=False,
cmap="CMRmap_r",
cmap=self.cmap,
)
plt.close()
final_display.add_contours(
Expand Down
5 changes: 3 additions & 2 deletions nilearn/maskers/nifti_spheres_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,9 @@ def fit_transform(self, imgs, confounds=None, sample_mask=None):
shape: (number of scans, number of spheres)

"""
return self.fit().transform(imgs, confounds=confounds,
sample_mask=sample_mask)
return self.fit(imgs).transform(
imgs, confounds=confounds, sample_mask=sample_mask
)

def _check_fitted(self):
if not hasattr(self, "seeds_"):
Expand Down