Skip to content

Commit 3d24c98

Browse files
committed
Conn input conversion function
1 parent dce6cc7 commit 3d24c98

File tree

7 files changed

+70
-15
lines changed

7 files changed

+70
-15
lines changed

examples/conn/plot_dfc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
# each of the temporal window
7373

7474
# compute DFC
75-
dfc, pairs, roi_p = conn_dfc(x, times, roi, win_sample, n_jobs=1)
75+
dfc, pairs, roi_p = conn_dfc(x, win_sample, times=times, roi=roi, n_jobs=1)
7676

7777
plt.figure(figsize=(10, 8))
7878
plt.plot(times_p, dfc.mean('trials').T)

examples/tutorials/plot_stim_spec_network.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def plot_mi(mi, pv):
148148
# compute the DFC for each subject
149149
dfc = []
150150
for n_s in range(n_subjects):
151-
_dfc = conn_dfc(x[n_s].data, times, roi, win_sample, verbose=False)[0]
151+
_dfc = conn_dfc(x[n_s].data, win_sample, times=times, roi=roi,
152+
verbose=False)[0]
152153
# reset trials dimension
153154
_dfc['trials'] = x[n_s]['trials'].data
154155
dfc += [_dfc]

frites/conn/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Information-theritical measures of connectivity."""
22
from .conn_covgc import conn_covgc # noqa
33
from .conn_dfc import conn_dfc # noqa
4-
from .conn_transfer_entropy import conn_transfer_entropy # noqa
54
from .conn_fit import conn_fit # noqa
5+
from .conn_io import conn_io # noqa
6+
from .conn_transfer_entropy import conn_transfer_entropy # noqa
67
from .conn_utils import (conn_reshape_undirected, conn_reshape_directed) # noqa

frites/conn/conn_dfc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
from frites.config import CONFIG
88
from frites.utils import parallel_func
99

10+
from frites.conn.conn_io import conn_io
11+
1012
from mne.utils import ProgressBar
1113

1214

1315

14-
def conn_dfc(data, times, roi, win_sample, n_jobs=1, gcrn=True, verbose=None):
16+
def conn_dfc(data, win_sample, times=None, roi=None, n_jobs=1, gcrn=True,
17+
verbose=None):
1518
"""Compute the Dynamic Functional Connectivity using the GCMI.
1619
1720
This function computes the Dynamic Functional Connectivity (DFC) using the
@@ -24,14 +27,14 @@ def conn_dfc(data, times, roi, win_sample, n_jobs=1, gcrn=True, verbose=None):
2427
data : array_like
2528
Electrophysiological data array of a single subject organized as
2629
(n_epochs, n_roi, n_times)
27-
times : array_like
28-
Time vector array of shape (n_times,)
29-
roi : array_like
30-
ROI names of a single subject
3130
win_sample : array_like
3231
Array of shape (n_windows, 2) describing where each window start and
3332
finish. You can use the function :func:`frites.utils.define_windows`
3433
to define either manually either sliding windows.
34+
times : array_like | None
35+
Time vector array of shape (n_times,)
36+
roi : array_like | None
37+
ROI names of a single subject
3538
n_jobs : int | 1
3639
Number of jobs to use for parallel computing (use -1 to use all
3740
jobs). The parallel loop is set at the pair level.
@@ -54,9 +57,12 @@ def conn_dfc(data, times, roi, win_sample, n_jobs=1, gcrn=True, verbose=None):
5457
define_windows, conn_covgc
5558
"""
5659
set_log_level(verbose)
60+
# -------------------------------------------------------------------------
61+
# inputs conversion
62+
da, roi, times = conn_io(data, roi=roi, times=times, verbose=verbose)
63+
5764
# -------------------------------------------------------------------------
5865
# data checking
59-
assert isinstance(data, np.ndarray) and (data.ndim == 3)
6066
n_epochs, n_roi, n_pts = data.shape
6167
assert (len(roi) == n_roi) and (len(times) == n_pts)
6268
assert isinstance(win_sample, np.ndarray) and (win_sample.ndim == 2)

frites/conn/conn_io.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Manage I/O for connectivity functions."""
2+
import numpy as np
3+
import xarray as xr
4+
5+
6+
from frites.io import set_log_level, logger
7+
8+
9+
def conn_io(da, roi=None, times=None, verbose=None):
10+
"""I/O conversion for connectivity functions.
11+
12+
Parameters
13+
----------
14+
da : array_like
15+
Array of electrophysiological data of shape (n_trials, n_roi, n_times)
16+
roi : array_like | None
17+
List of roi names or string corresponding to the dimension name in a
18+
DataArray
19+
times : array_like | None
20+
Time vector or string corresponding to the dimension name in a
21+
DataArray
22+
"""
23+
set_log_level(verbose)
24+
assert isinstance(da, np.ndarray) or isinstance(da, xr.DataArray)
25+
assert da.ndim == 3
26+
n_trials, n_roi, n_times = da.shape
27+
logger.info(f"Inputs conversion (n_trials={n_trials}, n_roi={n_roi}, "
28+
"n_times={n_times})")
29+
30+
# _____________________________ Empty inputs ______________________________
31+
if roi is None:
32+
roi = [f"roi_{k}" for k in range(n_roi)]
33+
if times is None:
34+
times = np.arange(n_times)
35+
36+
# _______________________________ Xarray case _____________________________
37+
if isinstance(da, xr.DataArray):
38+
# get roi and times
39+
if isinstance(roi, str):
40+
roi = da[roi].data
41+
if isinstance(times, str):
42+
times = da[times].data
43+
da = da.data
44+
45+
# _______________________________ Final check _____________________________
46+
assert isinstance(da, np.ndarray)
47+
assert da.shape == (n_trials, len(roi), len(times))
48+
49+
return da, roi, times

frites/conn/tests/test_conn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def test_conn_dfc(self):
4444
roi = [f"roi_{k}" for k in range(n_roi)]
4545
x = np.random.rand(n_epochs, n_roi, n_times)
4646

47-
dfc = conn_dfc(x, times, roi, win_sample)[0]
47+
dfc = conn_dfc(x, win_sample, times=times, roi=roi)[0]
4848
assert dfc.shape == (n_epochs, 3, 2)
49-
dfc = conn_dfc(x, times, roi, win_sample)[0]
49+
dfc = conn_dfc(x, win_sample, times=times, roi=roi)[0]
5050
assert isinstance(dfc, xr.DataArray)
5151

5252
def test_conn_covgc(self):
@@ -75,7 +75,7 @@ def test_conn_reshape_undirected(self):
7575
roi = [f"roi_{k}" for k in range(n_roi)]
7676
order = ['roi_2', 'roi_1']
7777
x = np.random.rand(n_epochs, n_roi, n_times)
78-
dfc = conn_dfc(x, times, roi, win_sample)[0].mean('trials')
78+
dfc = conn_dfc(x, win_sample, times=times, roi=roi)[0].mean('trials')
7979
# reshape it without the time dimension
8080
dfc_mean = conn_reshape_undirected(dfc.mean('times'))
8181
assert dfc_mean.shape == (n_roi, n_roi, 1)

frites/dataset/ds_ephy_io.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33

44
import numpy as np
55

6-
from frites.io import set_log_level
6+
from frites.io import set_log_level, logger
77
from frites.config import CONFIG
88

9-
logger = logging.getLogger("frites")
10-
119

1210
###############################################################################
1311
###############################################################################

0 commit comments

Comments
 (0)