Skip to content

Commit 860f3d4

Browse files
committed
Test internal copy of WfMi
1 parent 0c2228c commit 860f3d4

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

frites/workflow/tests/test_wf_mi.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_no_stat(self):
9999
wf.fit(dt, mcp='fdr', **kw_mi)
100100
t_end_1 = tst()
101101
t_start_2 = tst()
102-
wf.fit(dt, mcp='maxstat', **kw_mi)
102+
wf.fit(mcp='maxstat', **kw_mi)
103103
t_end_2 = tst()
104104
assert t_end_1 - t_start_1 > t_end_2 - t_start_2
105105

@@ -113,6 +113,14 @@ def test_conjunction_analysis(self):
113113
assert cj_ss.shape == (n_subjects, n_times, n_roi)
114114
assert cj.shape == (n_times, n_roi)
115115

116+
def test_copy(self):
117+
"""Test function copy."""
118+
y, gt = sim_mi_cc(x.copy(), snr=1.)
119+
dt = DatasetEphy(x.copy(), y=y, roi=roi, times=time)
120+
wf = WfMi(mi_type='cc', inference='rfx')
121+
_, _ = wf.fit(dt, **kw_mi)
122+
wf_2 = wf.copy()
123+
116124

117125
if __name__ == '__main__':
118-
TestWfMi().test_mi_ccd()
126+
TestWfMi().test_no_stat()

frites/workflow/wf_mi.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _node_compute_mi(self, dataset, n_perm, n_jobs, random_state):
142142

143143
return mi, mi_p
144144

145-
def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
145+
def fit(self, dataset=None, mcp='cluster', n_perm=1000, cluster_th=None,
146146
cluster_alpha=0.05, n_jobs=-1, random_state=None, **kw_stats):
147147
"""Run the workflow on a dataset.
148148
@@ -160,7 +160,8 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
160160
Parameters
161161
----------
162162
dataset : :class:`frites.dataset.DatasetEphy`
163-
A dataset instance
163+
A dataset instance. If the workflow has already been fitted, then
164+
this parameter can remains to None.
164165
mcp : {'cluster', 'maxstat', 'fdr', 'bonferroni', 'nostat', None}
165166
Method to use for correcting p-values for the multiple comparison
166167
problem. Use either :
@@ -207,23 +208,6 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
207208
----------
208209
Maris and Oostenveld, 2007 :cite:`maris2007nonparametric`
209210
"""
210-
# ---------------------------------------------------------------------
211-
# prepare variables
212-
# ---------------------------------------------------------------------
213-
# don't compute permutations if mcp is either nostat / None
214-
if mcp in ['noperm', None]:
215-
n_perm = 0
216-
# get needed dataset's informations
217-
self._times, self._roi = dataset.times, dataset.roi_names
218-
self._mi_dims = dataset._mi_dims
219-
self._mi_coords = dict()
220-
for k in self._mi_dims:
221-
if k != 'roi':
222-
self._mi_coords[k] = dataset.x[0].coords[k].data
223-
else:
224-
self._mi_coords['roi'] = self._roi
225-
self._df_rs, self._n_subjects = dataset.df_rs, dataset._n_subjects
226-
227211
# ---------------------------------------------------------------------
228212
# compute mutual information
229213
# ---------------------------------------------------------------------
@@ -234,6 +218,22 @@ def fit(self, dataset, mcp='cluster', n_perm=1000, cluster_th=None,
234218
"arguments")
235219
mi, mi_p = self._mi, self._mi_p
236220
else:
221+
# don't compute permutations if mcp is either nostat / None
222+
if mcp in ['noperm', None]:
223+
n_perm = 0
224+
225+
# get needed dataset's informations
226+
self._times, self._roi = dataset.times, dataset.roi_names
227+
self._mi_dims = dataset._mi_dims
228+
self._mi_coords = dict()
229+
for k in self._mi_dims:
230+
if k != 'roi':
231+
self._mi_coords[k] = dataset.x[0].coords[k].data
232+
else:
233+
self._mi_coords['roi'] = self._roi
234+
self._df_rs, self._n_subjects = dataset.df_rs, dataset._n_subjects
235+
236+
# compute mi and permutations
237237
mi, mi_p = self._node_compute_mi(
238238
dataset, n_perm, n_jobs, random_state)
239239
"""

0 commit comments

Comments
 (0)