In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import ray
from IPython.display import clear_output
from joblib import Parallel, delayed
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch
from ray.tune.sklearn import TuneSearchCV
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC
from xgboost import XGBClassifier

from src.acnets.pipeline import MultiScaleClassifier, Parcellation


In [None]:
N_RUNS = 100     # 10 independent train/test runs
TEST_SIZE = .25  # 8 subjects out of 32 subjects

In [None]:
# Xy
subjects = Parcellation(atlas_name='difumo_64_2mm').fit_transform(None).coords['subject'].values
X = subjects.reshape(-1,1)                                  # subjects ids, shape: (n_subjects, 1)

y_encoder = LabelEncoder()
y = y_encoder.fit_transform([s[:4] for s in subjects])      # labels (AVGP=1 or NVGP=1), shape: (n_subjects,)
y_mapping = dict(zip(y_encoder.classes_, y_encoder.transform(y_encoder.classes_)))

# DEBUG (report label mapping)
print('[DEBUG] label mapping:', y_mapping)

# DEBUG (expected to overfit, i.e., accuracy is 1)
overfit_score = MultiScaleClassifier().fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

In [None]:
ray.shutdown(); ray.init()

xgb_config = {
    'clf': XGBClassifier(),
    'params': {
        'atlas': ['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm'],
        'clf__base_score': [.5],
        'clf__objective': ['binary:logistic'],
        'clf__n_estimators': tune.randint(100, 500),
        'clf__max_depth': tune.randint(1, 8),
        'clf__learning_rate': tune.uniform(0.01, 0.1),
        # TODO 'clf__colsample_bytree': tune.uniform(0.01, 1.0),
        # TODO 'clf__bagging_fraction': tune.uniform(0.01, 1.0),
        # TODO 'clf__min_child_weight': tune.randint(1, 10),
    }
}

rfc_config = {
    'clf': RandomForestClassifier(),
    'params': {
        'atlas': ['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm'],
        'clf__n_estimators': tune.randint(100, 500),
        'clf__max_depth': tune.randint(1, 8),
        'clf__min_samples_split': tune.randint(2, 8),
        'clf__min_samples_leaf': tune.randint(1, 5),
        'clf__criterion': tune.choice(['gini', 'entropy']),
        'clf__max_features': tune.choice([None, 'sqrt'])
    }
}

svc_config = {
    'clf': SVC(),
    'params': {
        'atlas': ['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm'],
        'clf__C': [.1, 1, 10, 100, 1000],
        'clf__kernel': ['linear','poly','rbf','sigmoid'],
        'clf__gamma': tune.choice(['scale'])
    }
}

#############################################
# HPO
#############################################

config = xgb_config

tuner = TuneSearchCV(
    MultiScaleClassifier(classifier=config['clf']),
    HyperOptSearch.convert_search_space(config['params']),
    scoring='accuracy',
    cv=StratifiedShuffleSplit(n_splits=N_RUNS, test_size=TEST_SIZE),
    search_optimization='hyperopt',
    n_jobs=-1,
    refit=True,
    n_trials=10,
    verbose=2,
)

tuner.fit(X, y)
ray.shutdown()

clear_output()
print('[DEBUG] Best HPO score:', tuner.best_score_)

# create a tuned model using the best hyper-parameters
tuned_model = MultiScaleClassifier(atlas=tuner.best_params_['atlas'],
                                   classifier=config['clf']
                                   ).set_params(**tuner.best_params_)

tuned_model

In [None]:
cv_scores = cross_val_score(estimator = tuned_model,
                            X=X,
                            y=y,
                            cv=StratifiedShuffleSplit(n_splits=N_RUNS, test_size=TEST_SIZE),
                            verbose=3,
                            n_jobs=-1)
bootstrap_ci = stats.bootstrap(cv_scores.reshape(1,-1), np.mean)

clear_output(wait=True)
print(f'Test accuracy (mean ± std): {cv_scores.mean():.2f} ± {cv_scores.std():.2f}')
print(bootstrap_ci.confidence_interval)

In [None]:
# Permutation Feature Importance

feature_names = tuned_model.fit(X, y).get_feature_names_out()

def do_permutation_importance(estimator, X_train, y_train, X_test, y_test, scoring='accuracy'):
    """Perform permutation importance analysis on a given estimator."""
    estimator.fit(X_train, y_train)
    X_test_features = estimator.get_feature_extractor().transform(X_test)
    results = permutation_importance(estimator.get_classification_head(),
                                     X_test_features,
                                     y_test,
                                     n_jobs=-1,
                                     scoring=scoring)
    return results['importances_mean']


# run permutation importance in parallel
importance_scores = Parallel(n_jobs=-1, verbose=2)(
    delayed(do_permutation_importance)(
        estimator = tuned_model,
        X_train = X[train],
        y_train = y[train],
        X_test = X[test],
        y_test = y[test])
    for train, test in StratifiedShuffleSplit(n_splits=N_RUNS, test_size=TEST_SIZE).split(X, y)
)

# convert importance scores to dataframe and sort
importance_scores = pd.DataFrame(
    data=np.stack(importance_scores, axis=0),
    columns=feature_names).mean().sort_values(ascending=False).to_frame('importance')

# DEBUG: report top 20 features
importance_scores[:20]  # top 20 features