4
4
5
5
from frites import config
6
6
from frites .io import (set_log_level , logger , convert_dfc_outputs )
7
- from frites .core import get_core_mi_fun , permute_mi_trials
7
+ from frites .core import permute_mi_trials
8
8
from frites .workflow .wf_stats import WfStats
9
9
from frites .workflow .wf_base import WfBase
10
+ from frites .estimator import GCMIEstimator
10
11
11
12
12
13
class WfComod (WfBase ):
@@ -29,16 +30,11 @@ class WfComod(WfBase):
29
30
population.
30
31
31
32
By default, the workflow uses group level inference ('rfx')
32
- mi_method : {'gc', 'bin'}
33
- Method for computing the mutual information. Use either :
34
-
35
- * 'gc' : gaussian-copula based mutual information. This is the
36
- fastest method but it can only captures monotonic relationships
37
- between variables
38
- * 'bin' : binning-based method that can captures any kind of
39
- relationships but is much slower and also required to define the
40
- number of bins to use. Note that if the Numba package is
41
- installed computations should be much faster
33
+ estimator : MIEstimator | None
34
+ Estimator of mutual-information. If None, the Gaussian-Copula is used
35
+ instead. Note that here, since the mutual information is computed
36
+ between two time-series coming from two brain regions, the estimator
37
+ should has a mi_type='cc'
42
38
kernel : array_like | None
43
39
Kernel for smoothing true and permuted MI. For example, use
44
40
np.hanning(3) for a 3 time points smoothing or np.ones((3)) for a
@@ -49,42 +45,42 @@ class WfComod(WfBase):
49
45
Friston et al., 1996, 1999 :cite:`friston1996detecting,friston1999many`
50
46
"""
51
47
52
- def __init__ (self , inference = 'rfx' , mi_method = 'gc' , kernel = None ,
48
+ def __init__ (self , inference = 'rfx' , estimator = None , kernel = None ,
53
49
verbose = None ):
54
50
"""Init."""
55
51
WfBase .__init__ (self )
56
52
assert inference in ['ffx' , 'rfx' ], (
57
53
"'inference' input parameter should either be 'ffx' or 'rfx'" )
58
- assert mi_method in ['gc' , 'bin' ], (
59
- "'mi_method' input parameter should either be 'gc' or 'bin'" )
60
54
self ._mi_type = 'cc'
55
+ if estimator is None :
56
+ estimator = GCMIEstimator (mi_type = 'cc' , copnorm = False ,
57
+ verbose = verbose )
58
+ assert estimator .settings ['mi_type' ] == self ._mi_type
59
+ self ._copnorm = isinstance (estimator , GCMIEstimator )
61
60
self ._inference = inference
62
- self ._mi_method = mi_method
63
- self ._need_copnorm = mi_method == 'gc'
61
+ self .estimator = estimator
64
62
self ._gcrn = inference == 'rfx'
65
63
self ._kernel = kernel
66
64
set_log_level (verbose )
67
65
self .clean ()
68
66
self ._wf_stats = WfStats (verbose = verbose )
69
67
# update internal config
70
68
self .attrs .update (dict (mi_type = self ._mi_type , inference = inference ,
71
- mi_method = mi_method , kernel = kernel ))
69
+ kernel = kernel ))
72
70
73
- logger .info (f"Workflow for computing connectivity ( { self . _mi_type } - "
74
- f"{ mi_method } )" )
71
+ logger .info (f"Workflow for computing comodulations between distant "
72
+ f"brain areas ( { inference } )" )
75
73
76
74
77
- def _node_compute_mi (self , dataset , n_bins , n_perm , n_jobs , random_state ):
75
+ def _node_compute_mi (self , dataset , n_perm , n_jobs , random_state ):
78
76
"""Compute mi and permuted mi.
79
77
80
78
Permutations are performed by randomizing the target roi. For the fixed
81
79
effect, this randomization is performed across subjects. For the random
82
80
effect, the randomization is performed per subject.
83
81
"""
84
82
# get the function for computing mi
85
- mi_fun = get_core_mi_fun (self ._mi_method )[f"{ self ._mi_type } _conn" ]
86
- assert (f"mi_{ self ._mi_method } _ephy_conn_"
87
- f"{ self ._mi_type } " == mi_fun .__name__ )
83
+ core_fun = self .estimator .get_function ()
88
84
# get x, y, z and subject names per roi
89
85
roi , inf = dataset .roi_names , self ._inference
90
86
# get the pairs for computing mi
@@ -99,7 +95,7 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
99
95
logger .info (f" Evaluate true and permuted mi (n_perm={ n_perm } , "
100
96
f"n_jobs={ n_jobs } , n_pairs={ len (x_s )} )" )
101
97
mi , mi_p = [], []
102
- kw_get = dict (mi_type = self ._mi_type , copnorm = self ._need_copnorm ,
98
+ kw_get = dict (mi_type = self ._mi_type , copnorm = self ._copnorm ,
103
99
gcrn_per_suj = self ._gcrn )
104
100
for s in x_s :
105
101
# get source data
@@ -110,19 +106,18 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
110
106
da_t = dataset .get_roi_data (roi [t ], ** kw_get )
111
107
suj_t = da_t ['subject' ].data
112
108
# compute mi
113
- _mi = mi_fun (da_s .data , da_t .data , suj_s , suj_t , inf ,
114
- n_bins = n_bins )
109
+ _mi = comod (da_s .data , da_t .data , suj_s , suj_t , inf , core_fun )
115
110
mi += [_mi ]
116
111
# get the randomize version of y
117
112
y_p = permute_mi_trials (suj_t , inference = self ._inference ,
118
113
n_perm = n_perm )
119
114
# run permutations using the randomize regressor
120
- _mi_p = Parallel (n_jobs = n_jobs , ** cfg_jobs )(delayed (mi_fun )(
115
+ _mi_p = Parallel (n_jobs = n_jobs , ** cfg_jobs )(delayed (comod )(
121
116
da_s .data , da_t .data [..., y_p [p ]], suj_s , suj_t , inf ,
122
- n_bins = n_bins ) for p in range (n_perm ))
117
+ core_fun ) for p in range (n_perm ))
123
118
mi_p += [np .asarray (_mi_p )]
124
119
125
- # # smoothing
120
+ # smoothing
126
121
if isinstance (self ._kernel , np .ndarray ):
127
122
logger .info (" Apply smoothing to the true and permuted MI" )
128
123
for r in range (len (mi )):
@@ -138,8 +133,7 @@ def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
138
133
return mi , mi_p
139
134
140
135
def fit (self , dataset , mcp = 'cluster' , n_perm = 1000 , cluster_th = None ,
141
- cluster_alpha = 0.05 , n_bins = None , n_jobs = - 1 , random_state = None ,
142
- ** kw_stats ):
136
+ cluster_alpha = 0.05 , n_jobs = - 1 , random_state = None , ** kw_stats ):
143
137
"""Run the workflow on a dataset.
144
138
145
139
In order to run the worflow, you must first provide a dataset instance
@@ -179,11 +173,6 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
179
173
cluster_alpha : float | 0.05
180
174
Control the percentile to use for forming the clusters. By default
181
175
the 95th percentile of the permutations is used.
182
- n_bins : int | None
183
- Number of bins to use if the method for computing the mutual
184
- information is based on binning (mi_method='bin'). If None, the
185
- number of bins is going to be automatically inferred based on the
186
- number of trials and variables
187
176
n_jobs : int | -1
188
177
Number of jobs to use for parallel computing (use -1 to use all
189
178
jobs)
@@ -209,11 +198,6 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
209
198
# don't compute permutations if mcp is either nostat / None
210
199
if mcp in ['noperm' , None ]:
211
200
n_perm = 0
212
- # infer the number of bins if needed
213
- if (self ._mi_method == 'bin' ) and not isinstance (n_bins , int ):
214
- n_bins = 4
215
- logger .info (f" Use an automatic number of bins of { n_bins } " )
216
- self ._n_bins = n_bins
217
201
# get important dataset's variables
218
202
self ._times , self ._roi = dataset .times , dataset .roi_names
219
203
@@ -228,7 +212,7 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
228
212
mi , mi_p = self ._mi , self ._mi_p
229
213
else :
230
214
mi , mi_p = self ._node_compute_mi (
231
- dataset , self . _n_bins , n_perm , n_jobs , random_state )
215
+ dataset , n_perm , n_jobs , random_state )
232
216
233
217
# ---------------------------------------------------------------------
234
218
# compute statistics
@@ -239,8 +223,7 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
239
223
cluster_alpha = cluster_alpha , inference = self ._inference ,
240
224
** kw_stats )
241
225
# update internal config
242
- self .attrs .update (dict (n_perm = n_perm , random_state = random_state ,
243
- n_bins = n_bins ))
226
+ self .attrs .update (dict (n_perm = n_perm , random_state = random_state ))
244
227
self .attrs .update (self ._wf_stats .attrs )
245
228
246
229
# ---------------------------------------------------------------------
@@ -291,3 +274,28 @@ def tvalues(self):
291
274
def wf_stats (self ):
292
275
"""Get the workflow of statistics."""
293
276
return self ._wf_stats
277
+
278
+
279
+ def comod (x_1 , x_2 , suj_1 , suj_2 , inference , fun ):
280
+ """I(C; C) for rfx.
281
+
282
+ The returned mi array has a shape of (n_subjects, n_times) if inference is
283
+ "rfx", (1, n_times) if "ffx".
284
+ """
285
+ # proper shape of the regressor
286
+ n_times , _ , n_trials = x_1 .shape
287
+ # compute mi across (ffx) or per subject (rfx)
288
+ if inference == 'ffx' :
289
+ mi = fun (x_1 , x_2 )
290
+ elif inference == 'rfx' :
291
+ # get subject informations
292
+ suj_u = np .intersect1d (suj_1 , suj_2 )
293
+ n_subjects = len (suj_u )
294
+ # compute mi per subject
295
+ mi = np .zeros ((n_subjects , n_times ), dtype = float )
296
+ for n_s , s in enumerate (suj_u ):
297
+ is_suj_1 = suj_1 == s
298
+ is_suj_2 = suj_2 == s
299
+ mi [n_s , :] = fun (x_1 [..., is_suj_1 ], x_2 [..., is_suj_2 ])
300
+
301
+ return mi
0 commit comments