Skip to content

Commit bfdf2db

Browse files
committed
Adapt DatasetEphy to get the pair of ROI
1 parent 8b955a1 commit bfdf2db

File tree

1 file changed

+14
-36
lines changed

1 file changed

+14
-36
lines changed

frites/dataset/ds_ephy.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from frites.dataset import SubjectEphy
1212
from frites.dataset.ds_utils import multi_to_uni_conditions
1313
from frites.core import copnorm_cat_nd, copnorm_nd
14-
from frites.utils import savgol_filter
14+
from frites.conn.conn_utils import conn_get_pairs
15+
from frites.utils import savgol_filter, nonsorted_unique
1516

1617

1718

@@ -161,8 +162,7 @@ def _update_internals(self):
161162
"""Update internal variables."""
162163
# build a unique list of unsorted roi names
163164
merged_roi = np.r_[tuple([k['roi'].data for k in self._x])]
164-
_, u_idx = np.unique(merged_roi, return_index=True)
165-
roi_names = merged_roi[np.sort(u_idx)]
165+
roi_names = nonsorted_unique(merged_roi)
166166

167167
# dataframe made of unique roi per subjects and subject id
168168
suj_r, roi_r = [], []
@@ -403,41 +403,19 @@ def get_connectivity_pairs(self, as_blocks=False, directed=False,
403403
targets : array_like
404404
Indices of the target
405405
"""
406-
set_log_level(verbose)
407-
bad, df_rs, nb_min_suj = [], self._df_rs, self._nb_min_suj
408-
n_roi_full = len(df_rs.index)
409-
# get all possible pairs
410-
if directed:
411-
pairs = np.where(~np.eye(n_roi_full, dtype=bool))
412-
else:
413-
pairs = np.triu_indices(n_roi_full, k=1)
414-
# remove pairs where there's not enough subjects
415-
s_new, t_new = [], []
416-
for s, t in zip(pairs[0], pairs[1]):
417-
n_suj_s = int(df_rs.iloc[s]['#subjects'])
418-
n_suj_t = int(df_rs.iloc[t]['#subjects'])
419-
if min(n_suj_s, n_suj_t) >= nb_min_suj:
420-
s_new += [s]
421-
t_new += [t]
422-
else:
423-
bad += [f"{str(df_rs.index[s])}-{str(df_rs.index[t])}"]
424-
if len(bad):
425-
logger.warning("The following connectivity pairs are going to "
426-
"be ignored because the number of subjects is "
427-
f"bellow {nb_min_suj} : {bad}")
428-
pairs = (np.asarray(s_new), np.asarray(t_new))
429-
logger.info(f" {len(pairs[0])} remaining connectivity pairs / "
430-
f"{len(bad)} pairs have been ignored "
431-
f"(nb_min_suj={nb_min_suj})")
406+
rois = [k['roi'].data.astype(str) for k in self.x]
407+
# get the dataframe for connectivity
408+
self.df_conn = conn_get_pairs(
409+
rois, directed=directed, nb_min_suj=self._nb_min_suj,
410+
verbose=verbose)
411+
df_conn = self.df_conn.loc[self.df_conn['keep']]
412+
df_conn = df_conn.drop(columns='keep')
413+
414+
# group by sources
432415
if as_blocks:
433-
blocks_s, blocks_t, u_sources = [], [], np.unique(pairs[0])
434-
for s in u_sources:
435-
blocks_s += [s]
436-
blocks_t += [pairs[1][pairs[0] == s].tolist()]
437-
pairs = (blocks_s, blocks_t)
438-
416+
df_conn = df_conn.groupby('sources').agg(list).reset_index()
439417

440-
return pairs[0], pairs[1]
418+
return df_conn
441419

442420

443421
###########################################################################

0 commit comments

Comments
 (0)