# Connectivity Classifier

This notebook trains a binary classifier to predict the group of a participant (AVGP or NVGP) based on their connectivity matrices.

It implements the following steps:

1. Load the data
2. Cross-validated classification
3. Permutation testing
4. Permutation importance
5. SHAP


## Inputs

Connectivity matrices including five connectivity metrics (correlation, partial correlation, tangent, precision, and covariance), five possible parcellations (DiFuMo64, Dosenbach2010, Gordon2014, Friedman2020, and Seitzman2018), and several aggregation mode (region-level, network-level, randomized network assignment, network-connectivity, region-connectivity, and randomized network-connectivity).

## Outputs

Prediction accuracies on the test set for each combination of connectivity metric, parcellation, and aggregation mode. The results are stored in the following file:
  - `models/connectivity_classifier_*.nc5`


## Requirements

To run this notebook, you need to have a few packages installed. The easiest way to do this is to use mamba to create a new environment from the `environment.yml` file in the root of this repository:

```bash
mamba env create -f environment.yml
mamba activate acnets
```

In [4]:
# 0. SETUP

%reload_ext autoreload
%autoreload 3

import pandas as pd
from pathlib import Path
import xarray as xr
import scipy.stats as stats
import numpy as np
from src.acnets.pipeline import ConnectivityPipeline, ConnectivityVectorizer
from sklearn.feature_selection import SelectFromModel, VarianceThreshold
from sklearn.inspection import permutation_importance
from sklearn.model_selection import (GridSearchCV, StratifiedShuffleSplit,
                                     cross_val_score, learning_curve,
                                     permutation_test_score)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.svm import LinearSVC
from tqdm.auto import tqdm
from IPython.display import clear_output

In [5]:
import nilearn
nilearn._version.version

'0.10.1'

In [6]:
# PARAMETERS

CV = StratifiedShuffleSplit(n_splits=100, test_size=8)
N_PERMUTATIONS = 10  # for permutation test
N_TOP_MODELS = 15  # number of top models to run permutation test, SHAP, etc.
AGG_METHOD = 'network'  # whether aggregate regions into networks or not

# Analysis flags
ENABLE_SHAP = True  # whether to run SHAP analysis or not as it takes a long time
ENABLE_PERMUTATION_TEST = False  # whether to run permutation test or not as it takes a long time
ENABLE_LEARNING_CURVE = False  # whether to run learning curve or not as it takes a long time
ENABLE_PERMUTATION_IMPORTANCE = True  # whether to run permutation importance or not as it takes a long time

MODELS_DIR= Path('models')

In [7]:
# DATA PREPARATION

subjects = ConnectivityPipeline().transform('all').coords['subject'].values
groups = [s[:4] for s in subjects]  # AVGP or NVGP

X = subjects.reshape(-1, 1)

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(groups)

  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covarian

In [8]:
# PREPARE OUTPUT DATASET FILE

n_splits = CV.get_n_splits()
n_folds = int(X.shape[0] / CV.test_size)

model_output_name = ('connectivities'
                     '_classifier-SVM'
                     '_measure-accuracy'
                     f'_shap-{"enabled" if ENABLE_SHAP else "disabled"}'
                     f'_agg-{AGG_METHOD}'
                     f'_top-{N_TOP_MODELS}'
                     f'_cv-{n_splits}x{n_folds}fold.nc5'
                     )

OUTPUT_PATH = MODELS_DIR / model_output_name

In [9]:
# DEFINE PIPELINE

from sklearn.svm import SVC

pipe  = Pipeline([
    ('connectivity', ConnectivityPipeline()),
    ('vectorize', ConnectivityVectorizer()),
    ('scale', StandardScaler()),
    ('zerovar', VarianceThreshold()),
    ('select', SelectFromModel(LinearSVC(penalty='l1', dual=False, max_iter=10000),
                               max_features=lambda x: min(10, x.shape[1]))),
    ('clf', LinearSVC(penalty='l1', dual=False, max_iter=10000))
    # ('clf', SVC(kernel='linear', C=1))
])

# DEBUG (expected to overfit, i.e., score=1)
pipe.fit(X, y).score(X, y)

  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covarian

0.90625

