# Connectivity Classifier

This notebook trains a binary classifier to predict the group of a participant (AVGP or NVGP) based on their connectivity matrices.

It implements the following steps:

1. Load the data
2. Cross-validated classification pipeline
3. Permutation testing
4. Permutation importance
5. SHAP
6. Learning curve analysis


## Inputs

Region-level time-series are extracted using parcellation atlases (e.g., Dosenbach2010), and several aggregation (regions, networks, and randomized network assignment) are applied to the timeseries. The timeseries are then used to calculate connectivity matrices including correlation, partial correlation, tangent, precision, and covariance. Connectivity matrices are either aggregated (into networks or random-networks) or directly used as features for classification.

### Aggregation

We aggregate the region-level time-series using different strategies. The first strategy is use region time-series (no chance). The second strategy is to average the all region-level time-series across each network. The third strategy is to randomly assign each region to a network, and then average the time-series matrices across each network (random network assignment).

## Outputs

Prediction accuracies on the test set for each combination of connectivity metric, parcellation, and aggregation mode. The results are stored in the following file:
  - `models/connectivity_*.nc5`


## Requirements

To run this notebook, you need to have a few packages installed. The easiest way to do this is to use mamba to create a new environment from the `environment.yml` file in the root of this repository:

```bash
mamba env create -f environment.yml
mamba activate acnets
```

In [1]:
# 0. SETUP

%reload_ext autoreload
%autoreload 2

import pandas as pd
from pathlib import Path
import xarray as xr
import scipy.stats as stats
import numpy as np
from src.acnets.pipeline import ConnectivityPipeline, ConnectivityVectorizer
from sklearn.feature_selection import SelectFromModel, VarianceThreshold
from sklearn.inspection import permutation_importance
from sklearn.model_selection import (GridSearchCV, StratifiedShuffleSplit,
                                     cross_val_score, learning_curve,
                                     permutation_test_score)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.svm import LinearSVC
from tqdm.auto import tqdm
from IPython.display import clear_output

from src.acnets.pipeline import Parcellation

tqdm.pandas()

## Parameters

These parameters can be set in the command line when running the notebook, or in the notebook itself.

In [2]:
# PARAMETERS

N_CV_SPLITS = 100                       # number of cross-validation splits
N_TEST_SUBJECTS = 8                     # test size for cross-validation (number of subjects)

N_PERMUTATIONS = 100                    # for permutation test

ENABLE_SHAP_ANALYSIS = False            # enable SHAP analysis
ENABLE_LEARNING_CURVE_ANALYSIS = False  # enable learning curves

MODELS_DIR= Path('models/')             # Directory to save models

## Data

Here we load the data from the `data/julia2018/` dataset. These files contain the connectivity matrices for each participant, for each combination of parcellation and connectivity metric. For the reminder of this notebook, we only focus on `dosenbach2010` parcellation atlas.

In [3]:
# DATA PREPARATION
parcellation = Parcellation(atlas_name='dosenbach2010')

subjects = parcellation.fit_transform(X=None).coords['subject'].values

# extract group labels (AVGP or NVGP) from subject ids (e.g. AVGP-01)
subject_labels = [s[:4] for s in subjects]  

X = subjects.reshape(-1, 1)  # subject ids, shape: (n_subjects, 1)

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(subject_labels)

In [4]:
# PREPARE CV AND OUTPUT MODEL FILE

CV = StratifiedShuffleSplit(n_splits=N_CV_SPLITS, test_size=N_TEST_SUBJECTS)

n_splits = CV.get_n_splits()
n_folds = int(X.shape[0] / CV.test_size)

model_output_name = ('connectivity'
                     '_classifier-SVM'
                     f'_cv-{n_splits}x{n_folds}fold.nc5'
                     )

OUTPUT_PATH = MODELS_DIR / model_output_name

## Pipeline

The pipeline is composed of the following steps:

