-
Notifications
You must be signed in to change notification settings - Fork 580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] ENH: add connectome strength plot #2028
Changes from all commits
b90eeb8
7dc05ff
cfd97b5
db26e2e
c3c283d
1beafbf
ce72196
836c4d9
ee46da0
56c579b
54edbcb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
# delayed, so that the part module can be used without them). | ||
import numpy as np | ||
from scipy import ndimage | ||
from scipy import sparse | ||
from nibabel.spatialimages import SpatialImage | ||
|
||
from .._utils.numpy_conversions import as_ndarray | ||
|
@@ -1305,3 +1306,138 @@ def plot_connectome(adjacency_matrix, node_coords, | |
display = None | ||
|
||
return display | ||
|
||
|
||
def plot_connectome_strength(adjacency_matrix, node_coords, node_size="auto", | ||
cmap=None, output_file=None, display_mode="ortho", | ||
figure=None, axes=None, title=None): | ||
"""Plot connectome strength on top of the brain glass schematics. | ||
|
||
The strength of a connection is define as the sum of absolute values of | ||
the edges arriving to a node. | ||
|
||
Parameters | ||
---------- | ||
adjacency_matrix : numpy array of shape (n, n) | ||
represents the link strengths of the graph. Assumed to be | ||
a symmetric matrix. | ||
node_coords : numpy array_like of shape (n, 3) | ||
3d coordinates of the graph nodes in world space. | ||
node_size : 'auto' or scalar | ||
size(s) of the nodes in points^2. By default the size of the node is | ||
inversely propertionnal to the number of nodes. | ||
cmap : str or colormap | ||
colormap used to represent the strength of a node. | ||
output_file : string, or None, optional | ||
The name of an image file to export the plot to. Valid extensions | ||
are .png, .pdf, .svg. If output_file is not None, the plot | ||
is saved to a file, and the display is closed. | ||
display_mode : string, optional. Default is 'ortho'. | ||
Choose the direction of the cuts: 'x' - sagittal, 'y' - coronal, | ||
'z' - axial, 'l' - sagittal left hemisphere only, | ||
'r' - sagittal right hemisphere only, 'ortho' - three cuts are | ||
performed in orthogonal directions. Possible values are: 'ortho', | ||
'x', 'y', 'z', 'xz', 'yx', 'yz', 'l', 'r', 'lr', 'lzr', 'lyr', | ||
'lzry', 'lyrz'. | ||
figure : integer or matplotlib figure, optional | ||
Matplotlib figure used or its number. If None is given, a | ||
new figure is created. | ||
axes : matplotlib axes or 4 tuple of float: (xmin, ymin, width, height), \ | ||
optional | ||
The axes, or the coordinates, in matplotlib figure space, | ||
of the axes used to display the plot. If None, the complete | ||
figure is used. | ||
title : string, optional | ||
The title displayed on the figure. | ||
|
||
Notes | ||
----- | ||
The plotted image should in MNI space for this function to work properly. | ||
""" | ||
|
||
# input validation | ||
if cmap is None: | ||
cmap = plt.cm.viridis_r | ||
elif isinstance(cmap, str): | ||
cmap = plt.get_cmap(cmap) | ||
else: | ||
cmap = cmap | ||
|
||
node_size = (1 / len(node_coords) * 1e4 | ||
if node_size == 'auto' else node_size) | ||
|
||
node_coords = np.asarray(node_coords) | ||
|
||
if sparse.issparse(adjacency_matrix): | ||
adjacency_matrix = adjacency_matrix.toarray() | ||
|
||
adjacency_matrix = np.nan_to_num(adjacency_matrix) | ||
|
||
adjacency_matrix_shape = adjacency_matrix.shape | ||
if (len(adjacency_matrix_shape) != 2 or | ||
adjacency_matrix_shape[0] != adjacency_matrix_shape[1]): | ||
raise ValueError( | ||
"'adjacency_matrix' is supposed to have shape (n, n)." | ||
' Its shape was {0}'.format(adjacency_matrix_shape)) | ||
|
||
node_coords_shape = node_coords.shape | ||
if len(node_coords_shape) != 2 or node_coords_shape[1] != 3: | ||
message = ( | ||
"Invalid shape for 'node_coords'. You passed an " | ||
"'adjacency_matrix' of shape {0} therefore " | ||
"'node_coords' should be a array with shape ({0[0]}, 3) " | ||
"while its shape was {1}").format(adjacency_matrix_shape, | ||
node_coords_shape) | ||
|
||
raise ValueError(message) | ||
|
||
if node_coords_shape[0] != adjacency_matrix_shape[0]: | ||
raise ValueError( | ||
"Shape mismatch between 'adjacency_matrix' " | ||
"and 'node_coords'" | ||
"'adjacency_matrix' shape is {0}, 'node_coords' shape is {1}" | ||
.format(adjacency_matrix_shape, node_coords_shape)) | ||
|
||
if not np.allclose(adjacency_matrix, adjacency_matrix.T, rtol=1e-3): | ||
raise ValueError("'adjacency_matrix' should be symmetric") | ||
|
||
# For a masked array, masked values are replaced with zeros | ||
if hasattr(adjacency_matrix, 'mask'): | ||
if not (adjacency_matrix.mask == adjacency_matrix.mask.T).all(): | ||
raise ValueError( | ||
"'adjacency_matrix' was masked with a non symmetric mask") | ||
adjacency_matrix = adjacency_matrix.filled(0) | ||
|
||
# plotting | ||
region_strength = np.sum(np.abs(adjacency_matrix), axis=0) | ||
region_strength /= np.sum(region_strength) | ||
|
||
region_idx_sorted = np.argsort(region_strength)[::-1] | ||
strength_sorted = region_strength[region_idx_sorted] | ||
coords_sorted = node_coords[region_idx_sorted] | ||
|
||
display = plot_glass_brain( | ||
None, display_mode=display_mode, figure=figure, axes=axes, title=title | ||
) | ||
|
||
for coord, region in zip(coords_sorted, strength_sorted): | ||
color = list( | ||
cmap((region - strength_sorted.min()) / strength_sorted.max()) | ||
) | ||
# reduce alpha for the least strong regions | ||
color[-1] = ( | ||
(region - strength_sorted.min()) * | ||
(1 / (strength_sorted.max() - strength_sorted.min())) | ||
) | ||
# make color to be a 2D array | ||
color = [color] | ||
display.add_markers( | ||
[coord], marker_color=color, marker_size=node_size | ||
) | ||
|
||
if output_file is not None: | ||
display.savefig(output_file) | ||
display.close() | ||
display = None | ||
|
||
return display | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @glemaitre this is a super long method, refactoring it would be helpful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But there is no code duplication. It would not be really meaningful to cut the function into smaller functions if they are not reused. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They improve ease of understanding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the 4 checks below (adjacency shape, nodes shape, both shapes, adjacency symmetry, mask symmetry) could be useful for other connectome plotting functions?