# MultiScale Classifier

Sections:

1. Data
2. Hyper-parameter space
3. HPO
4. Cross-validation scores


## Setup

In [1]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import json
import ray
from IPython.display import clear_output
import matplotlib.pyplot as plt
import seaborn as sns
from joblib import Parallel, delayed
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import LinearSVC
from xgboost import XGBClassifier

import skexplain
from skexplain.common.importance_utils import to_skexplain_importance

from src.acnets.pipeline import MultiScaleClassifier, Parcellation
from functools import partial

In [2]:
N_RUNS = 100      # 10 independent train/test runs
TEST_SIZE = .25  # proportion of test subjects out of 32 subjects

## Prepare data

In [3]:
# Xy
subjects = Parcellation(atlas_name='difumo_64_2mm').fit_transform(None).coords['subject'].values
X = subjects.reshape(-1,1)                                  # subjects ids, shape: (n_subjects, 1)

y_encoder = LabelEncoder()
y = y_encoder.fit_transform([s[:4] for s in subjects])      # labels (AVGP=1 or NVGP=1), shape: (n_subjects,)
y_mapping = dict(zip(y_encoder.classes_, y_encoder.transform(y_encoder.classes_)))

# DEBUG (report label mapping)
print('[DEBUG] label mapping:', y_mapping)

# DEBUG (expected to overfit, i.e., accuracy is 1)
overfit_score = MultiScaleClassifier().fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

[DEBUG] label mapping: {'AVGP': 0, 'NVGP': 1}
[DEBUG] overfit accuracy: 1.000


## Hyper-parameter tuning

### Parameter space

In [4]:
xgb_param_space = {
    'clf': XGBClassifier(base_score=.5, objective='binary:logistic'),
    # 'atlas': ['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm'],
    'atlas': tune.choice(['dosenbach2010']),
    'kind': tune.choice(['partial correlation', 'tangent', 'precision', 'correlation', 'covariance']),
    'extract_h1_features': tune.grid_search([True, False]),
    'extract_h2_features': tune.grid_search([True, False]),
    'extract_h3_features': tune.grid_search([True]),
    # 'clf__subsample': tune.choice([.5, .8, 1]),
    'clf__n_estimators': tune.grid_search([100, 200]),
    'clf__max_depth': tune.grid_search([2, 4, 6, 8]),
    'clf__learning_rate': tune.grid_search([.1, .3]),
}

rfc_param_space = {
    'clf': RandomForestClassifier(),
    # 'atlas': tune.choice(['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm']),
    'atlas': tune.choice(['dosenbach2010']),
    'clf__n_estimators': tune.randint(100, 500),
    'clf__max_depth': tune.randint(1, 8),
    'clf__min_samples_split': tune.randint(2, 8),
    'clf__min_samples_leaf': tune.randint(1, 5),
    'clf__criterion': tune.choice(['gini', 'entropy']),
    'clf__max_features': tune.choice([None, 'sqrt'])
}

svm_param_space = {
    'clf': LinearSVC(max_iter=100000),
    # 'atlas': tune.choice(['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm']),
    'atlas': tune.choice(['dosenbach2010']),
    'clf__penalty': ['l1'],
    'clf__dual': [False],
    'clf__C': tune.choice([.01, .1, 1, 10, 100, 1000]),
    # 'clf__kernel': ['linear','poly','rbf','sigmoid'],
    # 'clf__gamma': tune.choice(['scale'])
}


In [5]:
# now we define the objective function

def eval_multiscale_model(config, classifier, X, y):

    model = MultiScaleClassifier(classifier=classifier).set_params(**config)

    # outer CV (for test set), and inner CV (for validation set)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, stratify=y)
    inner_cv = StratifiedKFold(n_splits=8, shuffle=True)

    # fit and score the validation set
    val_score = cross_val_score(model, X_train, y_train, scoring='accuracy', cv=inner_cv).mean()

    # test score (we only report this and do not use it during HPO)
    test_score = model.fit(X_train, y_train).score(X_test, y_test)

    metrics = {
        'val_accuracy': val_score,
        'test_accuracy': test_score
    }

    return metrics

# DEBUG
debug_config = dict(atlas='dosenbach2010', extract_h1_features=False, extract_h2_features=True)
'DEBUG', eval_multiscale_model(debug_config, classifier=XGBClassifier(), X=X, y=y)

('DEBUG', {'val_accuracy': 0.7916666666666666, 'test_accuracy': 0.5})

### HPO

