Skip to content

Commit a33d536

Browse files
committed
move every connectivity measures to frites.conn
1 parent 8378b6e commit a33d536

File tree

14 files changed

+301
-299
lines changed

14 files changed

+301
-299
lines changed

docs/source/api.rst

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,28 @@ Workflow
4242
WfFit
4343
WfStatsEphy
4444

45+
.. raw:: html
46+
47+
<hr>
48+
49+
Connectivity
50+
------------
51+
52+
.. currentmodule:: frites.conn
53+
54+
.. automodule:: frites.conn
55+
:no-members:
56+
:no-inherited-members:
57+
58+
.. autosummary::
59+
:toctree: generated/
60+
61+
conn_dfc
62+
conn_covgc
63+
conn_fit
64+
conn_transfer_entropy
65+
66+
4567
.. raw:: html
4668

4769
<hr>
@@ -254,14 +276,3 @@ Gaussian-Copula based measures to apply to multidimensional vectors
254276
gccmi_nd_ccnd
255277
gccmi_model_nd_cdnd
256278
gccmi_nd_ccc
257-
258-
Core connectivity
259-
+++++++++++++++++
260-
261-
.. autosummary::
262-
:toctree: generated/
263-
264-
dfc_gc
265-
it_transfer_entropy
266-
it_fit
267-
covgc

docs/source/references.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ References
22
----------
33

44
.. bibliography:: refs.bib
5-
:style: plain
5+
:style: plain

examples/conn/plot_covgc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from itertools import product
1010

1111
from frites.simulations import sim_single_suj_ephy
12-
from frites.core import covgc
12+
from frites.conn import conn_covgc
1313

1414
import matplotlib.pyplot as plt
1515
plt.style.use('seaborn-white')
@@ -58,8 +58,8 @@
5858
t0 = np.arange(100, 900, 10)
5959
lag = 10
6060
dt = 100
61-
gc, pairs, roi_p, times_p = covgc(x, dt, lag, t0, times=times, roi=roi,
62-
n_jobs=1)
61+
gc, pairs, roi_p, times_p = conn_covgc(x, dt, lag, t0, times=times, roi=roi,
62+
n_jobs=1)
6363
# take the mean across trials
6464
gc = gc.mean('trials')
6565

examples/conn/plot_dfc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from itertools import product
1111

1212
from frites.simulations import sim_single_suj_ephy
13-
from frites.core import dfc_gc
13+
from frites.conn import conn_dfc
1414
from frites.utils import define_windows, plot_windows
1515

1616
import matplotlib.pyplot as plt
@@ -72,7 +72,7 @@
7272
# each of the temporal window
7373

7474
# compute DFC
75-
dfc, pairs, roi_p = dfc_gc(x, times, roi, win_sample)
75+
dfc, pairs, roi_p = conn_dfc(x, times, roi, win_sample)
7676

7777
# sphinx_gallery_thumbnail_number = 2
7878
plt.figure(figsize=(10, 8))

frites/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
import logging
88

9-
from frites import io, core, stats, utils, workflow, simulations # noqa
9+
from frites import io, core, conn, stats, utils, workflow, simulations # noqa
1010

1111
__version__ = "0.3.4"
1212