In [10]:
# VERIFY THE MODEL (calculate cross-validated accuracy and bootstrap CI)
pipe.set_params(connectivity__atlas='dosenbach2010',
                connectivity__kind='covariance')

scores = cross_val_score(pipe, X, y,
                         cv=CV,
                         scoring='accuracy',
                         n_jobs=-1)
bootstrap_ci = stats.bootstrap(scores.reshape(1,-1), np.mean)

print('Test accuracy (mean ± std): {:.2f} ± {:.2f}'.format(scores.mean(), scores.std()))
print(bootstrap_ci.confidence_interval)

Test accuracy (mean ± std): 0.52 ± 0.17
ConfidenceInterval(low=0.48375, high=0.55125)


In [11]:
# RUN PIPELINE ON ALL ATLASES AND METRICS

param_grid = {
    'connectivity__agg_method': [AGG_METHOD],
    # 'connectivity__atlas': ['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm'],
    'connectivity__atlas': ['dosenbach2010'],
    'connectivity__kind': ['partial correlation', 'correlation', 'covariance', 'precision', 'tangent'],
}

grid = GridSearchCV(
    pipe,
    param_grid,
    cv=CV,
    verbose=2,
    n_jobs=-2,
    scoring='accuracy')

grid.fit(X, y)
clear_output(wait=True)


print('best estimator:', grid.best_estimator_, '\n', 'best score:', grid.best_score_)

# STORE RESULTS

def generate_model_name(params):
    """Helper function to generate a model name from the parameters."""

    atlas_name = params['connectivity__atlas'].replace('_2mm', '').replace('_','')
    agg_method = params['connectivity__agg_method']
    kind_name = params['connectivity__kind'].replace(' ', '-')
    name = f'{atlas_name}_{agg_method}_{kind_name}'

    return name

grid_results = pd.DataFrame(grid.cv_results_)

grid_results['grid_model_name'] = grid_results['params'].apply(generate_model_name)

grid_results.set_index('grid_model_name', inplace=True)
grid_results.drop(columns=['params'], inplace=True)

ds_grid = grid_results.to_xarray()
ds_grid['scoring'] = grid.scoring
ds_grid['cv_test_size'] = CV.test_size
ds_grid['cv_n_splits'] = CV.n_splits
ds_grid['n_subjects'] = len(X)

# Store results for region-based analysis and skip the rest of the notebook
if not AGG_METHOD == 'network':

    results = xr.merge([
        {'X': xr.DataArray(X.flatten(), dims=['subject'])},
        {'y': xr.DataArray(y_encoder.inverse_transform(y), dims='subject')},
        {'y_classes': y_encoder.classes_},
        ds_grid
    ])

    with open(OUTPUT_PATH, 'wb') as f:
        results.to_netcdf(f, engine='h5netcdf')
        results.close()


best estimator: Pipeline(steps=[('connectivity',
                 ConnectivityPipeline(kind='partial correlation')),
                ('vectorize', ConnectivityVectorizer()),
                ('scale', StandardScaler()), ('zerovar', VarianceThreshold()),
                ('select',
                 SelectFromModel(estimator=LinearSVC(dual=False, max_iter=10000,
                                                     penalty='l1'),
                                 max_features=<function <lambda> at 0x7f8b813d01f0>)),
                ('clf', LinearSVC(dual=False, max_iter=10000, penalty='l1'))]) 
 best score: 0.71625


In [12]:
# PERMUTATION TEST (SHUFFLE Y)

if not ENABLE_PERMUTATION_TEST:
    raise ValueError('ENABLE_PERMUTATION_TEST must be True to run permutation test.')

perm_scores_agg = []
cv_scores_agg = []
pvalues = []
model_names = []

# sort by rank and take top models
top_models = pd.DataFrame(grid.cv_results_).sort_values('rank_test_score')[:N_TOP_MODELS].loc[:,'params'].to_list()

for p in (progress_bar := tqdm(top_models)):

    pipe.set_params(**p)

    model_name = generate_model_name(p)
    progress_bar.set_description(model_name)

    _, perm_scores, pvalue = permutation_test_score(pipe, X, y,
                                                    scoring='accuracy',
                                                    n_permutations=N_PERMUTATIONS,
                                                    cv=CV,
                                                    n_jobs=-2,
                                                    verbose=1)

    cv_scores = cross_val_score(pipe, X, y,
                                cv=CV,
                                scoring='accuracy',
                                n_jobs=-2,
                                verbose=0)

    perm_scores_agg.append(perm_scores)
    cv_scores_agg.append(cv_scores)
    pvalues.append(pvalue)
    model_names.append(model_name)

