Skip to content

Commit 43fceb0

Browse files
committed
Implementation + test of the cross-correlation function
1 parent 70b184c commit 43fceb0

File tree

5 files changed

+182
-6
lines changed

5 files changed

+182
-6
lines changed

docs/source/api/api_connectivity.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Connectivity metrics
1717

1818
conn_dfc
1919
conn_covgc
20+
conn_ccf
2021
conn_transfer_entropy
2122

2223
Utility functions

frites/conn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .conn_io import conn_io # noqa
1515

1616
# connectivity metrics
17+
from .conn_ccf import conn_ccf # noqa
1718
from .conn_covgc import conn_covgc # noqa
1819
from .conn_dfc import conn_dfc # noqa
1920
from .conn_transfer_entropy import conn_transfer_entropy # noqa

frites/conn/conn_ccf.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Cross-correlation function."""
2+
import numpy as np
3+
import xarray as xr
4+
5+
from frites.conn import conn_io
6+
from frites.io import logger
7+
from frites.estimator import GCMIEstimator
8+
from frites.utils import parallel_func
9+
from frites.utils.preproc import _acf
10+
11+
12+
13+
def conn_ccf(data, times=None, roi=None, normalized=True, n_jobs=1,
14+
verbose=None):
15+
"""Single trial Cross-Correlation Function.
16+
17+
This function computes the pairwise Cross Correlation (CCF) at the single
18+
trial level. This can be particulary usefull to find whether there are
19+
temporal delays between times series.
20+
21+
Parameters
22+
----------
23+
data : array_like
24+
Electrophysiological data. Several input types are supported :
25+
26+
* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
27+
* mne.Epochs
28+
* xarray.DataArray of shape (n_epochs, n_roi, n_times)
29+
30+
times : array_like | None
31+
Time vector array of shape (n_times,). If the input is an xarray, the
32+
name of the time dimension can be provided
33+
roi : array_like | None
34+
ROI names of a single subject. If the input is an xarray, the
35+
name of the ROI dimension can be provided
36+
normalized : bool | True
37+
Z-score normalization of the data. By default, it set to true.
38+
n_jobs : int | 1
39+
Number of jobs to use for parallel computing (use -1 to use all
40+
jobs). The parallel loop is set at the pair level.
41+
42+
Returns
43+
-------
44+
ccf : array_like
45+
The Cross-Correlation array of shape (n_epochs, n_pairs, n_times). When
46+
the peak of correlation occurs at a negative time it means that the
47+
target has to be moved **toward** the source. On the contrary, if the
48+
peak occurs at positive time it means that the target is moved **away**
49+
of the source.
50+
"""
51+
# ________________________________ INPUTS _________________________________
52+
# inputs conversion
53+
data, cfg = conn_io(
54+
data, times=times, roi=roi, agg_ch=False, win_sample=None, pairs=None,
55+
sort=True, name='CCF', verbose=verbose,
56+
)
57+
58+
# extract variables
59+
x, trials, attrs = data.data, data['y'].data, cfg['attrs']
60+
x_s, x_t = cfg['x_s'], cfg['x_t']
61+
roi_p, roi_idx = cfg['roi_p'], cfg['roi_idx']
62+
times = data['times'].data
63+
n_pairs = len(x_s)
64+
65+
# data normalization
66+
if normalized:
67+
x = (x - x.mean(-1, keepdims=True)) / x.std(-1, keepdims=True)
68+
69+
# __________________________________ CCF __________________________________
70+
# function to put in parallel
71+
def para_fun(xs, xt):
72+
n_trials = xs.shape[0]
73+
corr = np.zeros((n_trials, int(2 * len(times)) - 1))
74+
for n_t in range(n_trials):
75+
corr[n_t, :] = _acf(xs[n_t, :], xt[n_t, :])
76+
return corr
77+
78+
# prepare parallel function
79+
n_jobs = 1 if n_pairs == 1 else n_jobs
80+
parallel, p_fun = parallel_func(para_fun, n_jobs=n_jobs, verbose=verbose,
81+
total=n_pairs, mesg='Estimating CCF')
82+
83+
logger.info(f'Computing CCF between {n_pairs} pairs')
84+
85+
# compute ccf
86+
ccf = parallel(
87+
p_fun(x[:, i_s, :], x[:, i_t, :]) for i_s, i_t in zip(x_s, x_t))
88+
ccf = np.stack(ccf, axis=1)
89+
90+
# ________________________________ OUTPUTS ________________________________
91+
# dataarray conversion
92+
times_n = np.arange(ccf.shape[-1]).astype(float)# / cfg['sfreq']
93+
times_n -= times_n.mean()
94+
ccf = xr.DataArray(ccf, dims=('trials', 'roi', 'times'), name=f'CCF',
95+
coords=(trials, roi_p, times_n))
96+
97+
# add the windows used in the attributes
98+
ccf.attrs = {**dict(type='ccf', normalized=int(normalized)), **attrs}
99+
100+
return ccf
101+
102+
if __name__ == '__main__':
103+
import matplotlib.pyplot as plt
104+
from frites.estimator import CorrEstimator
105+
106+
n_trials = 20
107+
n_roi = 3
108+
n_times = 1000
109+
# create coordinates
110+
trials = np.arange(n_trials)
111+
roi = [f"roi_{k}" for k in range(n_roi)]
112+
times = (np.arange(n_times) - 200) / 64.
113+
# data creation
114+
x = np.random.rand(n_trials, n_roi, n_times)
115+
# inject relation
116+
bump = np.hanning(200).reshape(1, -1)
117+
x[:, 0, 200:400] += bump
118+
x[:, 1, 220:420] += bump
119+
x[:, 2, 260:460] += bump
120+
# xarray conversion
121+
x = xr.DataArray(x, dims=('trials', 'roi', 'times'),
122+
coords=(trials, roi, times))
123+
plt.figure(figsize=(15, 6))
124+
125+
# compute delayed dfc
126+
ccf = conn_ccf(x, times='times', roi='roi', n_jobs=1, verbose=False)
127+
128+
plt.subplot(121)
129+
x.mean('trials').plot(x='times', hue='roi')
130+
plt.subplot(122)
131+
ccf.mean('trials').plot(x='times', hue='roi')
132+
plt.show()

frites/conn/tests/test_conn.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import xarray as xr
44

5-
from frites.conn import (conn_covgc, conn_transfer_entropy, conn_dfc)
5+
from frites.conn import (conn_covgc, conn_transfer_entropy, conn_dfc, conn_ccf)
66

77

88
class TestConn(object):
@@ -70,3 +70,41 @@ def test_conn_covgc(self):
7070
assert isinstance(gc, xr.DataArray)
7171
gc = conn_covgc(x, dt, lag, t0, n_jobs=1, method='gc',
7272
conditional=True, norm=False)
73+
74+
def test_conn_ccf(self):
75+
"""Test function conn_ccf."""
76+
n_trials, n_roi, n_times = 20, 3, 1000
77+
# create coordinates
78+
trials = np.arange(n_trials)
79+
roi = [f"roi_{k}" for k in range(n_roi)]
80+
times = (np.arange(n_times) - 200) / 64.
81+
# data creation
82+
rnd = np.random.RandomState(0)
83+
x = .1 * rnd.rand(n_trials, n_roi, n_times)
84+
# inject relation
85+
bump = np.hanning(200).reshape(1, -1)
86+
x[:, 0, 200:400] += bump
87+
x[:, 1, 220:420] += bump
88+
x[:, 2, 150:350] += bump
89+
# xarray conversion
90+
x = xr.DataArray(x, dims=('trials', 'roi', 'times'),
91+
coords=(trials, roi, times))
92+
# compute delayed dfc
93+
conn_ccf(x, times='times', roi='roi', n_jobs=1, verbose=False,
94+
normalized=False)
95+
ccf = conn_ccf(x, times='times', roi='roi', n_jobs=1, verbose=False)
96+
# shape and dimension checking
97+
assert ccf.ndim == 3
98+
assert ccf.dims == ('trials', 'roi', 'times')
99+
assert len(ccf['trials']) == len(trials)
100+
np.testing.assert_array_equal(ccf['trials'].data, trials)
101+
assert len(ccf['roi']) == 3
102+
# peak detection
103+
ccf_m = ccf.mean('trials')
104+
is_peaks = np.where(ccf_m == ccf_m.max('times'))
105+
peaks = ccf['times'].data[is_peaks[1]]
106+
# peak checking
107+
tol = 5
108+
assert -20 - tol <= peaks[0] <= -20 + tol
109+
assert 50 - tol <= peaks[1] <= 50 + tol
110+
assert 70 - tol <= peaks[2] <= 70 + tol

frites/utils/preproc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,15 @@ def kernel_smoothing(x, kernel, axis=-1):
120120
return x
121121

122122

123-
def _acf(xd):
124-
"""Auto-correlation on a single time-series."""
125-
n = len(xd)
126-
acov = np.correlate(xd, xd, "full")[n - 1:] / n
127-
return acov[: n + 1] / acov[0]
123+
def _acf(xs, xt=None):
124+
"""Auto- or cross-correlation of time-series."""
125+
n = len(xs)
126+
127+
if xt is None: # auto-correlation
128+
acov = np.correlate(xs, xs, "full")[n - 1:] / n
129+
return acov[: n + 1] / acov[0]
130+
else: # cross-correlation
131+
return np.correlate(xs, xt, mode="full") / len(xs)
128132

129133

130134
def acf(x, axis=-1, demean=True):

0 commit comments

Comments
 (0)