Skip to content

Commit e7a9c23

Browse files
committed
Fix copnorm in WfMi + adapt WfComod to estimators
1 parent 1944336 commit e7a9c23

File tree

2 files changed

+56
-47
lines changed

2 files changed

+56
-47
lines changed

frites/workflow/wf_comod.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
from frites import config
66
from frites.io import (set_log_level, logger, convert_dfc_outputs)
7-
from frites.core import get_core_mi_fun, permute_mi_trials
7+
from frites.core import permute_mi_trials
88
from frites.workflow.wf_stats import WfStats
99
from frites.workflow.wf_base import WfBase
10+
from frites.estimator import GCMIEstimator
1011

1112

1213
class WfComod(WfBase):
@@ -29,16 +30,11 @@ class WfComod(WfBase):
2930
population.
3031
3132
By default, the workflow uses group level inference ('rfx')
32-
mi_method : {'gc', 'bin'}
33-
Method for computing the mutual information. Use either :
34-
35-
* 'gc' : gaussian-copula based mutual information. This is the
36-
fastest method but it can only captures monotonic relationships
37-
between variables
38-
* 'bin' : binning-based method that can captures any kind of
39-
relationships but is much slower and also required to define the
40-
number of bins to use. Note that if the Numba package is
41-
installed computations should be much faster
33+
estimator : MIEstimator | None
34+
Estimator of mutual-information. If None, the Gaussian-Copula is used
35+
instead. Note that here, since the mutual information is computed
36+
between two time-series coming from two brain regions, the estimator
37+
should has a mi_type='cc'
4238
kernel : array_like | None
4339
Kernel for smoothing true and permuted MI. For example, use
4440
np.hanning(3) for a 3 time points smoothing or np.ones((3)) for a
@@ -49,42 +45,42 @@ class WfComod(WfBase):
4945
Friston et al., 1996, 1999 :cite:`friston1996detecting,friston1999many`
5046
"""
5147

52-
def __init__(self, inference='rfx', mi_method='gc', kernel=None,
48+
def __init__(self, inference='rfx', estimator=None, kernel=None,
5349
verbose=None):
5450
"""Init."""
5551
WfBase.__init__(self)
5652
assert inference in ['ffx', 'rfx'], (
5753
"'inference' input parameter should either be 'ffx' or 'rfx'")
58-
assert mi_method in ['gc', 'bin'], (
59-
"'mi_method' input parameter should either be 'gc' or 'bin'")
6054
self._mi_type = 'cc'
55+
if estimator is None:
56+
estimator = GCMIEstimator(mi_type='cc', copnorm=False,
57+
verbose=verbose)
58+
assert estimator.settings['mi_type'] == self._mi_type
59+
self._copnorm = isinstance(estimator, GCMIEstimator)
6160
self._inference = inference
62-
self._mi_method = mi_method
63-
self._need_copnorm = mi_method == 'gc'
61+
self.estimator = estimator
6462
self._gcrn = inference == 'rfx'
6563
self._kernel = kernel
6664
set_log_level(verbose)
6765
self.clean()
6866
self._wf_stats = WfStats(verbose=verbose)
6967
# update internal config
7068
self.attrs.update(dict(mi_type=self._mi_type, inference=inference,
71-
mi_method=mi_method, kernel=kernel))
69+
kernel=kernel))
7270

73-
logger.info(f"Workflow for computing connectivity ({self._mi_type} - "
74-
f"{mi_method})")
71+
logger.info(f"Workflow for computing comodulations between distant "
72+
f"brain areas ({inference})")
7573

7674

77-
def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
75+
def _node_compute_mi(self, dataset, n_perm, n_jobs, random_state):
7876
"""Compute mi and permuted mi.
7977
8078
Permutations are performed by randomizing the target roi. For the fixed
8179
effect, this randomization is performed across subjects. For the random
8280
effect, the randomization is performed per subject.
8381
"""
8482
# get the function for computing mi
85-
mi_fun = get_core_mi_fun(self._mi_method)[f"{self._mi_type}_conn"]
86-
assert (f"mi_{self._mi_method}_ephy_conn_"
87-
f"{self._mi_type}" == mi_fun.__name__)
83+
core_fun = self.estimator.get_function()
8884
# get x, y, z and subject names per roi
8985
roi, inf = dataset.roi_names, self._inference
9086
# get the pairs for computing mi
@@ -99,7 +95,7 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
9995
logger.info(f" Evaluate true and permuted mi (n_perm={n_perm}, "
10096
f"n_jobs={n_jobs}, n_pairs={len(x_s)})")
10197
mi, mi_p = [], []
102-
kw_get = dict(mi_type=self._mi_type, copnorm=self._need_copnorm,
98+
kw_get = dict(mi_type=self._mi_type, copnorm=self._copnorm,
10399
gcrn_per_suj=self._gcrn)
104100
for s in x_s:
105101
# get source data
@@ -110,19 +106,18 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
110106
da_t = dataset.get_roi_data(roi[t], **kw_get)
111107
suj_t = da_t['subject'].data
112108
# compute mi
113-
_mi = mi_fun(da_s.data, da_t.data, suj_s, suj_t, inf,
114-
n_bins=n_bins)
109+
_mi = comod(da_s.data, da_t.data, suj_s, suj_t, inf, core_fun)
115110
mi += [_mi]
116111
# get the randomize version of y
117112
y_p = permute_mi_trials(suj_t, inference=self._inference,
118113
n_perm=n_perm)
119114
# run permutations using the randomize regressor
120-
_mi_p = Parallel(n_jobs=n_jobs, **cfg_jobs)(delayed(mi_fun)(
115+
_mi_p = Parallel(n_jobs=n_jobs, **cfg_jobs)(delayed(comod)(
121116
da_s.data, da_t.data[..., y_p[p]], suj_s, suj_t, inf,
122-
n_bins=n_bins) for p in range(n_perm))
117+
core_fun) for p in range(n_perm))
123118
mi_p += [np.asarray(_mi_p)]
124119

125-
# # smoothing
120+
# smoothing
126121
if isinstance(self._kernel, np.ndarray):
127122
logger.info(" Apply smoothing to the true and permuted MI")
128123
for r in range(len(mi)):
@@ -138,8 +133,7 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
138133
return mi, mi_p
139134

140135
def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
141-
cluster_alpha=0.05, n_bins=None, n_jobs=-1, random_state=None,
142-
**kw_stats):
136+
cluster_alpha=0.05, n_jobs=-1, random_state=None, **kw_stats):
143137
"""Run the workflow on a dataset.
144138
145139
In order to run the worflow, you must first provide a dataset instance
@@ -179,11 +173,6 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
179173
cluster_alpha : float | 0.05
180174
Control the percentile to use for forming the clusters. By default
181175
the 95th percentile of the permutations is used.
182-
n_bins : int | None
183-
Number of bins to use if the method for computing the mutual
184-
information is based on binning (mi_method='bin'). If None, the
185-
number of bins is going to be automatically inferred based on the
186-
number of trials and variables
187176
n_jobs : int | -1
188177
Number of jobs to use for parallel computing (use -1 to use all
189178
jobs)
@@ -209,11 +198,6 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
209198
# don't compute permutations if mcp is either nostat / None
210199
if mcp in ['noperm', None]:
211200
n_perm = 0
212-
# infer the number of bins if needed
213-
if (self._mi_method == 'bin') and not isinstance(n_bins, int):
214-
n_bins = 4
215-
logger.info(f" Use an automatic number of bins of {n_bins}")
216-
self._n_bins = n_bins
217201
# get important dataset's variables
218202
self._times, self._roi = dataset.times, dataset.roi_names
219203

