|
11 | 11 | from frites.dataset import SubjectEphy
|
12 | 12 | from frites.dataset.ds_utils import multi_to_uni_conditions
|
13 | 13 | from frites.core import copnorm_cat_nd, copnorm_nd
|
14 |
| -from frites.utils import savgol_filter |
| 14 | +from frites.conn.conn_utils import conn_get_pairs |
| 15 | +from frites.utils import savgol_filter, nonsorted_unique |
15 | 16 |
|
16 | 17 |
|
17 | 18 |
|
@@ -161,8 +162,7 @@ def _update_internals(self):
|
161 | 162 | """Update internal variables."""
|
162 | 163 | # build a unique list of unsorted roi names
|
163 | 164 | merged_roi = np.r_[tuple([k['roi'].data for k in self._x])]
|
164 |
| - _, u_idx = np.unique(merged_roi, return_index=True) |
165 |
| - roi_names = merged_roi[np.sort(u_idx)] |
| 165 | + roi_names = nonsorted_unique(merged_roi) |
166 | 166 |
|
167 | 167 | # dataframe made of unique roi per subjects and subject id
|
168 | 168 | suj_r, roi_r = [], []
|
@@ -403,41 +403,19 @@ def get_connectivity_pairs(self, as_blocks=False, directed=False,
|
403 | 403 | targets : array_like
|
404 | 404 | Indices of the target
|
405 | 405 | """
|
406 |
| - set_log_level(verbose) |
407 |
| - bad, df_rs, nb_min_suj = [], self._df_rs, self._nb_min_suj |
408 |
| - n_roi_full = len(df_rs.index) |
409 |
| - # get all possible pairs |
410 |
| - if directed: |
411 |
| - pairs = np.where(~np.eye(n_roi_full, dtype=bool)) |
412 |
| - else: |
413 |
| - pairs = np.triu_indices(n_roi_full, k=1) |
414 |
| - # remove pairs where there's not enough subjects |
415 |
| - s_new, t_new = [], [] |
416 |
| - for s, t in zip(pairs[0], pairs[1]): |
417 |
| - n_suj_s = int(df_rs.iloc[s]['#subjects']) |
418 |
| - n_suj_t = int(df_rs.iloc[t]['#subjects']) |
419 |
| - if min(n_suj_s, n_suj_t) >= nb_min_suj: |
420 |
| - s_new += [s] |
421 |
| - t_new += [t] |
422 |
| - else: |
423 |
| - bad += [f"{str(df_rs.index[s])}-{str(df_rs.index[t])}"] |
424 |
| - if len(bad): |
425 |
| - logger.warning("The following connectivity pairs are going to " |
426 |
| - "be ignored because the number of subjects is " |
427 |
| - f"bellow {nb_min_suj} : {bad}") |
428 |
| - pairs = (np.asarray(s_new), np.asarray(t_new)) |
429 |
| - logger.info(f" {len(pairs[0])} remaining connectivity pairs / " |
430 |
| - f"{len(bad)} pairs have been ignored " |
431 |
| - f"(nb_min_suj={nb_min_suj})") |
| 406 | + rois = [k['roi'].data.astype(str) for k in self.x] |
| 407 | + # get the dataframe for connectivity |
| 408 | + self.df_conn = conn_get_pairs( |
| 409 | + rois, directed=directed, nb_min_suj=self._nb_min_suj, |
| 410 | + verbose=verbose) |
| 411 | + df_conn = self.df_conn.loc[self.df_conn['keep']] |
| 412 | + df_conn = df_conn.drop(columns='keep') |
| 413 | + |
| 414 | + # group by sources |
432 | 415 | if as_blocks:
|
433 |
| - blocks_s, blocks_t, u_sources = [], [], np.unique(pairs[0]) |
434 |
| - for s in u_sources: |
435 |
| - blocks_s += [s] |
436 |
| - blocks_t += [pairs[1][pairs[0] == s].tolist()] |
437 |
| - pairs = (blocks_s, blocks_t) |
438 |
| - |
| 416 | + df_conn = df_conn.groupby('sources').agg(list).reset_index() |
439 | 417 |
|
440 |
| - return pairs[0], pairs[1] |
| 418 | + return df_conn |
441 | 419 |
|
442 | 420 |
|
443 | 421 | ###########################################################################
|
|
0 commit comments