Skip to content

Commit 8ec636d

Browse files
committed
DFC with progressbar
1 parent 74dc664 commit 8ec636d

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

frites/conn/conn_dfc.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from frites.core import mi_nd_gg, copnorm_nd
77
from frites.config import CONFIG
88

9-
from joblib import Parallel, delayed
9+
from mne.parallel import parallel_func
10+
from mne.utils import ProgressBar
1011

1112

1213

@@ -67,15 +68,20 @@ def conn_dfc(data, times, roi, win_sample, n_jobs=1, verbose=None):
6768
# -------------------------------------------------------------------------
6869
# compute dfc
6970
logger.info(f'Computing DFC between {n_pairs} pairs')
71+
# get the parallel function
72+
parallel, p_fun, _ = parallel_func(mi_nd_gg, n_jobs=n_jobs, verbose=False)
73+
pbar = ProgressBar(range(n_win), mesg='Estimating DFC')
74+
7075
dfc = np.zeros((n_epochs, n_pairs, n_win), dtype=np.float32)
7176
for n_w, w in enumerate(win_sample):
7277
# select the data in the window and copnorm across time points
7378
data_w = copnorm_nd(data[..., w[0]:w[1]], axis=2)
7479
# compute mi between pairs
75-
_dfc = Parallel(n_jobs=n_jobs)(delayed(mi_nd_gg)(
76-
data_w[:, [s], :], data_w[:, [t], :],
77-
**CONFIG["KW_GCMI"]) for s, t in zip(x_s, x_t))
80+
_dfc = parallel(
81+
p_fun(data_w[:, [s], :], data_w[:, [t], :],
82+
**CONFIG["KW_GCMI"]) for s, t in zip(x_s, x_t))
7883
dfc[..., n_w] = np.stack(_dfc, axis=1)
84+
pbar.update_with_increment_value(1)
7985

8086
# -------------------------------------------------------------------------
8187
# dataarray conversion

0 commit comments

Comments
 (0)