In [None]:
%reload_ext autoreload
%autoreload 3
from src.acnets.pipeline import Parcellation
from src.acnets.pipeline import TimeseriesAggregator, ConnectivityExtractor
from src.acnets.pipeline import ConnectivityAggregator, ConnectivityVectorizer
import pandas as pd
import numpy as np

from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline


from sklearn.pipeline import make_union

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.feature_selection import SelectFromModel, VarianceThreshold
from sklearn.svm import LinearSVC, SVC
from sklearn.ensemble import RandomForestClassifier

from sklearn.base import BaseEstimator, TransformerMixin
from tqdm import tqdm
from sklearn.inspection import permutation_importance
from sklearn.model_selection import cross_val_score



In [None]:
pipelines = {}

In [None]:
# H1

class ExtractH1Features(TransformerMixin, BaseEstimator):
    def __init__(self):
        pass

    def fit(self, dataset, y=None):
        return self

    def transform(self, dataset):
        self.feature_names = dataset['timeseries'].coords['region'].values

        features = dataset['timeseries'].mean('timepoint').values

        return features

    def get_feature_names_out(self, input_features):
        return self.feature_names

pipelines['h1'] = Pipeline([
    ('extract_features', ExtractH1Features()),
    # TODO normalize timeseries
])


In [None]:
# H2
# within-network connectivity

class ExtractH2Features(TransformerMixin, BaseEstimator):
    def __init__(self):
        pass

    def fit(self, dataset, y=None):
        return self

    def transform(self, dataset):
        node_type = dataset['connectivity'].dims[-1]
        self.feature_names = dataset['connectivity'].coords[node_type].values.tolist()

        conn_vec = np.array([np.diag(conn)
                                  for conn in dataset['connectivity'].values])

        return conn_vec

    def get_feature_names_out(self, input_features):
        return self.feature_names

pipelines['h2'] = Pipeline([
    ('aggregate_ts', TimeseriesAggregator(strategy=None)),
    ('extract_conn', ConnectivityExtractor(kind='partial correlation')),
    ('aggregate_conn', ConnectivityAggregator(strategy='network')),
    ('extract_features', ExtractH2Features())
])


In [None]:
# H3: between-network connectivity
# non-diagonal connectivity between networks (shape: N_networks * N_networks / 2)

class ExtractH3Features(TransformerMixin, BaseEstimator):
    def __init__(self, k=0):
        self.k = k

    def fit(self, dataset, y=None):
        return self

    def transform(self, dataset):
        conns = dataset['connectivity'].values
        conn_vectorized = np.array([conn[np.triu_indices(conn.shape[0], k=self.k)]
                           for conn in conns])

        node_type = dataset['connectivity'].dims[-1][:-4]

        self.feature_names = pd.DataFrame(
            data=np.zeros((conns.shape[1], conns.shape[2])),
            columns=dataset[node_type + '_src'],
            index=dataset[node_type + '_dst'])

        sep = ' \N{left right arrow} '
        self.feature_names = (self.feature_names
                                  .stack().to_frame()
                                  .apply(
                                      lambda x: sep.join(x.name), axis=1)
                                  .unstack()).values
        self.feature_names = self.feature_names[np.triu_indices(self.feature_names.shape[0],
                                                                k=self.k)].tolist()

        return conn_vectorized

    def get_feature_names_out(self, input_features):
        return self.feature_names

pipelines['h3'] = Pipeline([
    ('aggregate_ts', TimeseriesAggregator(strategy='network')),
    ('extract_conn', ConnectivityExtractor(kind='partial correlation')),
    ('extract_features', ExtractH3Features())
    
])


In [None]:
# Combined model

# Input/Output
parcellation = Parcellation(atlas_name='dosenbach2010').fit()
subjects = parcellation.dataset_.coords['subject'].values
subject_labels = [s[:4] for s in subjects]  
X = subjects.reshape(-1,1)                     # subjects, shape: (n_subjects, 1)
y_encoder = LabelEncoder()
y = y_encoder.fit_transform(subject_labels)     # labels, shape: (n_subjects,)

from sklearn.decomposition import PCA
from sklearn.decomposition import FastICA

from xgboost import XGBClassifier

model  = Pipeline([
    ('parcellation', Parcellation(atlas_name='dosenbach2010')),
    ('extract_features', make_union(*pipelines.values())),
    ('scale', StandardScaler()),
    ('zerovar', VarianceThreshold()),
    ('clf', XGBClassifier())
    # ('ica', FastICA(n_components=20)),
    # ('select', SelectFromModel(RandomForestClassifier(),
    #                            max_features=lambda x: min(10, x.shape[1]))),
    # ('clf', RandomForestClassifier())
    # ('select', SelectFromModel(LinearSVC(penalty='l2', dual=False, max_iter=10000),
    #                            max_features=lambda x: min(10, x.shape[1]))),
    # ('clf', LinearSVC(penalty='l2', dual=False, max_iter=10000))
])

# DEBUG (expected to overfit, i.e., score=1)
overfit_score = model.fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

In [None]:

CV = StratifiedShuffleSplit(n_splits=20, test_size=8)
cross_val_score(model, X, y, cv=CV, verbose=3).mean()

In [None]:
feature_importance_results = []
feature_names = model[:2].get_feature_names_out()

permutation_cv = StratifiedShuffleSplit(n_splits=20, test_size=8)

for train, test in tqdm(permutation_cv.split(X,y), total=permutation_cv.get_n_splits(X,y)):
    model.fit(X[train], y[train])
    X_features = model[:2].transform(X)

    _results = permutation_importance(model[2:], X_features[test], y[test],
                                    scoring='accuracy')
    feature_importance_results.append(_results)


In [None]:
importances = pd.DataFrame(
    data=np.stack([imp['importances_mean'] for imp in feature_importance_results]),
    columns=feature_names).mean().sort_values(ascending=False).to_frame('importance')

importances[:10]  # top 10 features