Skip to content

Commit 3ec033e

Browse files
committed
change FIT input directed for net + fix testing for it
1 parent 0688cf1 commit 3ec033e

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

frites/workflow/tests/test_wf_fit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ def test_mi_methods(self):
5252
WfFit(mi_type='cc', inference='ffx', mi_method=meth).fit(
5353
ds, **kw_fit)
5454

55-
def test_directed(self):
56-
# directed
55+
def test_biunidirected(self):
56+
# bidirected
5757
ds = DatasetEphy(x, y, roi=roi, times=times)
5858
fd = WfFit(mi_type='cd').fit(
59-
ds, directed=True, output_type='2d_array', **kw_fit)[0]
60-
# non-directed
59+
ds, net=False, output_type='2d_array', **kw_fit)[0]
60+
# unidirected
6161
ds = DatasetEphy(x, y, roi=roi, times=times)
6262
nd = WfFit(mi_type='cd').fit(
63-
ds, directed=False, output_type='2d_array', **kw_fit)[0]
63+
ds, net=True, output_type='2d_array', **kw_fit)[0]
6464
assert fd.shape[1] > nd.shape[1]
6565

6666
def test_output_type(self):

frites/workflow/wf_fit.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, mi_type='cc', inference='rfx', gcrn_per_suj=True,
7676
f"Workflow for computing the FIT ({mi_type} - {mi_method}) and "
7777
f"statistics ({inference}) has been defined")
7878

79-
def _node_compute_fit(self, dataset, n_perm, n_jobs, max_delay, directed,
79+
def _node_compute_fit(self, dataset, n_perm, n_jobs, max_delay, net,
8080
random_state):
8181
# ---------------------------------------------------------------------
8282
# compute mi and permuted mi
@@ -92,15 +92,15 @@ def _node_compute_fit(self, dataset, n_perm, n_jobs, max_delay, directed,
9292
# ---------------------------------------------------------------------
9393
# get the number of pairs (source, target)
9494
n_roi = len(mi)
95-
if directed:
95+
if not net:
9696
all_s, all_t = np.where(~np.eye(n_roi, dtype=bool))
9797
tail = 1
9898
else:
9999
all_s, all_t = np.triu_indices(n_roi, k=1)
100100
tail = 0 # two tail test
101-
# pairs = np.c_[all_s, all_t]
102-
logger.info(f" Compute FIT (directed={directed}; max_delay="
103-
f"{max_delay}; n_pairs={len(all_s)})")
101+
direction = 'bidirectional' if not net else 'unidirectional'
102+
logger.info(f" Compute {direction} FIT (max_delay={max_delay}; "
103+
f"n_pairs={len(all_s)})")
104104
# get the unique subjects across roi (depends on inference type)
105105
inference = self._inference
106106
if inference is 'ffx':
@@ -112,7 +112,7 @@ def _node_compute_fit(self, dataset, n_perm, n_jobs, max_delay, directed,
112112
cfg_jobs = config.CONFIG["JOBLIB_CFG"]
113113
arch = Parallel(n_jobs=n_jobs, **cfg_jobs)(delayed(fcn_fit)(
114114
mi[s], mi[t], mi_p[s], mi_p[t], sujr[s], sujr[t], times,
115-
max_delay, directed, inference) for s, t in zip(all_s, all_t))
115+
max_delay, net, inference) for s, t in zip(all_s, all_t))
116116
fit_roi, fitp_roi, fit_m = [list(k) for k in zip(*arch)]
117117
"""
118118
For sEEG data, the ROI repartition is not the same across subjects.
@@ -138,7 +138,7 @@ def _node_compute_fit(self, dataset, n_perm, n_jobs, max_delay, directed,
138138
self._fit_roi, self._fitp_roi = fit_roi, fitp_roi
139139
self._fit_m = fit_m
140140

141-
def fit(self, dataset, max_delay=0.3, directed=True, level='cluster',
141+
def fit(self, dataset, max_delay=0.3, net=False, level='cluster',
142142
mcp='maxstat', cluster_th=None, cluster_alpha=0.05, n_perm=1000,
143143
n_jobs=-1, random_state=None, output_type='3d_dataframe',
144144
**kw_stats):
@@ -161,9 +161,11 @@ def fit(self, dataset, max_delay=0.3, directed=True, level='cluster',
161161
A dataset instance
162162
max_delay : float | 0.3
163163
Maximum delay to use for defining the past of the source and target
164-
directed : bool | True
165-
Use either a directed FIT (True) or un-directed (False) which is
166-
defined as FIT(s -> t) - FIT(t -> s)
164+
net : bool | False
165+
Compute either the bidirectional FIT (i.e A->B and B->A which
166+
correspond to `net=False`) either the unidirectional FIT
167+
(i.e A->B - B->A which correspond to `net=True`). By default, the
168+
bidirectional FIT is computed.
167169
level : {'testwise', 'cluster'}
168170
Inference level. If 'testwise', inferences are made for each region
169171
of interest and at each time point. If 'cluster', cluster-based
@@ -228,7 +230,7 @@ def fit(self, dataset, max_delay=0.3, directed=True, level='cluster',
228230
# compute fit (only if not already computed)
229231
if not len(self._fit_roi):
230232
self._node_compute_fit(dataset, n_perm, n_jobs, max_delay,
231-
directed, random_state)
233+
net, random_state)
232234
else:
233235
logger.warning(" True and permuted FIT already computed. "
234236
"Use WfFit.clean() to reset arguments")
@@ -325,7 +327,7 @@ def mi_p(self):
325327
return self._wf_mi._mi_p
326328

327329

328-
def fcn_fit(x_s, x_t, xp_s, xp_t, suj_s, suj_t, times, max_delay, directed,
330+
def fcn_fit(x_s, x_t, xp_s, xp_t, suj_s, suj_t, times, max_delay, net,
329331
inference):
330332
"""Compute FIT in parallel."""
331333
# find the unique list of subjects for the source and target
@@ -349,15 +351,15 @@ def fcn_fit(x_s, x_t, xp_s, xp_t, suj_s, suj_t, times, max_delay, directed,
349351
# FIT on true and permuted gcmi
350352
_fit_suj = it_fit(x_s_suj, x_t_suj, times, max_delay)[0, ...]
351353
_fitp_suj = it_fit(xp_s_suj, xp_t_suj, times, max_delay)
352-
# compute unidirected FIT
353-
if not directed:
354+
# compute unidirectional FIT
355+
if net:
354356
# compute target -> source
355357
_fit_ts = it_fit(x_t_suj, x_s_suj, times, max_delay)[0, ...]
356358
_fitp_ts = it_fit(xp_t_suj, xp_s_suj, times, max_delay)
357359
# subtract to source -> target
358360
_fit_suj -= _fit_ts
359361
_fitp_suj -= _fitp_ts
360-
# keep the computed (uni/bi)directed FIT
362+
# keep the computed (uni/bi) directed FIT
361363
fit_suj += [_fit_suj]
362364
fitp_suj += [_fitp_suj]
363365
# if not subjects, return empty arrays

0 commit comments

Comments
 (0)