1. Extract connectivity matrices from the data
2. Vectorize the connectivity matrices
3. Scale the connectivity matrices
4. Remove zero-variance features
5. Select the top 32 features based on the coefficient of a SVM classifier
6. SVM binary classifier

In [5]:
# DEFINE PIPELINE

pipe  = Pipeline([
    ('connectivity', ConnectivityPipeline(kind='partial correlation')),
    ('vectorize', ConnectivityVectorizer()),
    ('scale', StandardScaler()),
    ('zerovar', VarianceThreshold()),
    ('select', SelectFromModel(LinearSVC(penalty='l1', dual=False, max_iter=10000),
                               max_features=lambda x: min(10, x.shape[1]))),
    ('clf', LinearSVC(penalty='l1', dual=False, max_iter=10000))
    # ('clf', SVC(kernel='linear', C=1))
])

# DEBUG (expected to overfit, i.e., score=1)
overfit_score = pipe.fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

[DEBUG] overfit accuracy: 1.000


## Verify the pipeline

Here we verify that the pipeline works by running it on all aggregation strategies.

In [6]:
# TEST VARIOUS AGGREGATIONS (calculate cross-validated accuracy and bootstrap CI)

for timeseries_aggregation, connectivity_aggregation in [
    (None, None),                # no aggregation (regions)
    ('network', None),           # time-series aggregation region->network
    ('random_network', None),    # time-series aggregation region->random_network
    (None, 'network'),           # connectivity matrix aggregation region->network
    (None, 'random_network'),    # connectivity matrix aggregation region->random_network
    ]:

    pipe.set_params(connectivity__atlas='dosenbach2010',
                    connectivity__kind='partial correlation',
                    connectivity__timeseries_aggregation=timeseries_aggregation,
                    connectivity__connectivity_aggregation=connectivity_aggregation)

    scores = cross_val_score(pipe, X, y,
                            cv=CV,
                            scoring='accuracy',
                            n_jobs=-1)
    bootstrap_ci = stats.bootstrap(scores.reshape(1,-1), np.mean)

    print(f'[timeseries={timeseries_aggregation}, connectivity={connectivity_aggregation}]')
    print('Test accuracy (mean ± std): {:.2f} ± {:.2f}'.format(scores.mean(), scores.std()))
    print(bootstrap_ci.confidence_interval, '\n')

[timeseries=None, connectivity=None]
Test accuracy (mean ± std): 0.48 ± 0.16
ConfidenceInterval(low=0.4475, high=0.51) 

[timeseries=network, connectivity=None]
Test accuracy (mean ± std): 0.67 ± 0.15
ConfidenceInterval(low=0.635, high=0.69375) 

[timeseries=random_network, connectivity=None]
Test accuracy (mean ± std): 0.50 ± 0.12
ConfidenceInterval(low=0.48, high=0.52875) 

[timeseries=None, connectivity=network]
Test accuracy (mean ± std): 0.71 ± 0.16
ConfidenceInterval(low=0.6775, high=0.7375) 

[timeseries=None, connectivity=random_network]
Test accuracy (mean ± std): 0.54 ± 0.14
ConfidenceInterval(low=0.51375, high=0.5675) 



In [7]:
# RUN PIPELINE ON ALL METRICS

param_grid = [
    {
        # Only connectivity matrix aggregation
        'connectivity__timeseries_aggregation': [None],
        'connectivity__connectivity_aggregation': [None, 'network', 'random_network'],
        'connectivity__atlas': ['dosenbach2010'],  # choices: ['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm', 'aal'],
        'connectivity__kind': ['partial correlation', 'correlation', 'tangent', 'covariance', "precision"],        
    },
    {
        # only time-series aggregation
        'connectivity__timeseries_aggregation': ['network', 'random_network'],
        'connectivity__connectivity_aggregation': [None],
        'connectivity__atlas': ['dosenbach2010'],
        'connectivity__kind': ['partial correlation', 'correlation', 'tangent', 'covariance', "precision"],        
    }
]

