|
1 | 1 | """Workflow for computing MI and evaluate statistics."""
|
2 | 2 | import numpy as np
|
3 | 3 | import xarray as xr
|
4 |
| -from joblib import Parallel, delayed |
5 | 4 |
|
6 |
| -from frites import config |
| 5 | +from mne.parallel import parallel_func |
| 6 | +from mne.utils import ProgressBar |
| 7 | + |
7 | 8 | from frites.io import set_log_level, logger
|
8 | 9 | from frites.core import get_core_mi_fun, permute_mi_vector
|
9 | 10 | from frites.workflow.wf_stats_ephy import WfStatsEphy
|
@@ -117,21 +118,24 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
|
117 | 118 | # evaluate true mi
|
118 | 119 | logger.info(f" Evaluate true and permuted mi (n_perm={n_perm}, "
|
119 | 120 | f"n_jobs={n_jobs})")
|
120 |
| - mi = [mi_fun(x[k], y[k], z[k], suj[k], inf, |
121 |
| - n_bins=n_bins) for k in range(n_roi)] |
122 |
| - # get joblib configuration |
123 |
| - cfg_jobs = config.CONFIG["JOBLIB_CFG"] |
| 121 | + # parallel function for computing permutations |
| 122 | + parallel, p_fun, _ = parallel_func(mi_fun, n_jobs=n_jobs, |
| 123 | + verbose=False) |
| 124 | + pbar = ProgressBar(range(n_roi), mesg='MI estimation') |
124 | 125 | # evaluate permuted mi
|
125 |
| - mi_p = [] |
| 126 | + mi, mi_p = [], [] |
126 | 127 | for r in range(n_roi):
|
| 128 | + # compute the true mi |
| 129 | + mi += [mi_fun(x[r], y[r], z[r], suj[r], inf, n_bins=n_bins)] |
| 130 | + |
127 | 131 | # get the randomize version of y
|
128 | 132 | y_p = permute_mi_vector(y[r], suj[r], mi_type=self._mi_type,
|
129 | 133 | inference=self._inference, n_perm=n_perm)
|
130 | 134 | # run permutations using the randomize regressor
|
131 |
| - _mi = Parallel(n_jobs=n_jobs, **cfg_jobs)(delayed(mi_fun)( |
132 |
| - x[r], y_p[p], z[r], suj[r], inf, |
133 |
| - n_bins=n_bins) for p in range(n_perm)) |
| 135 | + _mi = parallel(p_fun(x[r], y_p[p], z[r], suj[r], inf, |
| 136 | + n_bins=n_bins) for p in range(n_perm)) |
134 | 137 | mi_p += [np.asarray(_mi)]
|
| 138 | + pbar.update_with_increment_value(1) |
135 | 139 | # smoothing
|
136 | 140 | if isinstance(self._kernel, np.ndarray):
|
137 | 141 | logger.info(" Apply smoothing to the true and permuted MI")
|
|
0 commit comments