Skip to content

Commit 6644cde

Browse files
committed
improve testings
1 parent 18a153a commit 6644cde

File tree

6 files changed

+77
-6
lines changed

6 files changed

+77
-6
lines changed

frites/simulations/sim_distant_mi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,4 @@ def sim_gauss_fit(stim_type='cont_linear', n_epochs=400, n_sti=4, n_pts=400,
141141
# add random noise to remove equal values for gcmi
142142
x = x + rnd_x.rand(*x.shape) / 100.
143143
y = y + rnd_y.rand(*x.shape) / 100.
144-
return x.T, y.T, stim
144+
return x.T, y.T, stim.squeeze()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Test simulating distant MI."""
2+
import numpy as np
3+
4+
from frites.simulations import (sim_gauss_fit, sim_distant_cc_ms,
5+
sim_distant_cc_ss)
6+
7+
8+
class TestDistantMI(object):
9+
10+
def test_sim_gauss_fit(self):
11+
for stim_type in ['discrete_stim', 'cont_linear', 'cont_flat']:
12+
x, y, stim = sim_gauss_fit(stim_type=stim_type, n_epochs=20)
13+
assert x.shape == y.shape
14+
if stim_type is not 'discrete_stim':
15+
assert len(stim) == x.shape[0]
16+
17+
def test_sim_distant_cc_ss(self):
18+
n_epochs = 20
19+
x, y, roi = sim_distant_cc_ss(n_epochs=n_epochs)
20+
assert x.shape == (len(y), len(roi), 400)
21+
22+
def test_sim_distant_cc_ms(self):
23+
n_subjects = 3
24+
x, y, roi, times = sim_distant_cc_ms(n_subjects)
25+
assert len(x) == len(y) == len(roi) == n_subjects
26+
assert [k.shape == x[0].shape for k in x]
27+
assert [k.shape == y[0].shape for k in y]
28+
assert [k.shape == roi[0].shape for k in roi]
29+
assert x[0].shape == (len(y[0]), len(roi[0]), len(times))

frites/stats/stats_param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def rfx_ttest(mi, mi_p, center=False, zscore=False, ttested=False):
5454
mi_p : array_like
5555
A list of array of permuted mutual information of shape
5656
(n_perm, n_suj, n_times). If `ttested` is True, n_suj shoud be 1.
57-
center : {'mean', "median", "trimmed"} | False
57+
center : {'mean', 'median', 'trimmed'} | False
5858
If True, substract the mean of the surrogates to the true and permuted
5959
mi. The median or the 20% trimmed mean can also be removed
6060
:cite:`wilcox2018guide`

frites/stats/tests/test_stats_param.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ def test_rfx_ttest(self):
2424
assert tv.shape == (n_roi, n_times)
2525
assert tv_p.shape == (n_perm, n_roi, n_times)
2626
# center
27-
tv, tv_p = rfx_ttest(x, x_p, center=True)
28-
assert tv.shape == (n_roi, n_times)
29-
assert tv_p.shape == (n_perm, n_roi, n_times)
27+
for center in [False, 'mean', 'median', 'trimmed']:
28+
tv, tv_p = rfx_ttest(x, x_p, center=center)
29+
assert tv.shape == (n_roi, n_times)
30+
assert tv_p.shape == (n_perm, n_roi, n_times)
3031
# zscore
3132
tv, tv_p = rfx_ttest(x, x_p, zscore=True)
3233
assert tv.shape == (n_roi, n_times)

frites/utils/tests/test_perf.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Test performance tools."""
2+
import numpy as np
3+
from time import sleep
4+
5+
from frites.utils.perf import timeit, id, get_data_base, arrays_share_data
6+
7+
8+
class TestPerfTools(object):
9+
10+
def test_timeit(self):
11+
@timeit
12+
def fcn(sec): return sleep(sec) # noqa
13+
14+
def test_id(self):
15+
x = np.random.rand(1000)
16+
assert id(x) != id(x.copy())
17+
assert id(x) == id(x)
18+
19+
def test_get_data_base(self):
20+
x = np.random.rand(1000)
21+
np.testing.assert_array_almost_equal(x, get_data_base(x))
22+
23+
def test_arrays_share_data(self):
24+
x = np.random.rand(1000)
25+
y = np.random.rand(1000)
26+
assert not arrays_share_data(x, y)
27+
assert arrays_share_data(x, x)

frites/workflow/tests/test_wf_mi.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ class TestWfMi(object): # noqa
3030

3131
def test_definition(self):
3232
"""Test workflow definition."""
33-
WfMi(mi_type='cc', inference='rfx')
33+
y, gt = sim_mi_cc(x, snr=1.)
34+
dt = DatasetEphy(x, y, roi, times=time)
35+
wf = WfMi(mi_type='cc', inference='rfx')
36+
wf.fit(dt, **kw_mi)
37+
wf.tvalues
3438

3539
def test_mi_cc(self):
3640
"""Test method fit."""
@@ -93,3 +97,13 @@ def test_no_stat(self):
9397
wf.fit(dt, mcp='maxstat', **kw_mi)
9498
t_end_2 = tst()
9599
assert t_end_1 - t_start_1 > t_end_2 - t_start_2
100+
101+
def test_conjunction_analysis(self):
102+
"""Test the conjunction analysis."""
103+
y, gt = sim_mi_cc(x, snr=1.)
104+
dt = DatasetEphy(x, y, roi, times=time)
105+
wf = WfMi(mi_type='cc', inference='rfx')
106+
mi, pv = wf.fit(dt, **kw_mi)
107+
cj_ss, cj = wf.conjunction_analysis(dt)
108+
assert cj_ss.shape == (n_subjects, n_times, n_roi)
109+
assert cj.shape == (n_times, n_roi)

0 commit comments

Comments
 (0)