grid = GridSearchCV(
    pipe,
    param_grid,
    cv=CV,
    verbose=1,
    refit='accuracy',
    n_jobs=-2,
    scoring=('accuracy', 'f1'))

grid.fit(X, y)


clear_output(wait=True)
print('best estimator:', grid.best_estimator_, '\n', 'best score:', grid.best_score_)

best estimator: Pipeline(steps=[('connectivity',
                 ConnectivityPipeline(atlas='dosenbach2010',
                                      kind='partial correlation',
                                      timeseries_aggregation=None,
                                      connectivity_aggregation='network',
                                      bids_dir='data/julia2018')),
                ('vectorize', ConnectivityVectorizer()),
                ('scale', StandardScaler()), ('zerovar', VarianceThreshold()),
                ('select',
                 SelectFromModel(estimator=LinearSVC(dual=False, max_iter=10000,
                                                     penalty='l1'),
                                 max_features=<function <lambda> at 0x7f7d47460fe0>)),
                ('clf', LinearSVC(dual=False, max_iter=10000, penalty='l1'))]) 
 best score: 0.715


In [8]:

def get_model_name(params):
    """Helper function to generate a unique model name from the parameters."""

    atlas = params['connectivity__atlas']
    kind = params['connectivity__kind'].replace(' ', '')
    tagg = params['connectivity__timeseries_aggregation'] or 'region'  # none = region
    cagg = params['connectivity__connectivity_aggregation'] or 'none'  # none = ts-aggregation
    tagg = tagg.replace('random_network', 'random')  # random_network -> random
    cagg = cagg.replace('random_network', 'random')  # random_network -> random
    name = f'{atlas}_kind-{kind}_tagg-{tagg}_cagg-{cagg}'

    return name


In [9]:
# select network models and sort them by mean_test_accuracy

results = pd.DataFrame(grid.cv_results_)
results['model_name'] = results['params'].apply(get_model_name)

results = results.sort_values(by='mean_test_accuracy', ascending=False).reset_index()

# select network models 
network_results = results.query('model_name.str.contains("tagg-network")').copy()
network_results.rename(columns={'model_name': 'network_model_name'}, inplace=True)
network_model_params = network_results.sort_values('rank_test_accuracy').loc[:,'params'].to_list()

In [10]:
# PERMUTATION TEST (SHUFFLE Y) + CV Scores for all models

# if not ENABLE_PERMUTATION_TEST:
#     raise ValueError('ENABLE_PERMUTATION_TEST must be True to run permutation test.')

perm_scores = []
cv_scores = []
pvalues = []
model_names = []

# perform permutation test only for network-aggregated models

_model_names = []
_perm_scores = []
_cv_scores = []
_p_values = []

for params in (progress_bar := tqdm(network_model_params)):

    model_name = get_model_name(params)
    progress_bar.set_description(model_name)

    pipe.set_params(**params)

    _, perm_score, p_value = permutation_test_score(pipe, X, y,
                                                   scoring='accuracy',
                                                   n_permutations=N_PERMUTATIONS,
                                                   cv=4,
                                                   n_jobs=-2,
                                                   verbose=1)
    cv_score = cross_val_score(pipe, X, y,
                               cv=CV,
                               scoring='accuracy',
                               n_jobs=-2,
                               verbose=0)

    _model_names.append(model_name)
    _perm_scores.append(perm_score)
    _cv_scores.append(cv_score)
    _p_values.append(p_value)


results_permutation_test = xr.Dataset({
    'permutationtest-scores': (('network_model_name', 'permutation_dim'), _perm_scores),
    'permutationtest-cvscores': (('network_model_name', 'cv_dim'), _cv_scores),
    'permutationtest-pvalue': (('network_model_name'), _p_values)},
    coords={'network_model_name': _model_names})


  0%|          | 0/5 [00:00<?, ?it/s]

