In [None]:
%reload_ext autoreload
%autoreload 3
from src.acnets.pipeline import Parcellation
import pandas as pd
import numpy as np

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
from sklearn.inspection import permutation_importance

from tqdm import tqdm

from src.acnets.pipeline import MultiScaleClassifier

In [None]:
# Combined model

# Input/Output
parcellation = Parcellation(atlas_name='dosenbach2010').fit()
subjects = parcellation.dataset_.coords['subject'].values
subject_labels = [s[:4] for s in subjects]  
X = subjects.reshape(-1,1)                     # subjects, shape: (n_subjects, 1)
y_encoder = LabelEncoder()
y = y_encoder.fit_transform(subject_labels)     # labels, shape: (n_subjects,)


model = MultiScaleClassifier()

# DEBUG (expected to overfit, i.e., score=1)
overfit_score = model.fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

In [None]:

CV = StratifiedShuffleSplit(n_splits=10, test_size=8)
cross_val_score(model, X, y, cv=CV, verbose=3).mean()

In [None]:
feature_importance_results = []
feature_names = model.get_encoder().get_feature_names_out()

permutation_cv = StratifiedShuffleSplit(n_splits=10, test_size=8)

for train, test in tqdm(permutation_cv.split(X,y), total=permutation_cv.get_n_splits(X,y)):
    model.fit(X[train], y[train])
    X_features = model.get_encoder().transform(X)

    _results = permutation_importance(model.get_classification_head(), X_features[test], y[test],
                                    scoring='accuracy')
    feature_importance_results.append(_results)


In [None]:
importances = pd.DataFrame(
    data=np.stack([imp['importances_mean'] for imp in feature_importance_results]),
    columns=feature_names).mean().sort_values(ascending=False).to_frame('importance')

importances[:10]  # top 10 features