Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ENH: add connectome strength plot #2028

Merged
merged 11 commits into from
Oct 14, 2019
1 change: 1 addition & 0 deletions doc/modules/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ uses.
plot_stat_map
plot_glass_brain
plot_connectome
plot_connectome_strength
plot_prob_atlas
plot_surf
plot_surf_roi
Expand Down
13 changes: 13 additions & 0 deletions doc/plotting/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ different heuristics to find cutting coordinates.
:target: ../auto_examples/03_connectivity/plot_sphere_based_connectome.html
:scale: 50

.. |plot_strength| image:: ../auto_examples/03_connectivity/images/sphx_glr_plot_sphere_based_connectome_004.png
:target: ../auto_examples/03_connectivity/plot_sphere_based_connectome.html
:scale: 50

.. |plot_anat| image:: ../auto_examples/01_plotting/images/sphx_glr_plot_demo_plotting_003.png
:target: ../auto_examples/01_plotting/plot_demo_plotting.html
:scale: 50
Expand Down Expand Up @@ -102,6 +106,15 @@ different heuristics to find cutting coordinates.
are demonstrated in
**Example:** :ref:`sphx_glr_auto_examples_03_connectivity_plot_atlas_comparison.py`

|plot_strength| :func:`plot_connectome_strength`
|hack|
Plotting a connectome strength

Functions for automatic extraction of coords based on
brain parcellations useful for :func:`plot_connectome`
are demonstrated in
**Example:** :ref:`sphx_glr_auto_examples_03_connectivity_plot_atlas_comparison.py`

|plot_prob_atlas| :func:`plot_prob_atlas`
|hack|
Plotting 4D probabilistic atlas maps
Expand Down
4 changes: 3 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ NEW
Yields very fast & accurate models, without creation of giant
clusters.
:class:`nilearn.regions.ReNA`

- Plot connectome strength
Use :func:`nilearn.plotting.plot_connectome_strength` to plot the strength of a
connectome on a glass brain. Strength is absolute sum of the edges at a node.
- Optimization to image resampling
:func:`nilearn.image.resample_img` has been optimized to pad rather than
resample images in the special case when there is only a translation
Expand Down
81 changes: 61 additions & 20 deletions examples/03_connectivity/plot_sphere_based_connectome.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
covariance** and **partial_correlation**, to recover the functional brain
**networks structure**.

We'll start by extracting signals from Default Mode Network regions and computing a
connectome from them.
We'll start by extracting signals from Default Mode Network regions and
computing a connectome from them.