In [6]:
# prep the hyper-parameter space and init the objective function
param_space = xgb_param_space.copy()
clf = param_space.pop('clf')
output_name = f'models/multiscale_classifier-{clf.__class__.__name__}-hpo.json'

objective_func = partial(eval_multiscale_model, classifier=clf, X=X, y=y)

ray.shutdown(); ray.init()

tuner = tune.Tuner(
    objective_func,
    param_space=param_space,
    tune_config=tune.TuneConfig(
        metric='val_accuracy',
        mode='max',      # FIXME we need to mean the CV scores
        num_samples=10,  # FIXME change to N_RUNS?
    )
)

tuning_results = tuner.fit()
ray.shutdown()

clear_output()
best_score = tuner.get_results().get_best_result().metrics['val_accuracy']
best_params = tuner.get_results().get_best_result(metric='val_accuracy', mode='max').config

# store the best hyper-parameters
best_params['classifier'] = clf.__class__.__name__
with open(output_name, 'w') as f:
    json.dump(best_params, f, indent=2)
del best_params['classifier']

print('[DEBUG] Best HPO validation score:', best_score)

# plot the tuned model
MultiScaleClassifier(classifier=clf).set_params(**best_params)

0,1
Current time:,2023-10-11 12:47:59
Running for:,00:05:28.39
Memory:,9.2/62.7 GiB

Trial name,status,loc,atlas,clf__learning_rate,clf__max_depth,clf__n_estimators,extract_h1_features,extract_h2_features,extract_h3_features,kind,iter,total time (s),val_accuracy,test_accuracy
eval_multiscale_model_e14c5_00216,RUNNING,10.184.42.15:102851,dosenbach2010,0.1,2,200,False,True,True,tangent,,,,
eval_multiscale_model_e14c5_00220,RUNNING,10.184.42.15:102848,dosenbach2010,0.1,6,200,False,True,True,tangent,,,,
eval_multiscale_model_e14c5_00222,RUNNING,10.184.42.15:102849,dosenbach2010,0.1,8,200,False,True,True,tangent,,,,
eval_multiscale_model_e14c5_00237,RUNNING,10.184.42.15:102847,dosenbach2010,0.3,6,200,True,False,True,partial correlation,,,,
eval_multiscale_model_e14c5_00238,RUNNING,10.184.42.15:102850,dosenbach2010,0.1,8,200,True,False,True,covariance,,,,
eval_multiscale_model_e14c5_00239,RUNNING,10.184.42.15:102845,dosenbach2010,0.3,8,200,True,False,True,correlation,,,,
eval_multiscale_model_e14c5_00240,RUNNING,10.184.42.15:102844,dosenbach2010,0.1,2,100,False,False,True,precision,,,,
eval_multiscale_model_e14c5_00242,PENDING,,dosenbach2010,0.1,4,100,False,False,True,covariance,,,,
eval_multiscale_model_e14c5_00243,PENDING,,dosenbach2010,0.3,4,100,False,False,True,correlation,,,,
eval_multiscale_model_e14c5_00244,PENDING,,dosenbach2010,0.1,6,100,False,False,True,covariance,,,,




## Cross-validation accuracy and CI

In [None]:
r = tuner.get_results().get_dataframe()

from IPython.display import display

for c in r.columns:
    if c.startswith('config/'):
        ac = r.groupby(c).mean('val_accuracy')
        display(ac)

Unnamed: 0_level_0,val_accuracy,test_accuracy,timestamp,done,training_iteration,time_this_iter_s,time_total_s,pid,time_since_restore,iterations_since_restore,config/extract_h1_features,config/extract_h2_features,config/extract_h3_features,config/clf__n_estimators,config/clf__max_depth,config/clf__learning_rate
config/atlas,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
dosenbach2010,0.661849,0.711328,1697020000.0,0.0,1.0,6.314029,6.314029,93895.484375,6.314029,1.0,0.5,0.5,1.0,150.0,5.0,0.2


Unnamed: 0_level_0,val_accuracy,test_accuracy,timestamp,done,training_iteration,time_this_iter_s,time_total_s,pid,time_since_restore,iterations_since_restore,config/extract_h2_features,config/extract_h3_features,config/clf__n_estimators,config/clf__max_depth,config/clf__learning_rate
config/extract_h1_features,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
False,0.699089,0.740234,1697020000.0,0.0,1.0,6.472659,6.472659,93895.540625,6.472659,1.0,0.5,1.0,150.0,5.0,0.2
True,0.624609,0.682422,1697020000.0,0.0,1.0,6.155398,6.155398,93895.428125,6.155398,1.0,0.5,1.0,150.0,5.0,0.2


