In [None]:
%reload_ext autoreload
%autoreload 3
from src.acnets.pipeline import Parcellation
import pandas as pd
import numpy as np

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
from sklearn.inspection import permutation_importance

from tqdm import tqdm
from joblib import delayed, Parallel

from IPython.display import clear_output

from src.acnets.pipeline import MultiScaleClassifier

from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch

from ray.tune.sklearn import TuneSearchCV
from sklearn.model_selection import GridSearchCV
from scipy import stats

In [None]:
# Input/Output
parcellation = Parcellation(atlas_name='dosenbach2010')
subjects = parcellation.fit_transform(None).coords['subject'].values
subject_labels = [s[:4] for s in subjects]
X = subjects.reshape(-1,1)                     # subjects, shape: (n_subjects, 1)
y_encoder = LabelEncoder()
y = y_encoder.fit_transform(subject_labels)     # labels, shape: (n_subjects,)


# DEBUG (expected to overfit, i.e., score=1)
overfit_score = MultiScaleClassifier().fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

In [None]:

param_space = {
    'clf__objective': ['binary:logistic'],
    'clf__max_depth': tune.randint(1, 20),
    'clf__min_child_weight': tune.randint(1, 10),
    'clf__subsample': tune.uniform(0.01, 1.0),
    'clf__eta': tune.loguniform(1e-4, 1e-1),
    'clf__learning_rate': (0.01, 1.0, 'log-uniform'),
    'clf__n_estimators': tune.choice([25, 50, 100, 1000]),
    # 'clf__max_depth': (0, 50),
    # 'clf__max_delta_step': (0, 20),
    # 'clf__reg_lambda': (1e-9, 1000, 'log-uniform'),
    # 'clf__reg_alpha': (1e-9, 1.0, 'log-uniform'),
    # 'clf__gamma': (1e-9, 0.5, 'log-uniform'),
    # 'clf__scale_pos_weight': (1e-6, 500, 'log-uniform')
}

tuner = TuneSearchCV(
    MultiScaleClassifier(),
    HyperOptSearch.convert_search_space(param_space),
    scoring='accuracy',
    cv=StratifiedShuffleSplit(n_splits=5, test_size=16),
    search_optimization='hyperopt',
    n_jobs=-1,
    refit=True,
    n_trials=10,
    verbose=2,
)

tuner.fit(X, y)
clear_output()

# create a model with best params
MultiScaleClassifier().set_params(**tuner.best_params_)

In [None]:
cv = StratifiedShuffleSplit(n_splits=10, test_size=8)

model = MultiScaleClassifier().set_params(**tuner.best_params_)
cv_scores = cross_val_score(model, X, y, cv=cv, verbose=3, n_jobs=-1)
bootstrap_ci = stats.bootstrap(cv_scores.reshape(1,-1), np.mean)


clear_output(wait=True)
print(f'Test accuracy (mean ± std): {cv_scores.mean():.2f} ± {cv_scores.std():.2f}')
print(bootstrap_ci.confidence_interval)

In [None]:
# Permutation Feature Importance

model = MultiScaleClassifier().set_params(**tuner.best_params_)
feature_names = model.fit(X, y).get_feature_names_out()

def do_permutation_importance(estimator, X_train, y_train, X_test, y_test, scoring='accuracy'):
    """Perform permutation importance analysis on a given estimator."""
    estimator.fit(X_train, y_train)
    X_test_features = estimator.get_feature_extractor().transform(X_test)
    results = permutation_importance(estimator.get_classification_head(),
                                     X_test_features,
                                     y_test,
                                     n_jobs=-1,
                                     scoring=scoring)
    return results['importances_mean']


# run permutation importance in parallel
permutation_cv = StratifiedShuffleSplit(n_splits=10, test_size=8)
importance_scores = Parallel(n_jobs=-1, verbose=2)(
    delayed(do_permutation_importance)(
        estimator = model,
        X_train = X[train],
        y_train = y[train],
        X_test = X[test],
        y_test = y[test])
    for train, test in permutation_cv.split(X, y)
)

# convert to dataframe and sort
importance_scores = pd.DataFrame(
    data=np.stack(importance_scores, axis=0),
    columns=feature_names).mean().sort_values(ascending=False).to_frame('importance')

importance_scores[:20]  # top 20 featuresp