Skip to content

Commit

Permalink
conn_net for computing the net connectivity
Browse files Browse the repository at this point in the history
  • Loading branch information
EtienneCmb committed Jan 14, 2022
1 parent 4a6afa6 commit c86b19f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api/api_connectivity.rst
Expand Up @@ -33,6 +33,7 @@ Reshaping connectivity outputs
conn_reshape_undirected
conn_reshape_directed
conn_ravel_directed
conn_net
conn_get_pairs

Metrics to apply on FC
Expand Down
2 changes: 1 addition & 1 deletion frites/conn/__init__.py
Expand Up @@ -24,4 +24,4 @@
from .conn_fcd_corr import conn_fcd_corr # noqa
from .conn_sliding_windows import define_windows, plot_windows # noqa
from .conn_utils import (conn_get_pairs, conn_reshape_undirected, # noqa
conn_reshape_directed, conn_ravel_directed)
conn_reshape_directed, conn_ravel_directed, conn_net)
76 changes: 76 additions & 0 deletions frites/conn/conn_utils.py
Expand Up @@ -414,3 +414,79 @@ def conn_ravel_directed(da, sep='-', drop_within=False):
da_ravel = da_ravel.sel(roi=to_keep)

return da_ravel


def conn_net(da, roi='roi', order=None, sep='-', invert=False, verbose=None):
"""Compute the net on directed connectivity.
This function can be used to compute the net difference on directed
connectivity (i.e. A - B = A->B - B->A).
Parameters
----------
da : xr.DataArray
Xarray DataArray containing the connectivity array
roi : 'roi'
Name of the spatial dimension
order : list | None
List of names for specifying the final order
sep : string | '-'
Separator between brain region names (e.g. if 'Insula->Thalamus' then
sep is '->')
invert : bool | False
Specify whether the difference should be computed with A - B or B - A
Returns
-------
out : xr.DataArray
DataArray, with the same dimension names as the input, representing the
net difference of directed connexions.
"""
set_log_level(verbose)
assert roi in da.dims
roi_names = da[roi].data

# get roi order from sources
if order is None:
roi_s, roi_t = [], []
for r in roi_names:
_rs, _rt = r.split(sep)
roi_s.append(_rs)
roi_t.append(_rt)
order = nonsorted_unique(roi_s + roi_t)
order = np.asarray(order)

# build names of the difference
x_s, x_t = np.triu_indices(len(order), k=1)
roi_s, roi_t = order[x_s], order[x_t]
if invert:
_roi_st = roi_s.copy()
roi_s = roi_t
roi_t = _roi_st

# build pairs names
roi_st, p_s, p_t, ignored = [], [], [], []
for s, t in zip(roi_s, roi_t):
name_s, name_t = f"{s}{sep}{t}", f"{t}{sep}{s}"
if (name_s in da[roi]) and (name_t in da[roi]):
roi_st.append(f"{s}-{t}")
p_s.append(name_s)
p_t.append(name_t)
else:
ignored.append(f"{s}-{t}")
# ignored.append(name_s)
if len(ignored):
logger.warning("The following pairs have been ignored in the "
f"subtraction : {ignored}")

# prepare the output
out = da.isel(**{roi: slice(0, len(roi_st))}).copy()
out[roi] = roi_st
out.data = da.sel(**{roi: p_s}).data - da.sel(**{roi: p_t}).data

# update attributes to track operations
out.attrs['net_source'] = p_s
out.attrs['net_target'] = p_t
out.name = da.name + '_net' if da.name else 'Net conn'

return out
37 changes: 36 additions & 1 deletion frites/conn/tests/test_conn_utils.py
Expand Up @@ -4,7 +4,7 @@

from frites.conn import (conn_reshape_undirected, conn_reshape_directed,
conn_ravel_directed, define_windows, plot_windows,
conn_dfc, conn_covgc, conn_get_pairs)
conn_dfc, conn_covgc, conn_get_pairs, conn_net)


class TestConnUtils(object):
Expand Down Expand Up @@ -187,3 +187,38 @@ def test_conn_ravel_directed(self):
assert len(conn_r.shape) == 3
np.testing.assert_array_equal(conn_r['roi'].data, roi_dir)
np.testing.assert_array_equal(conn_r.data, conn_c)

def test_conn_net(self):
"""Test function conn_net."""
conn_xy = np.full((10, 3, 12), 2)
conn_yx = np.ones((10, 4, 12))
conn = np.concatenate((conn_xy, conn_yx), axis=1)
roi = ['x->y', 'x->z', 'y->z', 'y->x', 'z->x', 'z->y', 'z->a']
trials, times = np.arange(10), np.arange(12)
conn = xr.DataArray(conn, dims=('trials', 'space', 'times'),
coords=(trials, roi, times))

# test normal usage
net = conn_net(conn, roi='space', sep='->', invert=False)
np.testing.assert_array_equal(net.shape, (10, 3, 12))
np.testing.assert_array_equal(net['trials'], trials)
np.testing.assert_array_equal(net['times'], times)
np.testing.assert_array_equal(net['space'], ['x-y', 'x-z', 'y-z'])
np.testing.assert_array_equal(
net.attrs['net_source'], ['x->y', 'x->z', 'y->z'])
np.testing.assert_array_equal(
net.attrs['net_target'], ['y->x', 'z->x', 'z->y'])
np.testing.assert_array_equal(net.data, np.full((10, 3, 12), 1))

# test inverted
net = conn_net(conn, roi='space', sep='->', invert=True)
np.testing.assert_array_equal(net['space'], ['y-x', 'z-x', 'z-y'])
np.testing.assert_array_equal(
net.attrs['net_source'], ['y->x', 'z->x', 'z->y'])
np.testing.assert_array_equal(
net.attrs['net_target'], ['x->y', 'x->z', 'y->z'])
np.testing.assert_array_equal(net.data, np.full((10, 3, 12), -1))

# test order
net = conn_net(conn, roi='space', sep='->', order=['z', 'x', 'y'])
np.testing.assert_array_equal(net['space'], ['z-x', 'z-y', 'x-y'])

0 comments on commit c86b19f

Please sign in to comment.