Skip to content

Commit 65bf08e

Browse files
committed
Improve conn_io and adapt all connectivity functions
1 parent 60cc160 commit 65bf08e

File tree

7 files changed

+49
-53
lines changed

7 files changed

+49
-53
lines changed

examples/conn/plot_covgc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
t0 = np.arange(100, 900, 10)
6161
lag = 10
6262
dt = 100
63-
gc, pairs, roi_p, times_p = conn_covgc(x, dt, lag, t0, times=times, roi=roi,
64-
n_jobs=1)
63+
gc = conn_covgc(x, dt, lag, t0, times=times, roi=roi, n_jobs=1)
64+
roi_p = gc['roi'].data
6565

6666
###############################################################################
6767
# Below we plot the mean time series of both directed and undirected covgc

examples/conn/plot_dfc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@
7272
# each of the temporal window
7373

7474
# compute DFC
75-
dfc, pairs, roi_p = conn_dfc(x, win_sample, times=times, roi=roi, n_jobs=1)
75+
dfc = conn_dfc(x, win_sample, times=times, roi=roi, n_jobs=1)
76+
print(dfc)
7677

7778
plt.figure(figsize=(10, 8))
7879
plt.plot(times_p, dfc.mean('trials').T)

frites/conn/conn_covgc.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,6 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
284284
* 0 : pairs[:, 0] -> pairs[:, 1] (x->y)
285285
* 1 : pairs[:, 1] -> pairs[:, 0] (y->x)
286286
* 2 : instantaneous (x.y)
287-
pairs : array_like
288-
Array of pairs of shape (n_pairs, 2)
289287
290288
References
291289
----------
@@ -304,15 +302,12 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
304302
t0 = np.asarray(t0).astype(int)
305303
dt, lag, step, trials = int(dt), int(lag), int(step), None
306304
# handle dataarray input
307-
if isinstance(data, xr.DataArray):
308-
trials = data['trials'].data
309-
data, roi, times = conn_io(data, roi=roi, times=times, verbose=verbose)
305+
data, trials, roi, times, attrs = conn_io(data, roi=roi, times=times,
306+
verbose=verbose)
310307
# force C contiguous array because operations on row-major
311308
if not data.flags.c_contiguous:
312309
data = np.ascontiguousarray(data)
313310
n_epochs, n_roi, n_times = data.shape
314-
if trials is None:
315-
trials = np.arange(n_epochs)
316311
# default roi vector
317312
if roi is None:
318313
roi = np.array([f"roi_{k}" for k in range(n_roi)])
@@ -370,32 +365,28 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
370365
# change output type
371366
dire = np.array(['x->y', 'y->x', 'x.y'])
372367
gc = xr.DataArray(gc, dims=('trials', 'roi', 'times', 'direction'),
373-
coords=(trials, roi_p, times_p, dire))
368+
coords=(trials, roi_p, times_p, dire), name='covgc')
374369
# set attributes
375-
gc.attrs['lag'] = lag
376-
gc.attrs['step'] = step
377-
gc.attrs['dt'] = dt
378-
gc.attrs['t0'] = t0
379-
gc.attrs['conditional'] = conditional
380-
gc.attrs['type'] = 'covgc'
381-
gc.name = 'covgc'
370+
cfg = dict(lag='lag', step='step', dt='dt', t0='t0',
371+
conditional='conditional', type='covgc')
372+
gc.attrs = {**attrs, **cfg}
382373

383-
return gc, pairs, roi_p, times_p
374+
return gc
384375

385376

386377
if __name__ == '__main__':
387378
from frites.simulations import StimSpecAR
388379
import matplotlib.pyplot as plt
389380

