Skip to content

Commit

Permalink
Test get connectivity paris
Browse files Browse the repository at this point in the history
  • Loading branch information
EtienneCmb committed Feb 6, 2021
1 parent bfdf2db commit b1ff8c3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
52 changes: 51 additions & 1 deletion frites/conn/tests/test_conn_utils.py
Expand Up @@ -2,7 +2,8 @@
import numpy as np

from frites.conn import (conn_reshape_undirected, conn_reshape_directed,
define_windows, plot_windows, conn_dfc, conn_covgc)
define_windows, plot_windows, conn_dfc, conn_covgc,
conn_get_pairs)


class TestConnUtils(object):
Expand Down Expand Up @@ -106,3 +107,52 @@ def test_plot_windows(self):
kw = dict(verbose=False)
ts = define_windows(times, slwin_len=.1, **kw)[0]
plot_windows(times, ts)

def test_conn_get_pairs(self):
"""Test function conn_get_pairs."""
roi = [np.array(['r1', 'r0']), np.array(['r0', 'r2', 'r1'])]
# test non-directed
df = conn_get_pairs(roi, directed=False)
rundir = np.c_[['r0', 'r0', 'r1'], ['r1', 'r2', 'r2']]
names = [f'{k}-{i}' for k, i in zip(rundir[:, 0], rundir[:, 1])]
suj = [0, 1, 1, 1]
nsuj = [2, 1, 1]
assert np.all(df['keep'])
np.testing.assert_array_equal(df['sources'], rundir[:, 0])
np.testing.assert_array_equal(df['targets'], rundir[:, 1])
np.testing.assert_array_equal(df['#subjects'], nsuj)
np.testing.assert_array_equal(df['names'], names)
np.testing.assert_array_equal(np.concatenate(df['subjects']), suj)
# test directed
df = conn_get_pairs(roi, directed=True)
rdir = np.c_[['r0', 'r0', 'r1', 'r1', 'r2', 'r2'],
['r1', 'r2', 'r0', 'r2', 'r0', 'r1']]
names = [f'{k}->{i}' for k, i in zip(rdir[:, 0], rdir[:, 1])]
suj = [0, 1, 1, 0, 1, 1, 1, 1]
nsuj = [2, 1, 2, 1, 1, 1]
assert np.all(df['keep'])
np.testing.assert_array_equal(df['sources'], rdir[:, 0])
np.testing.assert_array_equal(df['targets'], rdir[:, 1])
np.testing.assert_array_equal(df['#subjects'], nsuj)
np.testing.assert_array_equal(df['names'], names)
np.testing.assert_array_equal(np.concatenate(df['subjects']), suj)
# test nb_min_suj filtering (non-directed)
df = conn_get_pairs(roi, directed=False, nb_min_suj=2)
np.testing.assert_array_equal(df['keep'], [True, False, False])
df = df.loc[df['keep']]
np.testing.assert_array_equal(df['sources'], ['r0'])
np.testing.assert_array_equal(df['targets'], ['r1'])
np.testing.assert_array_equal(df['#subjects'], [2])
np.testing.assert_array_equal(np.concatenate(df['subjects']), [0, 1])
np.testing.assert_array_equal(df['names'], ['r0-r1'])
# test nb_min_suj filtering (directed)
df = conn_get_pairs(roi, directed=True, nb_min_suj=2)
np.testing.assert_array_equal(
df['keep'], [True, False, True, False, False, False])
df = df.loc[df['keep']]
np.testing.assert_array_equal(df['sources'], ['r0', 'r1'])
np.testing.assert_array_equal(df['targets'], ['r1', 'r0'])
np.testing.assert_array_equal(df['#subjects'], [2, 2])
np.testing.assert_array_equal(
np.concatenate(list(df['subjects'])), [0, 1, 0, 1])
np.testing.assert_array_equal(df['names'], ['r0->r1', 'r1->r0'])
12 changes: 11 additions & 1 deletion frites/dataset/tests/test_ds_ephy.py
Expand Up @@ -190,6 +190,16 @@ def test_savgol_filter(self):
ds = DatasetEphy(d_3d, times='times', **kw)
ds.savgol_filter(10., verbose=False)

def test_get_connectivity_pairs(self):
"""Test function get_connectivity_pairs."""
d_3d = self._get_data(3)
ds = DatasetEphy(d_3d, times='times', **kw)
for direction in [True, False]:
for blocks in [True, False]:
df = ds.get_connectivity_pairs(
directed=direction, as_blocks=blocks, verbose=False)
assert isinstance(df, pd.DataFrame)


if __name__ == '__main__':
TestDatasetEphy().test_definition()
TestDatasetEphy().test_get_connectivity_pairs()

0 comments on commit b1ff8c3

Please sign in to comment.