[Parallel(n_jobs=-2)]: Using backend LokyBackend with 7 concurrent workers.
[Parallel(n_jobs=-2)]: Done  36 tasks      | elapsed:   11.0s
[Parallel(n_jobs=-2)]: Done 100 out of 100 | elapsed:   30.1s finished
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 7 concurrent workers.
[Parallel(n_jobs=-2)]: Done  36 tasks      | elapsed:   11.3s
[Parallel(n_jobs=-2)]: Done 100 out of 100 | elapsed:   30.6s finished
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 7 concurrent workers.
[Parallel(n_jobs=-2)]: Done  36 tasks      | elapsed:   11.9s
[Parallel(n_jobs=-2)]: Done 100 out of 100 | elapsed:   31.8s finished
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 7 concurrent workers.
[Parallel(n_jobs=-2)]: Done  36 tasks      | elapsed:   11.3s
[Parallel(n_jobs=-2)]: Done 100 out of 100 | elapsed:   30.7s finished
[Parallel(n_jobs=-2)]: Using backend LokyBackend with 7 concurrent workers.
[Parallel(n_jobs=-2)]: Done  36 tasks      | elapsed:   11.1s
[Parallel(n_jobs=-2)]: Don

In [11]:
# PERMUTATION FEATURE IMPORTANCE (SHUFFLE X)

# if not ENABLE_PERMUTATION_IMPORTANCE:
#     raise RuntimeError('ENABLE_PERMUTATION_IMPORTANCE must be True to run permutation feature importance.')

_importance_reports = []

for p in (progress_bar := tqdm(network_model_params)):

    model_name = get_model_name(p)
    progress_bar.set_description(model_name)

    pipe.set_params(**p)

    # get feature names for the connectivity vector
    X_conn = pipe[:2].transform(X)
    feature_names = pipe[:2].get_feature_names_out()

    _importance_cv = []

    # cross-validated permutation importance
    for train, test in tqdm(CV.split(X,y), total=CV.get_n_splits(X,y), desc='CV', leave=False):
        pipe.fit(X[train], y[train])

        _results = permutation_importance(pipe[2:], X_conn[test], y[test],
                                        scoring='accuracy',
                                        n_jobs=-1)
        _importance_cv.append(_results.importances.T)
        # # sort by mean importance
        # importances = pd.DataFrame(np.vstack(_importance_cv), columns=feature_names)
        # sorted_columns = importances.mean(axis=0).sort_values(ascending=False).index
        # importances = importances[sorted_columns]

    feature_dim_name = f'{"_".join(model_name.split("_")[0:2])}_feature'

    importance_report = xr.Dataset({
        f'{model_name}_importance': (
            ('permutationimportance_num', feature_dim_name), np.vstack(_importance_cv))},
        coords={feature_dim_name: feature_names}
    )

    _importance_reports.append(importance_report)
    

results_permutation_importance = xr.merge(_importance_reports)

  0%|          | 0/5 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

CV:   0%|          | 0/100 [00:00<?, ?it/s]

In [12]:
# SHAP

