|
| 1 | +"""Cross-correlation function.""" |
| 2 | +import numpy as np |
| 3 | +import xarray as xr |
| 4 | + |
| 5 | +from frites.conn import conn_io |
| 6 | +from frites.io import logger |
| 7 | +from frites.estimator import GCMIEstimator |
| 8 | +from frites.utils import parallel_func |
| 9 | +from frites.utils.preproc import _acf |
| 10 | + |
| 11 | + |
| 12 | + |
| 13 | +def conn_ccf(data, times=None, roi=None, normalized=True, n_jobs=1, |
| 14 | + verbose=None): |
| 15 | + """Single trial Cross-Correlation Function. |
| 16 | +
|
| 17 | + This function computes the pairwise Cross Correlation (CCF) at the single |
| 18 | + trial level. This can be particulary usefull to find whether there are |
| 19 | + temporal delays between times series. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + data : array_like |
| 24 | + Electrophysiological data. Several input types are supported : |
| 25 | +
|
| 26 | + * Standard NumPy arrays of shape (n_epochs, n_roi, n_times) |
| 27 | + * mne.Epochs |
| 28 | + * xarray.DataArray of shape (n_epochs, n_roi, n_times) |
| 29 | +
|
| 30 | + times : array_like | None |
| 31 | + Time vector array of shape (n_times,). If the input is an xarray, the |
| 32 | + name of the time dimension can be provided |
| 33 | + roi : array_like | None |
| 34 | + ROI names of a single subject. If the input is an xarray, the |
| 35 | + name of the ROI dimension can be provided |
| 36 | + normalized : bool | True |
| 37 | + Z-score normalization of the data. By default, it set to true. |
| 38 | + n_jobs : int | 1 |
| 39 | + Number of jobs to use for parallel computing (use -1 to use all |
| 40 | + jobs). The parallel loop is set at the pair level. |
| 41 | +
|
| 42 | + Returns |
| 43 | + ------- |
| 44 | + ccf : array_like |
| 45 | + The Cross-Correlation array of shape (n_epochs, n_pairs, n_times). When |
| 46 | + the peak of correlation occurs at a negative time it means that the |
| 47 | + target has to be moved **toward** the source. On the contrary, if the |
| 48 | + peak occurs at positive time it means that the target is moved **away** |
| 49 | + of the source. |
| 50 | + """ |
| 51 | + # ________________________________ INPUTS _________________________________ |
| 52 | + # inputs conversion |
| 53 | + data, cfg = conn_io( |
| 54 | + data, times=times, roi=roi, agg_ch=False, win_sample=None, pairs=None, |
| 55 | + sort=True, name='CCF', verbose=verbose, |
| 56 | + ) |
| 57 | + |
| 58 | + # extract variables |
| 59 | + x, trials, attrs = data.data, data['y'].data, cfg['attrs'] |
| 60 | + x_s, x_t = cfg['x_s'], cfg['x_t'] |
| 61 | + roi_p, roi_idx = cfg['roi_p'], cfg['roi_idx'] |
| 62 | + times = data['times'].data |
| 63 | + n_pairs = len(x_s) |
| 64 | + |
| 65 | + # data normalization |
| 66 | + if normalized: |
| 67 | + x = (x - x.mean(-1, keepdims=True)) / x.std(-1, keepdims=True) |
| 68 | + |
| 69 | + # __________________________________ CCF __________________________________ |
| 70 | + # function to put in parallel |
| 71 | + def para_fun(xs, xt): |
| 72 | + n_trials = xs.shape[0] |
| 73 | + corr = np.zeros((n_trials, int(2 * len(times)) - 1)) |
| 74 | + for n_t in range(n_trials): |
| 75 | + corr[n_t, :] = _acf(xs[n_t, :], xt[n_t, :]) |
| 76 | + return corr |
| 77 | + |
| 78 | + # prepare parallel function |
| 79 | + n_jobs = 1 if n_pairs == 1 else n_jobs |
| 80 | + parallel, p_fun = parallel_func(para_fun, n_jobs=n_jobs, verbose=verbose, |
| 81 | + total=n_pairs, mesg='Estimating CCF') |
| 82 | + |
| 83 | + logger.info(f'Computing CCF between {n_pairs} pairs') |
| 84 | + |
| 85 | + # compute ccf |
| 86 | + ccf = parallel( |
| 87 | + p_fun(x[:, i_s, :], x[:, i_t, :]) for i_s, i_t in zip(x_s, x_t)) |
| 88 | + ccf = np.stack(ccf, axis=1) |
| 89 | + |
| 90 | + # ________________________________ OUTPUTS ________________________________ |
| 91 | + # dataarray conversion |
| 92 | + times_n = np.arange(ccf.shape[-1]).astype(float)# / cfg['sfreq'] |
| 93 | + times_n -= times_n.mean() |
| 94 | + ccf = xr.DataArray(ccf, dims=('trials', 'roi', 'times'), name=f'CCF', |
| 95 | + coords=(trials, roi_p, times_n)) |
| 96 | + |
| 97 | + # add the windows used in the attributes |
| 98 | + ccf.attrs = {**dict(type='ccf', normalized=int(normalized)), **attrs} |
| 99 | + |
| 100 | + return ccf |
| 101 | + |
| 102 | +if __name__ == '__main__': |
| 103 | + import matplotlib.pyplot as plt |
| 104 | + from frites.estimator import CorrEstimator |
| 105 | + |
| 106 | + n_trials = 20 |
| 107 | + n_roi = 3 |
| 108 | + n_times = 1000 |
| 109 | + # create coordinates |
| 110 | + trials = np.arange(n_trials) |
| 111 | + roi = [f"roi_{k}" for k in range(n_roi)] |
| 112 | + times = (np.arange(n_times) - 200) / 64. |
| 113 | + # data creation |
| 114 | + x = np.random.rand(n_trials, n_roi, n_times) |
| 115 | + # inject relation |
| 116 | + bump = np.hanning(200).reshape(1, -1) |
| 117 | + x[:, 0, 200:400] += bump |
| 118 | + x[:, 1, 220:420] += bump |
| 119 | + x[:, 2, 260:460] += bump |
| 120 | + # xarray conversion |
| 121 | + x = xr.DataArray(x, dims=('trials', 'roi', 'times'), |
| 122 | + coords=(trials, roi, times)) |
| 123 | + plt.figure(figsize=(15, 6)) |
| 124 | + |
| 125 | + # compute delayed dfc |
| 126 | + ccf = conn_ccf(x, times='times', roi='roi', n_jobs=1, verbose=False) |
| 127 | + |
| 128 | + plt.subplot(121) |
| 129 | + x.mean('trials').plot(x='times', hue='roi') |
| 130 | + plt.subplot(122) |
| 131 | + ccf.mean('trials').plot(x='times', hue='roi') |
| 132 | + plt.show() |
0 commit comments