"""

Expand All @@ -37,7 +37,7 @@
# connectivity dataset.

from nilearn import datasets
dataset = datasets.fetch_development_fmri(n_subjects=1)
dataset = datasets.fetch_development_fmri(n_subjects=20)

# print basic information on the dataset
print('First subject functional nifti image (4D) is at: %s' %
Expand All @@ -49,21 +49,21 @@
# ------------------------------------
dmn_coords = [(0, -52, 18), (-46, -68, 32), (46, -68, 32), (1, 50, -5)]
labels = [
'Posterior Cingulate Cortex',
'Left Temporoparietal junction',
'Right Temporoparietal junction',
'Medial prefrontal cortex',
]
'Posterior Cingulate Cortex',
'Left Temporoparietal junction',
'Right Temporoparietal junction',
'Medial prefrontal cortex'
]

##########################################################################
# Extracts signal from sphere around DMN seeds
# ----------------------------------------------
#
# We can compute the mean signal within **spheres** of a fixed radius
# We can compute the mean signal within **spheres** of a fixed radius
# around a sequence of (x, y, z) coordinates with the object
# :class:`nilearn.input_data.NiftiSpheresMasker`.
# The resulting signal is then prepared by the masker object: Detrended,
# band-pass filtered and **standardized to 1 variance**.
# The resulting signal is then prepared by the masker object: Detrended,
# band-pass filtered and **standardized to 1 variance**.

from nilearn import input_data

Expand All @@ -73,7 +73,7 @@
low_pass=0.1, high_pass=0.01, t_r=2,
memory='nilearn_cache', memory_level=1, verbose=2)

# Additionally, we pass confound information so ensure our extracted
# Additionally, we pass confound information to ensure our extracted
# signal is cleaned from confounds.

func_filename = dataset.func[0]
Expand Down Expand Up @@ -114,7 +114,8 @@
# Display connectome
# -------------------
#
# We display the graph of connections with `:func: nilearn.plotting.plot_connectome`.
# We display the graph of connections with
# `:func: nilearn.plotting.plot_connectome`.

from nilearn import plotting

Expand Down Expand Up @@ -172,7 +173,8 @@
#
#
# You can retrieve the coordinates for any atlas, including atlases
# not included in nilearn, using :func:`nilearn.plotting.find_parcellation_cut_coords`.
# not included in nilearn, using
# :func:`nilearn.plotting.find_parcellation_cut_coords`.


###############################################################################
Expand Down Expand Up @@ -231,8 +233,8 @@


###############################################################################
# Plot matrix and graph
# ---------------------
# Plot matrix, graph, and strength
# --------------------------------
#
# We use `:func: nilearn.plotting.plot_matrix` to visualize our correlation matrix
# and display the graph of connections with `nilearn.plotting.plot_connectome`.
Expand All @@ -247,13 +249,41 @@


###############################################################################
# .. note::
#
# Note the 1. on the matrix diagonal: These are the signals variances, set to
# 1. by the `spheres_masker`. Hence the covariance of the signal is a
# .. note::
#
# Note the 1. on the matrix diagonal: These are the signals variances, set
# to 1. by the `spheres_masker`. Hence the covariance of the signal is a
# correlation matrix.


###############################################################################
# Sometimes, the information in the correlation matrix is overwhelming and
# aggregating edge strength from the graph would help. Use the function
# `nilearn.plotting.plot_connectome_strength` to visualize this information.

plotting.plot_connectome_strength(
matrix, coords, title='Connectome strength for Power atlas'
)

###############################################################################
# From the correlation matrix, we observe that there is a positive and negative
# structure. We could make two different plots by plotting these strengths
# separately.

from matplotlib.pyplot import cm

# plot the positive part of of the matrix
plotting.plot_connectome_strength(
np.clip(matrix, 0, matrix.max()), coords, cmap=cm.YlOrRd,
title='Strength of the positive edges of the Power correlation matrix'
)

# plot the negative part of of the matrix
plotting.plot_connectome_strength(
np.clip(matrix, matrix.min(), 0), coords, cmap=cm.PuBu,
title='Strength of the negative edges of the Power correlation matrix'
)

###############################################################################
# Connectome extracted from Dosenbach's atlas
# -------------------------------------------
Expand Down Expand Up @@ -283,6 +313,17 @@

plotting.plot_connectome(matrix, coords, title='Dosenbach correlation graph',
edge_threshold="99.7%", node_size=20, colorbar=True)
plotting.plot_connectome_strength(
matrix, coords, title='Connectome strength for Power atlas'
)
plotting.plot_connectome_strength(
np.clip(matrix, 0, matrix.max()), coords, cmap=cm.YlOrRd,
title='Strength of the positive edges of the Power correlation matrix'
)
plotting.plot_connectome_strength(
np.clip(matrix, matrix.min(), 0), coords, cmap=cm.PuBu,
title='Strength of the negative edges of the Power correlation matrix'
)


###############################################################################
Expand Down
4 changes: 2 additions & 2 deletions nilearn/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _set_mpl_backend():
from . import cm
from .img_plotting import plot_img, plot_anat, plot_epi, \
plot_roi, plot_stat_map, plot_glass_brain, plot_connectome, \
plot_prob_atlas, show
plot_connectome_strength, plot_prob_atlas, show
from .find_cuts import find_xyz_cut_coords, find_cut_slices, \
find_parcellation_cut_coords, find_probabilistic_atlas_cut_coords
from .matrix_plotting import plot_matrix
Expand All @@ -48,7 +48,7 @@ def _set_mpl_backend():

__all__ = ['cm', 'plot_img', 'plot_anat', 'plot_epi',
'plot_roi', 'plot_stat_map', 'plot_glass_brain',
'plot_connectome', 'plot_prob_atlas',
'plot_connectome_strength', 'plot_connectome', 'plot_prob_atlas',
'find_xyz_cut_coords', 'find_cut_slices',
'show', 'plot_matrix', 'view_surf', 'view_img_on_surf',
'view_img', 'view_connectome', 'view_markers',
Expand Down
136 changes: 136 additions & 0 deletions nilearn/plotting/img_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# delayed, so that the part module can be used without them).
import numpy as np
from scipy import ndimage
from scipy import sparse
from nibabel.spatialimages import SpatialImage

from .._utils.numpy_conversions import as_ndarray
Expand Down Expand Up @@ -1305,3 +1306,138 @@ def plot_connectome(adjacency_matrix, node_coords,
display = None

return display


def plot_connectome_strength(adjacency_matrix, node_coords, node_size="auto",
cmap=None, output_file=None, display_mode="ortho",
figure=None, axes=None, title=None):
"""Plot connectome strength on top of the brain glass schematics.

The strength of a connection is define as the sum of absolute values of
the edges arriving to a node.