@@ -228,7 +212,7 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
228212
mi, mi_p = self._mi, self._mi_p
229213
else:
230214
mi, mi_p = self._node_compute_mi(
231-
dataset, self._n_bins, n_perm, n_jobs, random_state)
215+
dataset, n_perm, n_jobs, random_state)
232216

233217
# ---------------------------------------------------------------------
234218
# compute statistics
@@ -239,8 +223,7 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
239223
cluster_alpha=cluster_alpha, inference=self._inference,
240224
**kw_stats)
241225
# update internal config
242-
self.attrs.update(dict(n_perm=n_perm, random_state=random_state,
243-
n_bins=n_bins))
226+
self.attrs.update(dict(n_perm=n_perm, random_state=random_state))
244227
self.attrs.update(self._wf_stats.attrs)
245228

246229
# ---------------------------------------------------------------------
@@ -291,3 +274,28 @@ def tvalues(self):
291274
def wf_stats(self):
292275
"""Get the workflow of statistics."""
293276
return self._wf_stats
277+
278+
279+
def comod(x_1, x_2, suj_1, suj_2, inference, fun):
280+
"""I(C; C) for rfx.
281+
282+
The returned mi array has a shape of (n_subjects, n_times) if inference is
283+
"rfx", (1, n_times) if "ffx".
284+
"""
285+
# proper shape of the regressor
286+
n_times, _, n_trials = x_1.shape
287+
# compute mi across (ffx) or per subject (rfx)
288+
if inference == 'ffx':
289+
mi = fun(x_1, x_2)
290+
elif inference == 'rfx':
291+
# get subject informations
292+
suj_u = np.intersect1d(suj_1, suj_2)
293+
n_subjects = len(suj_u)
294+
# compute mi per subject
295+
mi = np.zeros((n_subjects, n_times), dtype=float)
296+
for n_s, s in enumerate(suj_u):
297+
is_suj_1 = suj_1 == s
298+
is_suj_2 = suj_2 == s
299+
mi[n_s, :] = fun(x_1[..., is_suj_1], x_2[..., is_suj_2])
300+
301+
return mi

