**Note:** use pyspi conda environment to run this code.

In [None]:
from tqdm.auto import tqdm
import xarray as xr
from pyspi.calculator import CalculatorFrame
from pyspi.data import Data

from src.multimodal.preprocessing import TimeseriesAggregator

# we only aggregate the region-wise timeseries into network-wise timeseries
preproc_pipe = TimeseriesAggregator(strategy='network')

atlas = 'dosenbach2010'
connectivity_kind = 'tangent'

with xr.open_dataset(f'data/Julia2018/timeseries_{atlas}.nc5') as ds:
    ds.load()
    ds = preproc_pipe.fit_transform(ds)


datasets = []
for subject in ds.coords['subject'].values:
    ts = ds.sel(subject=subject)['timeseries'].values.T
    feature_names = ds['network'].values
    dataset = Data(ts, procnames=feature_names, name=subject)
    datasets.append(dataset)

calc = CalculatorFrame(datasets=datasets, subset='fast',
                       name=f'Julia2018_{atlas}',
                       names=[d.name for d in datasets])
calc.compute()

In [130]:

# spis = spi_calc.table
# spis = s.columns.get_level_values(0).unique()

# spi_calc._get_correlation_df()

tables = {
    c.name: c.table
    for i, c in calc.calculators.itertuples()
}

import pandas as pd

spis = []

for i, c in tqdm(calc.calculators.itertuples(), total=calc.n_calculators):
    feats = c.table
    feats.index.name = 'process_1'
    feats = feats.reset_index()
    feats.columns.names = ['spi', 'process_2']
    melted = pd.melt(feats, id_vars='process_1', var_name=['spi', 'process_2'], value_name='value')
    melted['process'] = melted.apply(lambda x: set(x[['process_1', 'process_2']]), axis=1)
    melted = melted.groupby('spi').apply(lambda x: x.drop_duplicates('process'))
    # melted.dropna(subset=['value'], inplace=True)
    melted['process'] = melted['process_1'] + '-' + melted['process_2']
    melted.drop(columns=['spi', 'process_1', 'process_2'], inplace=True)
    melted.reset_index(level=0, inplace=True)
    melted.reset_index(drop=True, inplace=True)
    melted = melted.assign(subject=c.name, label=c.name[:4])
    spis.append(melted)
spi_df = pd.concat(spis)

100%|██████████| 32/32 [01:40<00:00,  3.13s/it]


In [136]:
from sklearn.model_selection import cross_val_score, StratifiedShuffleSplit
from sklearn.svm import SVC
from sklearn.preprocessing import LabelEncoder

def score_spi(s):
    estimator = SVC(kernel='linear', C=1)
    X = s.drop(columns=['subject', 'spi', 'label']).values
    y = s['label'].values
    y = LabelEncoder().fit_transform(y)
    CV = StratifiedShuffleSplit(n_splits=100, test_size=8)
    score = cross_val_score(estimator, X, y, cv=CV, groups=y, n_jobs=-1, scoring='accuracy')
    return score.mean()

spi_df_wide = spi_df.pivot_table(index=['subject', 'label', 'spi'], columns=['process'], values='value', aggfunc='mean').reset_index()
spi_df_wide.groupby(['spi']).apply(score_spi).sort_values(ascending=False)



100 fits failed out of a total of 100.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
100 fits failed with the following error:
Traceback (most recent call last):
  File "/home/morteza/miniforge3/envs/pyspi/lib/python3.9/site-packages/sklearn/model_selection/_validation.py", line 681, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/home/morteza/miniforge3/envs/pyspi/lib/python3.9/site-packages/sklearn/svm/_base.py", line 190, in fit
    X, y = self._validate_data(
  File "/home/morteza/miniforge3/envs/pyspi/lib/python3.9/site-packages/sklearn/base.py", line 576, in _validate_data
    X, y = check_X_y(X, y, **check_params)
  File "/home/morteza/miniforge3/envs/pyspi/lib/python3.9/site-packages/sklearn/utils/

spi
phase_multitaper_max_fs-1_fmin-0_fmax-0-5                0.74250
sgc_parametric_mean_fs-1_fmin-1e-05_fmax-0-5_order-1     0.70875
ddtf_multitaper_mean_fs-1_fmin-0_fmax-0-5                0.69750
cov_GraphicalLassoCV                                     0.69500
prec_OAS                                                 0.69375
                                                          ...   
gd_multitaper_delay_fs-1_fmin-0_fmax-0-5                     NaN
sgc_parametric_max_fs-1_fmin-0_fmax-0-25_order-20            NaN
sgc_parametric_max_fs-1_fmin-1e-05_fmax-0-5_order-20         NaN
sgc_parametric_mean_fs-1_fmin-0_fmax-0-25_order-20           NaN
sgc_parametric_mean_fs-1_fmin-1e-05_fmax-0-5_order-20        NaN
Length: 216, dtype: float64