frites/conn/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Information-theritical measures of connectivity."""
2+
from .conn_covgc import conn_covgc # noqa
3+
from .conn_dfc import conn_dfc # noqa
4+
from .conn_transfer_entropy import conn_transfer_entropy # noqa
5+
from .conn_fit import conn_fit # noqa

frites/core/covgc.py renamed to frites/conn/conn_covgc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def _gccovgc(d_s, d_t, ind_tx, t0):
128128

129129

130130

131-
def covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gauss',
132-
verbose=None, n_jobs=-1):
131+
def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gauss',
132+
n_jobs=-1, verbose=None):
133133
r"""Single-trial covariance-based Granger Causality for gaussian variables.
134134
135135
This function computes the covariance-based Granger Causality (covgc) for
@@ -202,7 +202,7 @@ def covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gauss',
202202
203203
See also
204204
--------
205-
dfc_gc
205+
conn_dfc
206206
"""
207207
set_log_level(verbose)
208208
# -------------------------------------------------------------------------

frites/conn/conn_dfc.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Dynamic Functional Connectivity."""
2+
import numpy as np
3+
import xarray as xr
4+
5+
from frites.io import set_log_level, logger
6+
7+
from frites.core import mi_nd_gg, copnorm_nd
8+
from frites.config import CONFIG
9+
10+
11+
12+
def conn_dfc(data, times, roi, win_sample, verbose=None):
13+
"""Compute the Dynamic Functional Connectivity using the GCMI.
14+
15+
This function computes the Dynamic Functional Connectivity (DFC) using the
16+
Gaussian Copula Mutual Information (GCMI). The DFC is computed across time
17+
points for each trial. Note that the DFC can either be computed on windows
18+
manually defined or on sliding windows.
19+
20+
Parameters
21+
----------
22+
data : array_like
23+
Electrophysiological data array of a single subject organized as
24+
(n_epochs, n_roi, n_times)
25+
times : array_like
26+
Time vector array of shape (n_times,)
27+
roi : array_like
28+
ROI names of a single subject
29+
win_sample : array_like
30+
Array of shape (n_windows, 2) describing where each window start and
31+
finish. You can use the function :func:`frites.utils.define_windows`
32+
to define either manually either sliding windows.
33+
34+
Returns
35+
-------
36+
dfc : array_like
37+
The DFC array of shape (n_epochs, n_pairs, n_windows)
38+
pairs : array_like
39+
Array of pairs of shape (n_pairs, 2)
40+
roi_p : array_like
41+
Array of shape (n_pairs,) describing the name of each pair
42+
43+
See also
44+
--------
45+
define_windows, covgc
46+
"""
47+
set_log_level(verbose)
48+
# -------------------------------------------------------------------------
49+
# data checking
50+
assert isinstance(data, np.ndarray) and (data.ndim == 3)
51+
n_epochs, n_roi, n_pts = data.shape
52+
assert (len(roi) == n_roi) and (len(times) == n_pts)
53+
assert isinstance(win_sample, np.ndarray) and (win_sample.ndim == 2)
54+
assert win_sample.dtype in CONFIG['INT_DTYPE']
55+
n_win = win_sample.shape[0]
56+
# get the non-directed pairs
57+
x_s, x_t = np.triu_indices(n_roi, k=1)
58+
n_pairs = len(x_s)
59+
pairs = np.c_[x_s, x_t]
60+
# build roi pairs names
61+
roi_p = [f"{roi[s]}-{roi[t]}" for s, t in zip(x_s, x_t)]
62+
63+
# -------------------------------------------------------------------------
64+
# compute dfc
65+
logger.info(f'Computing DFC between {n_pairs} pairs')
66+
dfc = np.zeros((n_epochs, n_pairs, n_win), dtype=np.float32)
67+
for n_w, w in enumerate(win_sample):
68+
# select the data in the window and copnorm across time points
69+
data_w = copnorm_nd(data[..., w[0]:w[1]], axis=2)
70+
# compute mi between pairs
71+
for n_p, (s, t) in enumerate(zip(x_s, x_t)):
72+
dfc[:, n_p, n_w] = mi_nd_gg(data_w[:, [s], :], data_w[:, [t], :],
73+
**CONFIG["KW_GCMI"])
74+
75+
# -------------------------------------------------------------------------
76+
# dataarray conversion
77+
trials = np.arange(n_epochs)
78+
win_times = times[win_sample]
79+
dfc = xr.DataArray(dfc, dims=('trials', 'roi', 'times'),
80+
coords=(trials, roi_p, win_times.mean(1)))
81+
# add the windows used in the attributes
82+
dfc.attrs['win_sample'] = np.r_[tuple(win_sample)]
83+
dfc.attrs['win_times'] = np.r_[tuple(win_times)]
84+
85+
return dfc, pairs, roi_p