frites/workflow/wf_mi.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class WfMi(WfBase):
4141
By default, the workflow uses group level inference ('rfx')
4242
estimator : MIEstimator | None
4343
Estimator of mutual-information. If None, the Gaussian-Copula is used
44-
instead
44+
instead.
4545
kernel : array_like | None
4646
Kernel for smoothing true and permuted MI. For example, use
4747
np.hanning(3) for a 3 time points smoothing or np.ones((3)) for a
@@ -66,6 +66,7 @@ def __init__(self, mi_type='cc', inference='rfx', estimator=None,
6666
estimator = GCMIEstimator(mi_type=mi_type, copnorm=False,
6767
verbose=verbose)
6868
self.estimator = estimator
69+
self._copnorm = isinstance(estimator, GCMIEstimator)
6970
self._gcrn = inference == 'rfx'
7071
self._kernel = kernel
7172
set_log_level(verbose)
@@ -76,7 +77,7 @@ def __init__(self, mi_type='cc', inference='rfx', estimator=None,
7677
mi_type=mi_type, inference=inference, kernel=kernel))
7778

7879
logger.info(f"Workflow for computing mutual information ({inference} -"
79-
f" {estimator.name} - {mi_type})")
80+
f" {mi_type})")
8081

8182
def _node_compute_mi(self, dataset, n_perm, n_jobs, random_state):
8283
"""Compute mi and permuted mi.
@@ -106,7 +107,7 @@ def _node_compute_mi(self, dataset, n_perm, n_jobs, random_state):
106107
for r in range(n_roi):
107108
# get the data of selected roi
108109
da = dataset.get_roi_data(
109-
self._roi[r], copnorm=True, mi_type=self._mi_type,
110+
self._roi[r], copnorm=self._copnorm, mi_type=self._mi_type,
110111
gcrn_per_suj=self._gcrn)
111112
x, y, suj = da.data, da['y'].data, da['subject'].data
112113
kw_mi = dict()

0 commit comments

Comments
 (0)