Parameters
----------
adjacency_matrix : numpy array of shape (n, n)
represents the link strengths of the graph. Assumed to be
a symmetric matrix.
node_coords : numpy array_like of shape (n, 3)
3d coordinates of the graph nodes in world space.
node_size : 'auto' or scalar
size(s) of the nodes in points^2. By default the size of the node is
inversely propertionnal to the number of nodes.
cmap : str or colormap
colormap used to represent the strength of a node.
output_file : string, or None, optional
The name of an image file to export the plot to. Valid extensions
are .png, .pdf, .svg. If output_file is not None, the plot
is saved to a file, and the display is closed.
display_mode : string, optional. Default is 'ortho'.
Choose the direction of the cuts: 'x' - sagittal, 'y' - coronal,
'z' - axial, 'l' - sagittal left hemisphere only,
'r' - sagittal right hemisphere only, 'ortho' - three cuts are
performed in orthogonal directions. Possible values are: 'ortho',
'x', 'y', 'z', 'xz', 'yx', 'yz', 'l', 'r', 'lr', 'lzr', 'lyr',
'lzry', 'lyrz'.
figure : integer or matplotlib figure, optional
Matplotlib figure used or its number. If None is given, a
new figure is created.
axes : matplotlib axes or 4 tuple of float: (xmin, ymin, width, height), \
optional
The axes, or the coordinates, in matplotlib figure space,
of the axes used to display the plot. If None, the complete
figure is used.
title : string, optional
The title displayed on the figure.

Notes
-----
The plotted image should in MNI space for this function to work properly.
"""

# input validation
if cmap is None:
cmap = plt.cm.viridis_r
elif isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
else:
cmap = cmap

node_size = (1 / len(node_coords) * 1e4
if node_size == 'auto' else node_size)

node_coords = np.asarray(node_coords)

if sparse.issparse(adjacency_matrix):
adjacency_matrix = adjacency_matrix.toarray()

adjacency_matrix = np.nan_to_num(adjacency_matrix)

adjacency_matrix_shape = adjacency_matrix.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the 4 checks below (adjacency shape, nodes shape, both shapes, adjacency symmetry, mask symmetry) could be useful for other connectome plotting functions?

if (len(adjacency_matrix_shape) != 2 or
adjacency_matrix_shape[0] != adjacency_matrix_shape[1]):
raise ValueError(
"'adjacency_matrix' is supposed to have shape (n, n)."
' Its shape was {0}'.format(adjacency_matrix_shape))

node_coords_shape = node_coords.shape
if len(node_coords_shape) != 2 or node_coords_shape[1] != 3:
message = (
"Invalid shape for 'node_coords'. You passed an "
"'adjacency_matrix' of shape {0} therefore "
"'node_coords' should be a array with shape ({0[0]}, 3) "
"while its shape was {1}").format(adjacency_matrix_shape,
node_coords_shape)

raise ValueError(message)

if node_coords_shape[0] != adjacency_matrix_shape[0]:
raise ValueError(
"Shape mismatch between 'adjacency_matrix' "
"and 'node_coords'"
"'adjacency_matrix' shape is {0}, 'node_coords' shape is {1}"
.format(adjacency_matrix_shape, node_coords_shape))

if not np.allclose(adjacency_matrix, adjacency_matrix.T, rtol=1e-3):
raise ValueError("'adjacency_matrix' should be symmetric")

# For a masked array, masked values are replaced with zeros
if hasattr(adjacency_matrix, 'mask'):
if not (adjacency_matrix.mask == adjacency_matrix.mask.T).all():
raise ValueError(
"'adjacency_matrix' was masked with a non symmetric mask")
adjacency_matrix = adjacency_matrix.filled(0)

# plotting
region_strength = np.sum(np.abs(adjacency_matrix), axis=0)
region_strength /= np.sum(region_strength)

region_idx_sorted = np.argsort(region_strength)[::-1]
strength_sorted = region_strength[region_idx_sorted]
coords_sorted = node_coords[region_idx_sorted]

display = plot_glass_brain(
None, display_mode=display_mode, figure=figure, axes=axes, title=title
)

for coord, region in zip(coords_sorted, strength_sorted):
color = list(
cmap((region - strength_sorted.min()) / strength_sorted.max())
)
# reduce alpha for the least strong regions
color[-1] = (
(region - strength_sorted.min()) *
(1 / (strength_sorted.max() - strength_sorted.min()))
)
# make color to be a 2D array
color = [color]
display.add_markers(
[coord], marker_color=color, marker_size=node_size
)

if output_file is not None:
display.savefig(output_file)
display.close()
display = None

return display
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @glemaitre this is a super long method, refactoring it would be helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there is no code duplication. It would not be really meaningful to cut the function into smaller functions if they are not reused.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They improve ease of understanding.

Loading