In [1]:
%reload_ext autoreload
%autoreload 3
from src.acnets.pipeline import Parcellation
import pandas as pd
import numpy as np

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
from sklearn.inspection import permutation_importance

from tqdm import tqdm

from src.acnets.pipeline import MultiScaleClassifier

In [2]:
# Combined model

# Input/Output
parcellation = Parcellation(atlas_name='dosenbach2010')
subjects = parcellation.fit_transform(None).coords['subject'].values
subject_labels = [s[:4] for s in subjects]
X = subjects.reshape(-1,1)                     # subjects, shape: (n_subjects, 1)
y_encoder = LabelEncoder()
y = y_encoder.fit_transform(subject_labels)     # labels, shape: (n_subjects,)


model = MultiScaleClassifier()

# DEBUG (expected to overfit, i.e., score=1)
overfit_score = model.fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

[DEBUG] overfit accuracy: 1.000


In [10]:
cv = StratifiedShuffleSplit(n_splits=100, test_size=8)

cv_scores = cross_val_score(model, X, y, cv=cv, verbose=0, n_jobs=-1)

print(f'CV accuracy: {cv_scores.mean():.3f} +/- {cv_scores.std():.3f}')

CV accuracy: 0.703 +/- 0.127


In [None]:
feature_importance_results = []
# TODO feature_names = model.get_feature_extractor_head().get_feature_names_out()

permutation_cv = StratifiedShuffleSplit(n_splits=100, test_size=8)

for train, test in tqdm(permutation_cv.split(X,y), total=permutation_cv.get_n_splits(X,y)):
    model.fit(X[train], y[train])
    X_features = model.get_feature_extractor_head().transform(X)

    _results = permutation_importance(model.get_classification_head(), X_features[test], y[test],
                                      n_jobs=-1,
                                    scoring='accuracy')
    feature_importance_results.append(_results)


In [6]:
# TODO

importances = pd.DataFrame(
    data=np.stack([imp['importances_mean'] for imp in feature_importance_results]),
    columns=feature_names).mean().sort_values(ascending=False).to_frame('importance')

importances[:10]  # top 10 features

Unnamed: 0,importance
h2__default,0.1925
h3__fronto-parietal ↔ sensorimotor,0.0325
h1__med cerebellum 143,0.025
h1__post cingulate 108,0.0225
h2__fronto-parietal,0.0175
h2__cerebellum,0.0125
h1__occipital 142,0.01
h3__default ↔ sensorimotor,0.01
h3__cerebellum ↔ occipital,0.0075
h2__cingulo-opercular,0.0075