ds_perm_test = xr.Dataset({
    'perm_scores': (('model_name', 'permutation_dim'), perm_scores_agg),
    'cv_scores': (('model_name', 'cv_dim'), cv_scores_agg),
    'pvalue': (('model_name'), pvalues)},
    coords={'model_name': model_names})


ValueError: ENABLE_PERMUTATION_TEST must be True to run permutation test.

In [13]:
# PERMUTATION FEATURE IMPORTANCE (SHUFFLE X)

if not ENABLE_PERMUTATION_IMPORTANCE:
    raise ValueError('ENABLE_PERMUTATION_IMPORTANCE must be True to run permutation feature importance.')

# sort by rank and take top models
top_models = pd.DataFrame(grid.cv_results_).sort_values('rank_test_score')[:N_TOP_MODELS].loc[:,'params'].to_list()

importances_cv = []

for p in (progress_bar := tqdm(top_models)):

    pipe.set_params(**p)

    model_name = generate_model_name(p)
    progress_bar.set_description(model_name)

    # get feature names for the connectivity vector
    X_conn = pipe[:2].transform(X)
    feature_names = pipe[:2].get_feature_names_out()

    importances = []

    # cross-validated permutation importance
    for train, test in tqdm(CV.split(X,y), total=CV.get_n_splits(X,y), desc='CV', leave=False):
        pipe.fit(X[train], y[train])

        results = permutation_importance(pipe[2:], X_conn[test], y[test],
                                        scoring=grid.scoring,
                                        n_jobs=-1)
        importances.append(results.importances.T)

    feature_dim_name = f'{"_".join(model_name.split("_")[0:2])}_feature'

    importances_cv_dataset = xr.Dataset({
        f'{model_name} importances': (('permutation_importance_num', feature_dim_name), np.vstack(importances))},
        coords={feature_dim_name: feature_names}
    )

    importances_cv.append(importances_cv_dataset)
    
    # sort by mean importance
    importances = pd.DataFrame(np.vstack(importances), columns=feature_names)
    sorted_columns = importances.mean(axis=0).sort_values(ascending=False).index
    importances = importances[sorted_columns]

ds_perm_importance = xr.merge(importances_cv)

  0%|          | 0/5 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covarian