frites/conn/conn_fit.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Feature specific information transfer (Numba compliant)."""
2+
import numpy as np
3+
4+
from frites.utils import jit
5+
6+
7+
@jit("f4[:,:,:](f4[:,:,:], f4[:,:,:], f4[:], f4)")
8+
def conn_fit(x_s, x_t, times, max_delay): # noqa
9+
"""Compute Feature-specific Information Transfer (FIT).
10+
11+
This function has been written for supporting 3D arrays. If Numba is
12+
installed, performances of this function can be greatly improved.
13+
14+
Parameters
15+
----------
16+
x_s : array_like
17+
Array to use as source. Must be a 3d array of shape (:, :, n_times)
18+
and of type np.float32
19+
x_t : array_like
20+
Array to use as target. Must be a 3d array of shape (:, :, n_times)
21+
and of type np.float32
22+
times : array_like
23+
Time vector of shape (n_times,) and of type np.float32
24+
max_delay : float | .3
25+
Maximum delay (must be a np.float32)
26+
27+
Returns
28+
-------
29+
fit : array_like
30+
Array of FIT of shape (:, :, n_times - max_delay)
31+
"""
32+
# ---------------------------------------------------------------------
33+
n_dim, n_suj, n_times = x_s.shape
34+
# time indices for target roi
35+
t_start = np.where(times > times[0] + max_delay)[0]
36+
# max delay index
37+
max_delay = n_times - len(t_start)
38+
39+
# ---------------------------------------------------------------------
40+
# Compute FIT on original MI values
41+
fit = np.zeros((n_dim, n_suj, n_times - max_delay), dtype=np.float32)
42+
43+
# mi at target roi in the present
44+
x_t_pres = x_t[:, :, t_start]
45+
46+
# Loop over delays for past of target and sources
47+
for delay in range(1, max_delay):
48+
# get past delay indices
49+
past_delay = t_start - delay
50+
# mi at target roi in the past
51+
x_t_past = x_t[:, :, past_delay]
52+
# mi at sources roi in the past
53+
x_s_past = x_s[:, :, past_delay]
54+
# redundancy between sources and target (min MI)
55+
red_s_t = np.minimum(x_t_pres, x_s_past)
56+
# redundancy between sources, target present and target past
57+
red_all = np.minimum(red_s_t, x_t_past)
58+
# sum delay-specific FIT (source, target)
59+
fit += red_s_t - red_all
60+
61+
return fit

frites/conn/conn_transfer_entropy.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Transfer entropy using the Gaussian-Copula."""
2+
import numpy as np
3+
import xarray as xr
4+
5+
from frites.core import cmi_nd_ggg, copnorm_nd
6+
from frites.config import CONFIG
7+
8+
9+
def conn_transfer_entropy(x, max_delay=30, pairs=None, gcrn=True):
10+
"""Compute the transfer entropy.
11+
12+
The transfer entropy represents the amount of information that is send
13+
from a source to a target. It is defined as :
14+
15+
.. math::
16+
17+
TE = I(source_{past}; target_{present} | target_{past})
18+
19+
Where :math:`past` is defined using the `max_delay` input parameter. Note
20+
that the transfer entropy only provides about the amount of information
21+
that is sent, not on the content.
22+
23+
Parameters
24+
----------
25+
x : array_like
26+
Array of data of shape (n_roi, n_times, n_epochs). Must be a gaussian
27+
variable
28+
max_delay : int | 30
29+
Number of time points defining where to stop looking at in the past.
30+
Increasing this maximum delay input can lead to slower computations
31+
pairs : array_like
32+
Array of pairs to consider for computing the transfer entropy. It
33+
should be an array of shape (n_pairs, 2) where the first column refers
34+
to sources and the second to targets. If None, all pairs will be
35+
computed
36+
gcrn : bool | True
37+
Apply a Gaussian Copula rank normalization
38+
39+
Returns
40+
-------
41+
te : array_like
42+
The transfer entropy array of shape (n_pairs, n_times - max_delay)
43+
pairs : array_like
44+
Pairs vector use for computations of shape (n_pairs, 2)
45+
"""
46+
# -------------------------------------------------------------------------
47+
# check pairs
48+
n_roi, n_times, n_epochs = x.shape
49+
if not isinstance(pairs, np.ndarray):
50+
pairs = np.c_[np.where(~np.eye(n_roi, dtype=bool))]
51+
assert isinstance(pairs, np.ndarray) and (pairs.ndim == 2) and (
52+
pairs.shape[1] == 2), ("`pairs` should be a 2d array of shape "
53+
"(n_pairs, 2) where the first column refers to "
54+
"sources and the second to targets")
55+
x_all_s, x_all_t = pairs[:, 0], pairs[:, 1]
56+
n_pairs = len(x_all_s)
57+
# check max_delay
58+
assert isinstance(max_delay, (int, np.int)), ("`max_delay` should be an "
59+
"integer")
60+
# check input data
61+
assert (x.ndim == 3), ("input data `x` should be a 3d array of shape "
62+
"(n_roi, n_times, n_epochs)")
63+
x = x[..., np.newaxis, :]
64+
65+
# -------------------------------------------------------------------------
66+
# apply copnorm
67+
if gcrn:
68+
x = copnorm_nd(x, axis=-1)
69+
70+
# -------------------------------------------------------------------------
71+
# compute the transfer entropy
72+
te = np.zeros((n_pairs, n_times - max_delay), dtype=float)
73+
for n_s, x_s in enumerate(x_all_s):
74+
# select targets
75+
is_source = x_all_s == x_s
76+
x_t = x_all_t[is_source]
77+
targets = x[x_t, ...]
78+
# tile source
79+
source = np.tile(x[[x_s], ...], (targets.shape[0], 1, 1, 1))
80+
# loop over remaining time points
81+
for n_d, d in enumerate(range(max_delay + 1, n_times)):
82+
t_pres = np.tile(targets[:, [d], :], (1, max_delay, 1, 1))
83+
past = slice(d - max_delay - 1, d - 1)
84+
s_past = source[:, past, ...]
85+
t_past = targets[:, past, ...]
86+
# compute the transfer entropy
87+
_te = cmi_nd_ggg(s_past, t_pres, t_past, **CONFIG["KW_GCMI"])
88+
# take the sum over delays
89+
te[is_source, n_d] = _te.mean(1)
90+
91+
return te, pairs

0 commit comments

Comments
 (0)