Unnamed: 0_level_0,val_accuracy,test_accuracy,timestamp,done,training_iteration,time_this_iter_s,time_total_s,pid,time_since_restore,iterations_since_restore,config/extract_h1_features,config/extract_h3_features,config/clf__n_estimators,config/clf__max_depth,config/clf__learning_rate
config/extract_h2_features,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
False,0.642448,0.684766,1697020000.0,0.0,1.0,5.291049,5.291049,93895.4625,5.291049,1.0,0.5,1.0,150.0,5.0,0.2
True,0.68125,0.737891,1697020000.0,0.0,1.0,7.337008,7.337008,93895.50625,7.337008,1.0,0.5,1.0,150.0,5.0,0.2


Unnamed: 0_level_0,val_accuracy,test_accuracy,timestamp,done,training_iteration,time_this_iter_s,time_total_s,pid,time_since_restore,iterations_since_restore,config/extract_h1_features,config/extract_h2_features,config/clf__n_estimators,config/clf__max_depth,config/clf__learning_rate
config/extract_h3_features,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
True,0.661849,0.711328,1697020000.0,0.0,1.0,6.314029,6.314029,93895.484375,6.314029,1.0,0.5,0.5,150.0,5.0,0.2


Unnamed: 0_level_0,val_accuracy,test_accuracy,timestamp,done,training_iteration,time_this_iter_s,time_total_s,pid,time_since_restore,iterations_since_restore,config/extract_h1_features,config/extract_h2_features,config/extract_h3_features,config/clf__max_depth,config/clf__learning_rate
config/clf__n_estimators,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
100,0.665625,0.705078,1697020000.0,0.0,1.0,6.255447,6.255447,93895.446875,6.255447,1.0,0.5,0.5,1.0,5.0,0.2
200,0.658073,0.717578,1697020000.0,0.0,1.0,6.372611,6.372611,93895.521875,6.372611,1.0,0.5,0.5,1.0,5.0,0.2


Unnamed: 0_level_0,val_accuracy,test_accuracy,timestamp,done,training_iteration,time_this_iter_s,time_total_s,pid,time_since_restore,iterations_since_restore,config/extract_h1_features,config/extract_h2_features,config/extract_h3_features,config/clf__n_estimators,config/clf__learning_rate
config/clf__max_depth,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
2,0.647135,0.713281,1697020000.0,0.0,1.0,6.306796,6.306796,93895.19375,6.306796,1.0,0.5,0.5,1.0,150.0,0.2
4,0.660156,0.696875,1697020000.0,0.0,1.0,6.344633,6.344633,93895.76875,6.344633,1.0,0.5,0.5,1.0,150.0,0.2
6,0.682292,0.707812,1697020000.0,0.0,1.0,6.325249,6.325249,93895.8,6.325249,1.0,0.5,0.5,1.0,150.0,0.2
8,0.657813,0.727344,1697020000.0,0.0,1.0,6.279437,6.279437,93895.175,6.279437,1.0,0.5,0.5,1.0,150.0,0.2


Unnamed: 0_level_0,val_accuracy,test_accuracy,timestamp,done,training_iteration,time_this_iter_s,time_total_s,pid,time_since_restore,iterations_since_restore,config/extract_h1_features,config/extract_h2_features,config/extract_h3_features,config/clf__n_estimators,config/clf__max_depth
config/clf__learning_rate,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
0.1,0.664974,0.710547,1697020000.0,0.0,1.0,6.320232,6.320232,93895.346875,6.320232,1.0,0.5,0.5,1.0,150.0,5.0
0.3,0.658724,0.712109,1697020000.0,0.0,1.0,6.307825,6.307825,93895.621875,6.307825,1.0,0.5,0.5,1.0,150.0,5.0


In [None]:
tuned_model = MultiScaleClassifier(classifier=clf).set_params(**best_params)

cv_scores = cross_val_score(tuned_model, X, y,
                            cv=StratifiedShuffleSplit(n_splits=N_RUNS, test_size=TEST_SIZE),
                            verbose=3, n_jobs=-1)

# Calculate 95% confidence interval
bootstrap_ci = stats.bootstrap(cv_scores.reshape(1,-1), np.mean)

# Report
clear_output(wait=True)
print(f'Test accuracy (mean ± std): {cv_scores.mean():.3f} ± {cv_scores.std():.3f}')
print(bootstrap_ci.confidence_interval)

Test accuracy (mean ± std): 0.741 ± 0.146
ConfidenceInterval(low=0.71125, high=0.77)
