Skip to content

Commit

Permalink
[MRG] Nans in view connectome (nilearn#2166)
Browse files Browse the repository at this point in the history
* nan_to_num in view_connectome

* update warning test

* whatsnew
  • Loading branch information
jeromedockes authored and kchawla-pi committed Oct 10, 2019
1 parent 6c24005 commit d566678
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Changes
colormap. These arguments were already accepted in `kwargs` but not documented
before.

- :func:`nilearn.plotting.view_connectome` now converts NaNs in the adjacency
matrix to 0.

Fixes
-----

Expand Down
2 changes: 1 addition & 1 deletion nilearn/plotting/html_connectome.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _get_connectome(adjacency_matrix, coords, threshold=None,
marker_size=None, cmap=cm.cold_hot, symmetric_cmap=True):
connectome = {}
coords = np.asarray(coords, dtype='<f4')
adjacency_matrix = adjacency_matrix.copy()
adjacency_matrix = np.nan_to_num(adjacency_matrix, copy=True)
colors = colorscale(
cmap, adjacency_matrix.ravel(), threshold=threshold,
symmetric_cmap=symmetric_cmap)
Expand Down
14 changes: 10 additions & 4 deletions nilearn/plotting/tests/test_html_connectome.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def test_get_connectome():
assert {'_con_x', '_con_y', '_con_z', '_con_w', 'colorscale'
}.issubset(connectome.keys())
assert (connectome['cmin'], connectome['cmax']) == (-2.5, 2.5)
adj[adj == 0] = np.nan
connectome = html_connectome._get_connectome(adj, coord)
con_x = decode(connectome['_con_x'], '<f4')
assert (con_x == expected_x).all()
assert (connectome['cmin'], connectome['cmax']) == (-2.5, 2.5)


def test_view_connectome():
Expand Down Expand Up @@ -129,10 +134,11 @@ def test_params_deprecation_view_connectome():
)
old_params = ['coords', 'threshold', 'cmap', 'marker_size']

assert len(raised_warnings) == 4
for old_param_, raised_warning_ in zip(old_params, raised_warnings):
assert warning_msgs[old_param_] == str(raised_warning_.message)
assert raised_warning_.category is DeprecationWarning
raised_warning_messages = ''.join(
str(warning.message) for warning in raised_warnings)
print(raised_warning_messages)
for old_param_ in old_params:
assert warning_msgs[old_param_] in raised_warning_messages


def test_get_markers():
Expand Down

0 comments on commit d566678

Please sign in to comment.