|
6 | 6 | from frites.core import mi_nd_gg, copnorm_nd
|
7 | 7 | from frites.config import CONFIG
|
8 | 8 |
|
9 |
| -from joblib import Parallel, delayed |
| 9 | +from mne.parallel import parallel_func |
| 10 | +from mne.utils import ProgressBar |
10 | 11 |
|
11 | 12 |
|
12 | 13 |
|
@@ -67,15 +68,20 @@ def conn_dfc(data, times, roi, win_sample, n_jobs=1, verbose=None):
|
67 | 68 | # -------------------------------------------------------------------------
|
68 | 69 | # compute dfc
|
69 | 70 | logger.info(f'Computing DFC between {n_pairs} pairs')
|
| 71 | + # get the parallel function |
| 72 | + parallel, p_fun, _ = parallel_func(mi_nd_gg, n_jobs=n_jobs, verbose=False) |
| 73 | + pbar = ProgressBar(range(n_win), mesg='Estimating DFC') |
| 74 | + |
70 | 75 | dfc = np.zeros((n_epochs, n_pairs, n_win), dtype=np.float32)
|
71 | 76 | for n_w, w in enumerate(win_sample):
|
72 | 77 | # select the data in the window and copnorm across time points
|
73 | 78 | data_w = copnorm_nd(data[..., w[0]:w[1]], axis=2)
|
74 | 79 | # compute mi between pairs
|
75 |
| - _dfc = Parallel(n_jobs=n_jobs)(delayed(mi_nd_gg)( |
76 |
| - data_w[:, [s], :], data_w[:, [t], :], |
77 |
| - **CONFIG["KW_GCMI"]) for s, t in zip(x_s, x_t)) |
| 80 | + _dfc = parallel( |
| 81 | + p_fun(data_w[:, [s], :], data_w[:, [t], :], |
| 82 | + **CONFIG["KW_GCMI"]) for s, t in zip(x_s, x_t)) |
78 | 83 | dfc[..., n_w] = np.stack(_dfc, axis=1)
|
| 84 | + pbar.update_with_increment_value(1) |
79 | 85 |
|
80 | 86 | # -------------------------------------------------------------------------
|
81 | 87 | # dataarray conversion
|
|
0 commit comments