390381
ss = StimSpecAR()
391-
ar = ss.fit(ar_type='ding_3', n_stim=2, n_epochs=20)
382+
ar = ss.fit(ar_type='ding_3_direct', n_stim=2, n_epochs=20)
392383
# plot the model
393384
# plt.figure(figsize=(7, 8))
394385
# ss.plot()
395386
# compute covgc
396387
dt, lag, step = 50, 5, 2
397388
t0 = np.arange(lag, ar.shape[-1] - dt, step)
398389
gc = conn_covgc(ar, roi='roi', times='times', dt=dt, lag=lag, t0=t0,
399-
n_jobs=-1, conditional=False)[0]
390+
n_jobs=-1, conditional=False)
400391
ss.plot_covgc(gc=gc)
401392
plt.show()

frites/conn/conn_dfc.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,6 @@ def conn_dfc(data, win_sample, times=None, roi=None, n_jobs=1, gcrn=True,
4747
-------
4848
dfc : array_like
4949
The DFC array of shape (n_epochs, n_pairs, n_windows)
50-
pairs : array_like
51-
Array of pairs of shape (n_pairs, 2)
52-
roi_p : array_like
53-
Array of shape (n_pairs,) describing the name of each pair
5450
5551
See also
5652
--------
@@ -59,7 +55,8 @@ def conn_dfc(data, win_sample, times=None, roi=None, n_jobs=1, gcrn=True,
5955
set_log_level(verbose)
6056
# -------------------------------------------------------------------------
6157
# inputs conversion
62-
data, roi, times = conn_io(data, roi=roi, times=times, verbose=verbose)
58+
data, trials, roi, times, attrs = conn_io(data, roi=roi, times=times,
59+
verbose=verbose)
6360

6461
# -------------------------------------------------------------------------
6562
# data checking
@@ -100,14 +97,12 @@ def conn_dfc(data, win_sample, times=None, roi=None, n_jobs=1, gcrn=True,
10097

10198
# -------------------------------------------------------------------------
10299
# dataarray conversion
103-
trials = np.arange(n_epochs)
104100
win_times = times[win_sample]
105-
dfc = xr.DataArray(dfc, dims=('trials', 'roi', 'times'),
101+
dfc = xr.DataArray(dfc, dims=('trials', 'roi', 'times'), name='dfc',
106102
coords=(trials, roi_p, win_times.mean(1)))
107103
# add the windows used in the attributes
108-
dfc.attrs['win_sample'] = np.r_[tuple(win_sample)]
109-
dfc.attrs['win_times'] = np.r_[tuple(win_times)]
110-
dfc.attrs['type'] = 'dfc'
111-
dfc.name = 'dfc'
104+
cfg = dict(win_sample=np.r_[tuple(win_sample)],
105+
win_times=np.r_[tuple(win_times)], type='dfc')
106+
dfc.attrs = {**cfg, **attrs}
112107

113-
return dfc, pairs, roi_p
108+
return dfc

frites/conn/conn_io.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from frites.io import set_log_level, logger
77

88

9-
def conn_io(da, roi=None, times=None, verbose=None):
9+
def conn_io(da, trials=None, roi=None, times=None, verbose=None):
1010
"""I/O conversion for connectivity functions.
1111
1212
Parameters
@@ -24,26 +24,35 @@ def conn_io(da, roi=None, times=None, verbose=None):
2424
assert isinstance(da, np.ndarray) or isinstance(da, xr.DataArray)
2525
assert da.ndim == 3
2626
n_trials, n_roi, n_times = da.shape
27+
attrs = dict(n_trials=n_trials, n_roi=n_roi, n_times=n_times)
2728
logger.info(f"Inputs conversion (n_trials={n_trials}, n_roi={n_roi}, "
28-
"n_times={n_times})")
29-
30-
# _____________________________ Empty inputs ______________________________
31-
if roi is None:
32-
roi = [f"roi_{k}" for k in range(n_roi)]
33-
if times is None:
34-
times = np.arange(n_times)
29+
f"n_times={n_times})")
3530

3631
# _______________________________ Xarray case _____________________________
3732
if isinstance(da, xr.DataArray):
38-
# get roi and times
33+
# force using
34+
if trials is None:
35+
trials = da.dims[0]
36+
# get trials, roi and times
37+
if isinstance(trials, str):
38+
trials = da[trials].data
3939
if isinstance(roi, str):
4040
roi = da[roi].data
4141
if isinstance(times, str):
4242
times = da[times].data
43+
attrs = {**attrs, **da.attrs}
4344
da = da.data
4445

46+
# _____________________________ Empty inputs ______________________________
47+
if roi is None:
48+
roi = [f"roi_{k}" for k in range(n_roi)]
49+
if times is None:
50+
times = np.arange(n_times)
51+
if trials is None:
52+
trials = np.arange(n_trials)
53+
4554
# _______________________________ Final check _____________________________
4655
assert isinstance(da, np.ndarray)
47-
assert da.shape == (n_trials, len(roi), len(times))
56+
assert da.shape == (len(trials), len(roi), len(times))
4857

49-
return da, roi, times
58+
return da, trials, roi, times, attrs

frites/conn/tests/test_conn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def test_conn_dfc(self):
4444
roi = [f"roi_{k}" for k in range(n_roi)]
4545
x = np.random.rand(n_epochs, n_roi, n_times)
4646

47-
dfc = conn_dfc(x, win_sample, times=times, roi=roi)[0]
47+
dfc = conn_dfc(x, win_sample, times=times, roi=roi)
4848
assert dfc.shape == (n_epochs, 3, 2)
49-
dfc = conn_dfc(x, win_sample, times=times, roi=roi)[0]
49+
dfc = conn_dfc(x, win_sample, times=times, roi=roi)
5050
assert isinstance(dfc, xr.DataArray)
5151

5252
def test_conn_covgc(self):
@@ -59,12 +59,12 @@ def test_conn_covgc(self):
5959
lag = 2
6060
t0 = [50, 80]
6161

62-
_ = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gc')[0]
63-
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gauss')[0]
62+
_ = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gc')
63+
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gauss')
6464
assert gc.shape == (n_epochs, 3, len(t0), 3)
6565
assert isinstance(gc, xr.DataArray)
6666
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gc',
67-
conditional=True)[0]
67+
conditional=True)
6868

6969
def test_conn_reshape_undirected(self):
7070
"""Test function conn_reshape_undirected."""
@@ -75,7 +75,7 @@ def test_conn_reshape_undirected(self):
7575
roi = [f"roi_{k}" for k in range(n_roi)]
7676
order = ['roi_2', 'roi_1']
7777
x = np.random.rand(n_epochs, n_roi, n_times)
78-
dfc = conn_dfc(x, win_sample, times=times, roi=roi)[0].mean('trials')
78+
dfc = conn_dfc(x, win_sample, times=times, roi=roi).mean('trials')
7979
# reshape it without the time dimension
8080
dfc_mean = conn_reshape_undirected(dfc.mean('times'))
8181
assert dfc_mean.shape == (n_roi, n_roi, 1)
@@ -98,7 +98,7 @@ def test_conn_reshape_directed(self):
9898
dt, lag, t0 = 10, 2, [50, 80]
9999
order = ['roi_2', 'roi_1']
100100
# compute covgc
101-
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gauss')[0]
101+
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gauss')
102102
gc = gc.mean('trials')
103103
# reshape it without the time dimension
104104
gc_mean = conn_reshape_directed(gc.copy().mean('times'))

frites/simulations/sim_ar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ def compute_covgc(self, ar, dt=50, lag=5, step=1, method='gc',
315315
"""
316316
# compute the granger causality
317317
t0 = np.arange(lag, ar.shape[-1] - dt, step)
318-
gc, _, _, _ = conn_covgc(ar, dt, lag, t0, times='times', method=method,
319-
roi='roi', step=1, conditional=conditional)
318+
gc = conn_covgc(ar, dt, lag, t0, times='times', method=method,
319+
roi='roi', step=1, conditional=conditional)
320320
gc['trials'] = ar['trials']
321321
self._gc = gc
322322
# compute the MI between stimulus / raw

0 commit comments

Comments
 (0)