Skip to content

Commit aab68e1

Browse files
committed
rfx_center and rfx_sigma input parameters to better control RFX
1 parent abea667 commit aab68e1

File tree

4 files changed

+61
-36
lines changed

4 files changed

+61
-36
lines changed

frites/stats/stats_param.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,25 @@
88
logger = logging.getLogger("frites")
99

1010

11-
RECENTER = dict(mean=np.mean, median=np.median,
12-
trimmed=lambda x, axis=0: trim_mean(x, .2, axis=axis))
11+
def _trimmed(x, prop=.2, axis=0, keepdims=False):
12+
trm = trim_mean(x, prop, axis=axis)
13+
if keepdims:
14+
ax = [slice(None)] * x.ndim
15+
ax[axis] = np.newaxis
16+
trm = trm[tuple(ax)]
17+
return trm
18+
19+
20+
RECENTER = {
21+
'mean': np.mean, 'median': np.median, 'zscore': np.mean,
22+
'trimmed': _trimmed
23+
}
24+
25+
26+
def _recenter(x, fcn_mean, zscore, axis=-1):
27+
"""Recentering function."""
28+
_std = x.std(axis=axis, keepdims=True) if zscore else 1.
29+
return (x - fcn_mean(x, axis=axis, keepdims=True)) / _std
1330

1431

1532
def ttest_1samp(x, pop_mean, axis=0, implementation='mne', sigma=0.001, **kw):
@@ -50,7 +67,7 @@ def fcn(x, pop_mean, axis): # noqa
5067
return fcn(x, pop_mean, axis)
5168

5269

53-
def rfx_ttest(mi, mi_p, center=False, zscore=False, ttested=False):
70+
def rfx_ttest(mi, mi_p, center=False, sigma=0.001, ttested=False):
5471
"""Perform the t-test across subjects.
5572
5673
Parameters
@@ -61,12 +78,12 @@ def rfx_ttest(mi, mi_p, center=False, zscore=False, ttested=False):
6178
mi_p : array_like
6279
A list of array of permuted mutual information of shape
6380
(n_perm, n_suj, n_times). If `ttested` is True, n_suj shoud be 1.
64-
center : {'mean', 'median', 'trimmed'} | False
65-
If True, substract the mean of the surrogates to the true and permuted
66-
mi. The median or the 20% trimmed mean can also be removed
67-
:cite:`wilcox2018guide`
68-
zscore : bool | False
69-
Apply z-score normalization to the true and permuted mutual information
81+
sigma : float | 0.001
82+
Hat adjustment method, a value of 1e-3 may be a reasonable choice
83+
center : {False, 'mean', 'median', 'trimmed', 'zscore'}
84+
Re-center the time-series of effect arround 0 before computing the
85+
t-test. This parameters can be useful in case of a different number
86+
of data per brain region.
7087
ttested : bool | False
7188
Specify if the inputs have already been t-tested
7289
@@ -93,14 +110,17 @@ def rfx_ttest(mi, mi_p, center=False, zscore=False, ttested=False):
93110
n_roi = len(mi_p)
94111

95112
# remove the mean / median / trimmed
113+
zscore = center == 'zscore'
96114
if center in RECENTER.keys():
97-
logger.info(f" RFX recenter distributions (center={center}, "
98-
f"z-score={zscore})")
99-
for k in range(len(mi)):
100-
_med = RECENTER[center](mi_p[k], axis=0)
101-
_std = mi_p[k].std(axis=0) if zscore else 1.
102-
mi[k] = (mi[k] - _med) / _std
103-
mi_p[k] = (mi_p[k] - _med) / _std
115+
# get the centering function
116+
fcn_mean = RECENTER[center]
117+
118+
# here, we need to make a copy of the effect sizes to avoid changing
119+
# the ouputs
120+
mi, mi_p = mi.copy(), mi_p.copy()
121+
for k in range(n_roi):
122+
mi[k] = _recenter(mi[k], fcn_mean, zscore, axis=-1).copy()
123+
mi_p[k] = _recenter(mi_p[k], fcn_mean, zscore, axis=-1).copy()
104124

105125
# get the mean of surrogates (low ram method)
106126
n_element = np.sum([np.prod(k.shape) for k in mi_p])
@@ -112,12 +132,14 @@ def rfx_ttest(mi, mi_p, center=False, zscore=False, ttested=False):
112132
that the MNE t-test is going to evaluate one sigma per roi. To fix that,
113133
we estimate this sigma using the variance of all of the data
114134
"""
115-
from frites.config import CONFIG
116-
s_hat = CONFIG['TTEST_MNE_SIGMA']
135+
s_hat = sigma
117136

