Skip to content

Commit 74dc664

Browse files
committed
Progressbar WfMi
1 parent 6570bac commit 74dc664

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

frites/workflow/wf_mi.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Workflow for computing MI and evaluate statistics."""
22
import numpy as np
33
import xarray as xr
4-
from joblib import Parallel, delayed
54

6-
from frites import config
5+
from mne.parallel import parallel_func
6+
from mne.utils import ProgressBar
7+
78
from frites.io import set_log_level, logger
89
from frites.core import get_core_mi_fun, permute_mi_vector
910
from frites.workflow.wf_stats_ephy import WfStatsEphy
@@ -117,21 +118,24 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
117118
# evaluate true mi
118119
logger.info(f" Evaluate true and permuted mi (n_perm={n_perm}, "
119120
f"n_jobs={n_jobs})")
120-
mi = [mi_fun(x[k], y[k], z[k], suj[k], inf,
121-
n_bins=n_bins) for k in range(n_roi)]
122-
# get joblib configuration
123-
cfg_jobs = config.CONFIG["JOBLIB_CFG"]
121+
# parallel function for computing permutations
122+
parallel, p_fun, _ = parallel_func(mi_fun, n_jobs=n_jobs,
123+
verbose=False)
124+
pbar = ProgressBar(range(n_roi), mesg='MI estimation')
124125
# evaluate permuted mi
125-
mi_p = []
126+
mi, mi_p = [], []
126127
for r in range(n_roi):
128+
# compute the true mi
129+
mi += [mi_fun(x[r], y[r], z[r], suj[r], inf, n_bins=n_bins)]
130+
127131
# get the randomize version of y
128132
y_p = permute_mi_vector(y[r], suj[r], mi_type=self._mi_type,
129133
inference=self._inference, n_perm=n_perm)
130134
# run permutations using the randomize regressor
131-
_mi = Parallel(n_jobs=n_jobs, **cfg_jobs)(delayed(mi_fun)(
132-
x[r], y_p[p], z[r], suj[r], inf,
133-
n_bins=n_bins) for p in range(n_perm))
135+
_mi = parallel(p_fun(x[r], y_p[p], z[r], suj[r], inf,
136+
n_bins=n_bins) for p in range(n_perm))
134137
mi_p += [np.asarray(_mi)]
138+
pbar.update_with_increment_value(1)
135139
# smoothing
136140
if isinstance(self._kernel, np.ndarray):
137141
logger.info(" Apply smoothing to the true and permuted MI")

0 commit comments

Comments
 (0)