Skip to content

Commit

Permalink
Support atlases in plot_carpet (#2702)
Browse files Browse the repository at this point in the history
* Get atlas working.

* Add mask_labels (unused).

* Add labels.

* Add test for plot_carpet with atlas.

* Fix test.

* Coerce datatypes more elegantly.

* Remove f-string and fix.

* Add atlas to example.

* Replace double-quotes with single-quotes.

* Add whatsnew entry.

* Remove unused colorbar argument.

* Add undocumented arguments to docstrings.
  • Loading branch information
tsalo committed Mar 4, 2021
1 parent f7e2c16 commit fa85bf0
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 18 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ Enhancements
colorbar of surface plots. The default format is scientific notation except for :func:`nilearn.plotting.plot_surf_roi`
for which it is set as integers.

- :func:`nilearn.plotting.plot_carpet` now supports discrete atlases.
When an atlas is used, a colorbar is added to the figure,
optionally with labels corresponding to the different values in the atlas.

- :class:`nilearn.input_data.NiftiMasker`, :class:`nilearn.input_data.NiftiLabelsMasker`,
:class:`nilearn.input_data.MultiNiftiMasker`, :class:`nilearn.input_data.NiftiMapsMasker`,
and :class:`nilearn.input_data.NiftiSpheresMasker` can now compute high variance confounds
Expand Down
40 changes: 40 additions & 0 deletions examples/01_plotting/plot_carpet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,43 @@
display = plot_carpet(adhd_dataset.func[0], mask_img)

display.show()

###############################################################################
# Deriving a label-based mask
# ---------------------------
# Create a gray matter/white matter/cerebrospinal fluid mask from
# ICBM152 tissue probability maps.
import nibabel as nib
import numpy as np
from nilearn import image

atlas = datasets.fetch_icbm152_2009()
atlas_img = image.concat_imgs((atlas["gm"], atlas["wm"], atlas["csf"]))
map_labels = {"Gray Matter": 1, "White Matter": 2, "Cerebrospinal Fluid": 3}

atlas_data = atlas_img.get_fdata()
discrete_version = np.argmax(atlas_data, axis=3) + 1
discrete_version[np.max(atlas_data, axis=3) == 0] = 0
discrete_atlas_img = nib.Nifti1Image(
discrete_version,
atlas_img.affine,
atlas_img.header,
)

###############################################################################
# Visualizing global patterns, separated by tissue type
# -----------------------------------------------------
import matplotlib.pyplot as plt

from nilearn.plotting import plot_carpet

fig, ax = plt.subplots(figsize=(10, 10))

display = plot_carpet(
adhd_dataset.func[0],
discrete_atlas_img,
mask_labels=map_labels,
axes=ax,
)

fig.show()
143 changes: 127 additions & 16 deletions nilearn/plotting/img_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec as mgs

from .. import _utils
from .._utils.extmath import fast_abs_percentile
from .._utils.param_validation import check_threshold
from .._utils.ndimage import get_border_data
from ..datasets import load_mni152_template
from ..image import new_img_like, iter_img, get_data
from ..image import new_img_like, iter_img, get_data, math_img, resample_to_img
from ..input_data import NiftiMasker
from nilearn.image.resampling import reorder_img
from ..masking import compute_epi_mask, apply_mask
from .displays import get_slicer, get_projector
Expand Down Expand Up @@ -1909,8 +1911,10 @@ def plot_markers(node_values, node_coords, node_size='auto',
return display


def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
figure=None, axes=None, vmin=None, vmax=None, title=None):
def plot_carpet(img, mask_img=None, mask_labels=None,
detrend=True, output_file=None,
figure=None, axes=None, vmin=None, vmax=None, title=None,
cmap=plt.cm.gist_ncar):
"""Plot an image representation of voxel intensities across time.
This figure is also known as a "grayplot" or "Power plot".
Expand All @@ -1923,9 +1927,17 @@ def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
mask_img : Niimg-like object or None, optional
Limit plotted voxels to those inside the provided mask (default is
None). If not specified a new mask will be derived from data.
None). If a 3D atlas is provided, voxels will be grouped by atlas
value and a colorbar will be added to the left side of the figure
with atlas labels.
If not specified, a new mask will be derived from data.
See http://nilearn.github.io/manipulating_images/input_output.html.
mask_labels : :obj:`dict`, optional
If ``mask_img`` corresponds to an atlas, then this dictionary maps
values from the ``mask_img`` to labels. Dictionary keys are labels
and values are values within the atlas.
detrend : :obj:`bool`, optional
Detrend and z-score the data prior to plotting. Default=True.
Expand All @@ -1943,9 +1955,23 @@ def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
The axes used to display the plot (default is None).
If None, the complete figure is used.
vmin : float or None, optional
Lower bound for plotting, passed to matplotlib.pyplot.imshow.
If None, vmin will be automatically determined based on the data.
Default=None.
vmax : float or None, optional
Upper bound for plotting, passed to matplotlib.pyplot.imshow.
If None, vmax will be automatically determined based on the data.
Default=None.
title : :obj:`str` or None, optional
The title displayed on the figure (default is None).
cmap : matplotlib colormap, optional
The colormap for the sidebar, if an atlas is used.
Default=plt.cm.gist_ncar.
Returns
-------
figure : :class:`matplotlib.figure.Figure`
Expand Down Expand Up @@ -1976,7 +2002,39 @@ def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
else:
mask_img = _utils.check_niimg_3d(mask_img, dtype='auto')

data = apply_mask(img, mask_img)
is_atlas = len(np.unique(mask_img.get_fdata())) > 2
if is_atlas:
background_label = 0

atlas_img_res = resample_to_img(
mask_img,
img,
interpolation='nearest',
)
atlas_bin = math_img(
'img != {}'.format(background_label),
img=atlas_img_res,
)
masker = NiftiMasker(atlas_bin, target_affine=img.affine)

data = masker.fit_transform(img)
atlas_values = masker.transform(atlas_img_res)
atlas_values = np.squeeze(atlas_values)

if mask_labels:
label_dtype = type(list(mask_labels.values())[0])
if label_dtype != atlas_values.dtype:
print('Coercing atlas_values to {}'.format(label_dtype))
atlas_values = atlas_values.astype(label_dtype)

# Sort data and atlas by atlas values
order = np.argsort(atlas_values)
order = np.squeeze(order)
atlas_values = atlas_values[order]
data = data[:, order]
else:
data = apply_mask(img, mask_img)

# Detrend and standardize data
if detrend:
data = clean(data, t_r=tr, detrend=True, standardize='zscore')
Expand All @@ -1991,8 +2049,7 @@ def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
if axes is None:
axes = figure.add_subplot(1, 1, 1)
else:
assert axes.figure is figure, ("The axes passed are not "
"in the figure")
assert axes.figure is figure, ('The axes passed are not in the figure')

# Determine vmin and vmax based on the full data
std = np.mean(data.std(axis=0))
Expand All @@ -2006,10 +2063,60 @@ def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
n_decimations = int(np.ceil(np.log2(np.ceil(n_tsteps / LONG_CUTOFF))))
data = data[::2 ** n_decimations, :]

axes.imshow(data.T, interpolation='nearest',
aspect='auto', cmap='gray',
vmin=vmin or default_vmin,
vmax=vmax or default_vmax)
if is_atlas:
# Define nested GridSpec
legend = False
wratios = [2, 100, 20]
gs = mgs.GridSpecFromSubplotSpec(
1,
2 + int(legend),
subplot_spec=axes,
width_ratios=wratios[: 2 + int(legend)],
wspace=0.0,
)

ax0 = plt.subplot(gs[0])
ax0.set_xticks([])
ax0.imshow(
atlas_values[:, np.newaxis],
interpolation='none',
aspect='auto',
cmap=cmap
)
if mask_labels:
# Add labels to middle of each associated band
mask_labels_inv = {v: k for k, v in mask_labels.items()}
ytick_locs = [
np.mean(np.where(atlas_values == i)[0])
for i in np.unique(atlas_values)
]
ax0.set_yticks(ytick_locs)
ax0.set_yticklabels([
mask_labels_inv[i] for i in np.unique(atlas_values)
])
else:
ax0.set_yticks([])

# Carpet plot
axes = plt.subplot(gs[1]) # overwrite axes
axes.imshow(
data.T,
interpolation='nearest',
aspect='auto',
cmap='gray',
vmin=vmin or default_vmin,
vmax=vmax or default_vmax,
)
ax0.tick_params(axis='both', which='both', length=0)
else:
axes.imshow(
data.T,
interpolation='nearest',
aspect='auto',
cmap='gray',
vmin=vmin or default_vmin,
vmax=vmax or default_vmax,
)

axes.grid(False)
axes.set_yticks([])
Expand All @@ -2020,11 +2127,11 @@ def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
(int(data.shape[0] + 1) // 10, int(data.shape[0] + 1) // 5, 1))
xticks = list(range(0, data.shape[0])[::interval])
axes.set_xticks(xticks)

axes.set_xlabel('time (s)')
axes.set_ylabel('voxels')

if title:
axes.set_title(title)

labels = tr * (np.array(xticks))
labels *= (2 ** n_decimations)
axes.set_xticklabels(['%.02f' % t for t in labels.tolist()])
Expand All @@ -2035,10 +2142,14 @@ def plot_carpet(img, mask_img=None, detrend=True, output_file=None,
axes.spines[side].set_color('none')
axes.spines[side].set_visible(False)

axes.yaxis.set_ticks_position('left')
axes.xaxis.set_ticks_position('bottom')
axes.spines['bottom'].set_position(('outward', 20))
axes.spines['left'].set_position(('outward', 20))
axes.spines['bottom'].set_position(('outward', 10))

if not mask_labels:
axes.yaxis.set_ticks_position('left')
buffer = 20 if is_atlas else 10
axes.spines['left'].set_position(('outward', buffer))
axes.set_ylabel('voxels')

if output_file is not None:
figure.savefig(output_file)
Expand Down
68 changes: 66 additions & 2 deletions nilearn/plotting/tests/test_img_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,21 @@ def testdata_4d():
rng.uniform(size=(7, 7, 3, 1777)), mni_affine
)
img_mask = nibabel.Nifti1Image(np.ones((7, 7, 3), int), mni_affine)
atlas = np.ones((7, 7, 3), int)
atlas[2:5, :, :] = 2
atlas[5:8, :, :] = 3
img_atlas = nibabel.Nifti1Image(atlas, mni_affine)
atlas_labels = {
"gm": 1,
"wm": 2,
"csf": 3,
}
data = {
'img_4d': img_4d,
'img_4d_long': img_4d_long,
'img_mask': img_mask,
'img_atlas': img_atlas,
'atlas_labels': atlas_labels,
}
return data

Expand Down Expand Up @@ -314,8 +325,7 @@ def test_plot_glass_brain_threshold_for_uint8(testdata_3d):


def test_plot_carpet(testdata_4d):
"""Check contents of plot_carpet figure against data in image.
"""
"""Check contents of plot_carpet figure against data in image."""
img_4d = testdata_4d['img_4d']
img_4d_long = testdata_4d['img_4d_long']
mask_img = testdata_4d['img_mask']
Expand Down Expand Up @@ -345,6 +355,60 @@ def test_plot_carpet(testdata_4d):
plt.close(display)


def test_plot_carpet_with_atlas(testdata_4d):
"""Test plot_carpet when using an atlas."""
img_4d = testdata_4d['img_4d']
mask_img = testdata_4d['img_atlas']
atlas_labels = testdata_4d['atlas_labels']

# Test atlas - labels
display = plot_carpet(img_4d, mask_img, detrend=False, title='TEST')

# Check the output
# Two axes: 1 for colorbar and 1 for imshow
assert len(display.axes) == 2
# The y-axis label of the imshow should be 'voxels' since atlas labels are
# unknown
ax = display.axes[1]
assert ax.get_ylabel() == 'voxels'

# Next two lines retrieve the numpy array from the plot
ax = display.axes[0]
colorbar = ax.images[0].get_array()
assert len(np.unique(colorbar)) == len(atlas_labels)

# Save execution time and memory
plt.close(display)

# Test atlas + labels
fig, ax = plt.subplots()
display = plot_carpet(
img_4d,
mask_img,
mask_labels=atlas_labels,
detrend=True,
title='TEST',
figure=fig,
axes=ax,
)
# Check the output
# Two axes: 1 for colorbar and 1 for imshow
assert len(display.axes) == 2
ax = display.axes[0]

# The ytick labels of the colorbar should match the atlas labels
yticklabels = ax.get_yticklabels()
yticklabels = [yt.get_text() for yt in yticklabels]
assert set(yticklabels) == set(atlas_labels.keys())

# Next two lines retrieve the numpy array from the plot
ax = display.axes[0]
colorbar = ax.images[0].get_array()
assert len(np.unique(colorbar)) == len(atlas_labels)

plt.close(display)


def test_save_plot(testdata_3d, tmpdir):
img = testdata_3d['img']

Expand Down

0 comments on commit fa85bf0

Please sign in to comment.