118137
# sigma on true mi and permuted mi
119-
sigma = s_hat * max([np.var(k, axis=0, ddof=1).max() for k in mi])
120-
sigma_p = s_hat * max([np.var(k, axis=1, ddof=1).max() for k in mi_p])
138+
if s_hat > 0:
139+
sigma = s_hat * max([np.var(k, axis=0, ddof=1).max() for k in mi])
140+
sigma_p = s_hat * max([np.var(k, axis=1, ddof=1).max() for k in mi_p])
141+
else:
142+
sigma = sigma_p = 0.
121143
logger.debug(f"sigma_true={sigma}; sigma_permuted={sigma_p}")
122144
kw = dict(implementation='mne', method='absolute', sigma=sigma)
123145
kw_p = dict(implementation='mne', method='absolute', sigma=sigma_p)

frites/stats/tests/test_stats_param.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,10 @@ def test_rfx_ttest(self):
2424
assert tv.shape == (n_roi, n_times)
2525
assert tv_p.shape == (n_perm, n_roi, n_times)
2626
# center
27-
for center in [False, 'mean', 'median', 'trimmed']:
27+
for center in [False, True, 'mean', 'median', 'trimmed', 'zscore']:
2828
tv, tv_p, _ = rfx_ttest(x, x_p, center=center)
2929
assert tv.shape == (n_roi, n_times)
3030
assert tv_p.shape == (n_perm, n_roi, n_times)
31-
# zscore
32-
tv, tv_p, _ = rfx_ttest(x, x_p, zscore=True)
33-
assert tv.shape == (n_roi, n_times)
34-
assert tv_p.shape == (n_perm, n_roi, n_times)
35-
# center + zscore
36-
tv, tv_p, _ = rfx_ttest(x, x_p, zscore=True, center=True)
37-
assert tv.shape == (n_roi, n_times)
38-
assert tv_p.shape == (n_perm, n_roi, n_times)
3931
# t-tested
4032
tv, tv_p, _ = rfx_ttest(x, x_p, ttested=True)
4133
assert tv.shape == (n_roi * n_suj, n_times)

frites/workflow/wf_mi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def confidence_interval(self, dataset, ci=95, n_boots=200, rfx_es='mi',
451451
pbar = ProgressBar(range(n_boots), mesg='Estimating CI')
452452

453453
# get t-test related variables
454-
s_hat = CONFIG['TTEST_MNE_SIGMA']
454+
s_hat = self._wf_stats.attrs['ttest_sigma']
455455

456456
tt = []
457457
for n_p in range(n_boots):

frites/workflow/wf_stats.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def __init__(self, verbose=None): # noqa
2828
logger.info("Definition of a non-parametric statistical workflow")
2929

3030
def fit(self, effect, perms, inference='rfx', mcp='cluster', tail=1,
31-
cluster_th=None, cluster_alpha=0.05, ttested=False):
31+
cluster_th=None, cluster_alpha=0.05, ttested=False,
32+
rfx_sigma=0.001, rfx_center=False):
3233
"""Fit the workflow on true data.
3334
3435
Parameters
@@ -71,6 +72,12 @@ def fit(self, effect, perms, inference='rfx', mcp='cluster', tail=1,
7172
the 95th percentile of the permutations is used.
7273
ttested : bool | False
7374
Specify if the inputs have already been t-tested
75+
rfx_sigma : float | 0.001
76+
Hat adjustment method, a value of 1e-3 may be a reasonable choice
77+
rfx_center : {False, 'mean', 'median', 'trimmed', 'zscore'}
78+
Re-center the time-series of effect arround 0 before computing the
79+
t-test. This parameters can be useful in case of a different number
80+
of data per brain region.
7481
7582
Returns
7683
-------
@@ -123,11 +130,15 @@ def fit(self, effect, perms, inference='rfx', mcp='cluster', tail=1,
123130
rfx_suj = np.min(nb_suj_roi) > 1
124131
assert rfx_suj, "For RFX, `n_subjects` should be > 1"
125132
# modelise how subjects are distributed
126-
es, es_p, pop_mean = rfx_ttest(effect, perms)
127-
from frites.config import CONFIG
128-
sigma = CONFIG['TTEST_MNE_SIGMA']
129-
self.attrs.update(dict(ttest_pop_mean=pop_mean,
130-
ttest_sigma=sigma))
133+
rfx_center = 'mean' if isinstance(
134+
rfx_center, bool) and rfx_center else rfx_center
135+
es, es_p, pop_mean = rfx_ttest(
136+
effect, perms, sigma=rfx_sigma, center=rfx_center
137+
)
138+
self.attrs.update(dict(
139+
ttest_pop_mean=pop_mean, ttest_sigma=rfx_sigma,
140+
ttest_center=rfx_center
141+
))
131142
tvalues = es
132143

133144
# ---------------------------------------------------------------------

0 commit comments

Comments
 (0)