def run_shap(pipeline):

  import shap
  import logging

  # turn off shap info-level logs while using progress bars
  logging.getLogger('shap').setLevel(logging.WARNING)

  shap_agg = []

  for p in (progress_bar := tqdm(network_model_params)):

      model_name = get_model_name(p)
      progress_bar.set_description(model_name)

      pipeline.set_params(**p)

      shap_values_cv = []
      test_indices = []
      y_test_cv = []
      y_pred_cv = []

      feature_names = pipeline[:2].get_feature_names_out()

      X_conn = pipeline[:2].fit_transform(X, y)

      for train, test in tqdm(CV.split(X, y), total=CV.get_n_splits(X, y), desc='CV', leave=False):

          shap_model = pipeline[2:].fit(X_conn[train], y[train])

          y_pred = shap_model.predict(X_conn[test])

          test_indices.extend(test)
          y_test_cv.append(y[test])
          y_pred_cv.append(y_pred)

          explainer = shap.Explainer(
              shap_model.predict, X_conn[train],
              feature_names=feature_names,
              # approximate=True,
              # model_output='raw',
              # feature_perturbation='interventional',
          )

          shap_values = explainer(X_conn[test], max_evals=2*len(feature_names) + 1)#, check_additivity=True)

          shap_values_cv.append(shap_values)

      # merge CV SHAPs

      # X = subjects.reshape(-1, 1)
      # X_test = pd.DataFrame(X[np.hstack(test_indices)], columns=['subject'])
      y_test = np.hstack(y_test_cv)
      y_pred = np.hstack(y_pred_cv)

      shap_values = shap.Explanation(
        values = np.vstack([sh.values for sh in shap_values_cv]),
        base_values = np.hstack([sh.base_values for sh in shap_values_cv]),
        data = np.vstack([sh.data for sh in shap_values_cv]),
        feature_names=feature_names,
        compute_time=np.sum([sh.compute_time for sh in shap_values_cv]),
        output_names=y_encoder.classes_,
        output_indexes=y_pred,
      )

      feature_dim_name = f'{"_".join(model_name.split("_")[0:2])}_feature'

      shap_ds = xr.Dataset({
        f'{model_name}_shap-value': (('shap_dim', feature_dim_name), shap_values.values),
        f'{model_name}_shap-data': (('shap_dim', feature_dim_name), shap_values.data),
        f'{model_name}_shap-ytest': (('shap_dim'), y_encoder.inverse_transform(y_test)),
        f'{model_name}_shap-ypred': (('shap_dim'), y_encoder.inverse_transform(y_pred)),
        },
        coords={feature_dim_name: feature_names}
      )

      shap_agg.append(shap_ds)

  results_shap = xr.merge(shap_agg)
  return results_shap

if ENABLE_SHAP_ANALYSIS:
  results_shap = run_shap()

In [13]:
# LEARNING CURVE ANALYSIS (HOW DOES TRAIN/TRAIN SIZE IMPACT ACCURACY?)
# Note: this only analyze the best model

def run_learning_curve_analysis(model, X, y):
    
    train_sizes, train_scores, test_scores = learning_curve(model, X, y,
                                                            cv=CV,
                                                            scoring='accuracy',
                                                            n_jobs=-1,
                                                            shuffle=True,
                                                            train_sizes=np.array([16, 18, 20, 22, 24]))

    results_learning_curve = pd.DataFrame({
        'learningcurve-trainsize': train_sizes,
        'learningcurve_trainscore': train_scores.mean(axis=1),
        'learningcurve_testscore': test_scores.mean(axis=1)
    })

    results_learning_curve.index.name  = 'learning_curve_index'

    return results_learning_curve.to_xarray()

if ENABLE_LEARNING_CURVE_ANALYSIS:
    results_learning_curve = run_learning_curve_analysis(grid.best_estimator_, X, y) 

In [14]:
# %%script echo Skipping...

# STORE RESULTS

datasets = [
    {'X': xr.DataArray(X.flatten(), dims=['subject'])},
    {'y': xr.DataArray(y_encoder.inverse_transform(y), dims='subject')},
    {'y_classes': y_encoder.classes_},
    results.drop(columns=['params']).to_xarray(),
]

datasets.append(results_permutation_test)
datasets.append(results_permutation_importance)
datasets.append(results_shap) if ENABLE_SHAP_ANALYSIS else None
datasets.append(results_learning_curve)  if ENABLE_LEARNING_CURVE_ANALYSIS else None

results_ds = xr.merge(datasets)

with open(OUTPUT_PATH, 'wb') as f:
    results_ds.to_netcdf(f, engine='h5netcdf')
    results_ds.close()

# reload from disk
results_ds = xr.open_dataset(OUTPUT_PATH, engine='scipy').load()

results_ds