CV:   0%|          | 0/100 [00:00<?, ?it/s]

  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covariances_std = [
  covarian

CV:   0%|          | 0/100 [00:00<?, ?it/s]

In [14]:
# SHAP

if not ENABLE_SHAP:
    raise ValueError('ENABLE_SHAP must be True to run SHAP analysis.')


import shap
import logging

# turn off shap info-level logs while using progress bars
logging.getLogger('shap').setLevel(logging.WARNING)

top_models = pd.DataFrame(grid.cv_results_).sort_values('rank_test_score')[:N_TOP_MODELS].loc[:,'params'].to_list()
shap_agg = []

for p in (progress_bar := tqdm(top_models)):

    pipe.set_params(**p)

    model_name = generate_model_name(p)
    progress_bar.set_description(model_name)

    shap_values_cv = []
    test_indices = []
    y_test_cv = []
    y_pred_cv = []

    feature_names = pipe[:2].get_feature_names_out()

    X_conn = pipe[:2].fit_transform(X, y)

    for train, test in tqdm(CV.split(X, y), total=CV.get_n_splits(X, y), desc='CV', leave=False):

        shap_model = pipe[2:].fit(X_conn[train], y[train])

        y_pred = shap_model.predict(X_conn[test])

        test_indices.extend(test)
        y_test_cv.append(y[test])
        y_pred_cv.append(y_pred)

        explainer = shap.Explainer(
            shap_model.predict, X_conn[train],
            feature_names=feature_names,
            # approximate=True,
            # model_output='raw',
            # feature_perturbation='interventional',
        )

        shap_values = explainer(X_conn[test], max_evals=2*len(feature_names) + 1)#, check_additivity=True)

        shap_values_cv.append(shap_values)

    # merge CV SHAPs

    # X = subjects.reshape(-1, 1)
    # X_test = pd.DataFrame(X[np.hstack(test_indices)], columns=['subject'])
    y_test = np.hstack(y_test_cv)
    y_pred = np.hstack(y_pred_cv)

    shap_values = shap.Explanation(
      values = np.vstack([sh.values for sh in shap_values_cv]),
      base_values = np.hstack([sh.base_values for sh in shap_values_cv]),
      data = np.vstack([sh.data for sh in shap_values_cv]),
      feature_names=feature_names,
      compute_time=np.sum([sh.compute_time for sh in shap_values_cv]),
      output_names=y_encoder.classes_,
      output_indexes=y_pred,
    )

    feature_dim_name = f'{"_".join(model_name.split("_")[0:2])}_feature'

    shap_ds = xr.Dataset({
      f'{model_name} shap': (('shap_dim', feature_dim_name), shap_values.values),
      f'{model_name} shap data': (('shap_dim', feature_dim_name), shap_values.data),
      f'{model_name} shap y_test': (('shap_dim'), y_encoder.inverse_transform(y_test)),
      f'{model_name} shap y_pred': (('shap_dim'), y_encoder.inverse_transform(y_pred)),
      },
      coords={feature_dim_name: feature_names}
    )

    shap_agg.append(shap_ds)

ds_shap = xr.merge(shap_agg)

  0%|          | 0/5 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

The default strategy for standardize is currently 'zscore' which incorrectly uses population std to calculate sample zscores. The new strategy 'zscore_sample' corrects this behavior by using the sample std. In release 0.13, the default strategy will be replaced by the new strategy and the 'zscore' option will be removed. Please use 'zscore_sample' instead.
The default strategy for standardize is currently 'zscore' which incorrectly uses population std to calculate sample zscores. The new strategy 'zscore_sample' corrects this behavior by using the sample std. In release 0.13, the default strategy will be replaced by the new strategy and the 'zscore' option will be removed. Please use 'zscore_sample' instead.
The default strategy for standardize is currently 'zscore' which incorrectly uses population std to calculate sample zscores. The new strategy 'zscore_sample' corrects this behavior by using the sample std. In release 0.13, the default strategy will be replaced by the new strategy 

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

In [15]:
# LEARNING CURVE ANALYSIS (HOW DOES TRAIN/TRAIN SIZE IMPACT ACCURACY?)
# Note: this only analyze the best model

if not ENABLE_LEARNING_CURVE:
    raise ValueError('ENABLE_LEARNING_CURVE must be True to run learning curve analysis.')

train_sizes, train_scores, test_scores = learning_curve(grid.best_estimator_, X, y,
                                                        cv=CV,
                                                        scoring='accuracy',
                                                        n_jobs=-1,
                                                        shuffle=True,
                                                        train_sizes=np.array([16, 18, 20, 22, 24]))


learning_curve_results = pd.DataFrame({
    'learning_curve_train_size': train_sizes,
    'learning_curve_mean_train_score': train_scores.mean(axis=1),
    'learning_curve_mean_test_score': test_scores.mean(axis=1)
})

learning_curve_results.index.name  = 'learning_curve_num'

ds_learning_curve = learning_curve_results.to_xarray()

ValueError: ENABLE_LEARNING_CURVE must be True to run learning curve analysis.

In [16]:
# %%script echo Skipping...

# STORE RESULTS

datasets = [
    {'X': xr.DataArray(X.flatten(), dims=['subject'])},
    {'y': xr.DataArray(y_encoder.inverse_transform(y), dims='subject')},
    {'y_classes': y_encoder.classes_},
    ds_grid
]

datasets.append(ds_perm_test) if ENABLE_PERMUTATION_TEST else None
datasets.append(ds_perm_importance) if ENABLE_PERMUTATION_IMPORTANCE else None
datasets.append(ds_shap) if ENABLE_SHAP else None
datasets.append(ds_learning_curve) if ENABLE_LEARNING_CURVE else None

results = xr.merge(datasets)

with open(OUTPUT_PATH, 'wb') as f:
    results.to_netcdf(f, engine='h5netcdf')
    results.close()

# reload from disk
results = xr.open_dataset(OUTPUT_PATH, engine='scipy').load()

results