# Connectivity Classifier

This notebook trains a binary classifier to predict the group of a participant (AVGP or NVGP) based on their connectivity matrices.

It implements the following steps:

1. Load the data
2. Cross-validated classification pipeline
3. Permutation testing
4. Permutation importance
5. SHAP
6. Learning curve analysis


## Inputs

Region-level time-series are extracted using several possible parcellation atlases (i.e., DiFuMo64, Dosenbach2010, Gordon2014, Friedman2020, and Seitzman2018), and several aggregation (regions, networks, and randomized network assignment). The timeseries are then used to calculate connectivity matrices including correlation, partial correlation, tangent, precision, and covariance.

## Outputs

Prediction accuracies on the test set for each combination of connectivity metric, parcellation, and aggregation mode. The results are stored in the following file:
  - `models/connectivity_classifier_*.nc5`


## Requirements

To run this notebook, you need to have a few packages installed. The easiest way to do this is to use mamba to create a new environment from the `environment.yml` file in the root of this repository:

```bash
mamba env create -f environment.yml
mamba activate acnets
```

In [81]:
# 0. SETUP

%reload_ext autoreload
%autoreload 2

import pandas as pd
from pathlib import Path
import xarray as xr
import scipy.stats as stats
import numpy as np
from src.acnets.pipeline import ConnectivityPipeline, ConnectivityVectorizer
from sklearn.feature_selection import SelectFromModel, VarianceThreshold
from sklearn.inspection import permutation_importance
from sklearn.model_selection import (GridSearchCV, StratifiedShuffleSplit,
                                     cross_val_score, learning_curve,
                                     permutation_test_score)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.svm import LinearSVC
from tqdm.auto import tqdm
from IPython.display import clear_output

from src.acnets.pipeline import Parcellation

## Parameters

These parameters can be set in the command line when running the notebook, or in the notebook itself.

In [82]:
# PARAMETERS

N_CV_SPLITS = 10                    # number of cross-validation splits
N_TEST_SUBJECTS = 8                 # test size for cross-validation (number of subjects)
CV = StratifiedShuffleSplit(n_splits=N_CV_SPLITS, test_size=N_TEST_SUBJECTS)

N_PERMUTATIONS = 10                 # for permutation test

# Analysis flags
ENABLE_SHAP = False                     # run SHAP analysis or not as it takes a long time
ENABLE_PERMUTATION_TEST = False         # run permutation test or not as it takes a long time
ENABLE_LEARNING_CURVE = False           # run learning curve or not as it takes a long time
ENABLE_PERMUTATION_IMPORTANCE = False   # run permutation importance or not as it takes a long time

MODELS_DIR= Path('models/')              # Directory to save models

## Data

Here we load the data from the `data/julia2018/` dataset. These files contain the connectivity matrices for each participant, for each combination of parcellation and connectivity metric. For the reminder of this notebook, we only focus on `dosenbach2010` parcellation atlas.

In [83]:
# DATA PREPARATION
parcellation = Parcellation(atlas_name='dosenbach2010').fit()

subjects = parcellation.dataset_.coords['subject'].values

# extract group labels (AVGP or NVGP) from subject ids (e.g. AVGP-01)
subject_labels = [s[:4] for s in subjects]  

X = subjects.reshape(-1, 1)  # subject ids, shape: (n_subjects, 1)

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(subject_labels)     # labels, shape: (n_subjects,)

In [84]:
# PREPARE OUTPUT DATASET FILE

n_splits = CV.get_n_splits()
n_folds = int(X.shape[0] / CV.test_size)

model_output_name = ('connectivities'
                     '_classifier-SVM'
                     '_measure-accuracy'
                     f'_cv-{n_splits}x{n_folds}fold.nc5'
                     )

OUTPUT_PATH = MODELS_DIR / model_output_name

## Aggregation Maps

We aggregate the region-level time-series using different strategies. The first strategy is use region time-series (no chance). The second strategy is to average the all region-level time-series across each network. The third strategy is to randomly assign each region to a network, and then average the time-series matrices across each network (random network assignment).

In [117]:
# strategy 1: region to network
region_to_network = parcellation.labels_['network'].to_dict()

# strategy 2: region to region
region_to_region = {k: k for k in region_to_network.keys()}  

# convert them to dataframes
region_to_network = pd.DataFrame.from_dict(region_to_network, orient='index',columns=['group'])
region_to_region = pd.DataFrame.from_dict(region_to_region, orient='index', columns=['group'])

# strategy 3: random network assignment
region_to_random_network = region_to_network.copy().rename(columns={'group': 'original_group'})
region_to_random_network['group'] = region_to_random_network['original_group'].sample(frac=1).values

aggregation_strategies = {
    'region': region_to_region,
    'network': region_to_network,
    'random-network': region_to_random_network
}

# set names for aggregation strategies; this will be handy later
for k, v in aggregation_strategies.items():
    v.name = k

## Pipeline

The pipeline is composed of the following steps:

1. Extract connectivity matrices from the data
2. Vectorize the connectivity matrices
3. Scale the connectivity matrices
4. Remove zero-variance features
5. Select the top 32 features based on the coefficient of a SVM classifier
6. SVM binary classifier

In [118]:
# DEFINE PIPELINE

pipe  = Pipeline([
    ('connectivity', ConnectivityPipeline(region_to_network=region_to_network, kind='partial correlation')),
    ('vectorize', ConnectivityVectorizer()),
    ('scale', StandardScaler()),
    ('zerovar', VarianceThreshold()),
    ('select', SelectFromModel(LinearSVC(penalty='l1', dual=False, max_iter=10000),
                               max_features=lambda x: min(10, x.shape[1]))),
    ('clf', LinearSVC(penalty='l1', dual=False, max_iter=10000))
    # ('clf', SVC(kernel='linear', C=1))
])

# DEBUG (expected to overfit, i.e., score=1)
overfit_score = pipe.fit(X, y).score(X, y)
print(f'[DEBUG] overfit accuracy: {overfit_score:.3f}')

[DEBUG] overfit accuracy: 1.000


## Verify the pipeline

Here we verify that the pipeline works by running it on all aggregation strategies.

In [119]:
# VERIFY THE MODEL (calculate cross-validated accuracy and bootstrap CI)

for agg_strategy, agg_mapping in tqdm(aggregation_strategies.items()):

    pipe.set_params(connectivity__atlas='dosenbach2010',
                    connectivity__kind='partial correlation',
                    connectivity__region_to_network=agg_mapping)

    scores = cross_val_score(pipe, X, y,
                            cv=CV,
                            scoring='accuracy',
                            n_jobs=-1)
    bootstrap_ci = stats.bootstrap(scores.reshape(1,-1), np.mean)

    print(f'[{agg_strategy}]')
    print('Test accuracy (mean ± std): {:.2f} ± {:.2f}'.format(scores.mean(), scores.std()))
    print(bootstrap_ci.confidence_interval, '\n')

  0%|          | 0/3 [00:00<?, ?it/s]

[region]
Test accuracy (mean ± std): 0.50 ± 0.23
ConfidenceInterval(low=0.3375, high=0.625) 

[network]
Test accuracy (mean ± std): 0.69 ± 0.13
ConfidenceInterval(low=0.6125, high=0.7625) 

[random-network]
Test accuracy (mean ± std): 0.59 ± 0.15
ConfidenceInterval(low=0.5125, high=0.7125) 



In [120]:
# RUN PIPELINE ON ALL METRICS

param_grid = {
    'connectivity__region_to_network': list(aggregation_strategies.values()),
    'connectivity__atlas': ['dosenbach2010'],
    # 'connectivity__atlas': ['dosenbach2010', 'gordon2014_2mm', 'difumo_64_2mm'],
    'connectivity__kind': ['partial correlation', 'correlation', 'covariance', 'precision', 'tangent'],
}

grid = GridSearchCV(
    pipe,
    param_grid,
    cv=CV,
    verbose=1,
    n_jobs=-2,
    scoring='accuracy')

grid.fit(X, y)
clear_output(wait=True)


print('best estimator:', grid.best_estimator_, '\n', 'best score:', grid.best_score_)

best estimator: Pipeline(steps=[('connectivity',
                 ConnectivityPipeline(atlas='dosenbach2010',
                                      kind='partial correlation',
                                      bids_dir='data/julia2018',
                                      parcellation_cache_dir='data/julia2018/derivatives/resting_timeseries/',
                                      region_to_network=                     original_group              group
vmPFC 1                     default            default
aPFC 2              fronto-parietal          occipital
aPFC 3              fronto-parietal  cingulo-opercular
mPFC 4                      default         cerebellum
aPFC 5                      d...
post occipital 159        occipital  cingulo-opercular
post occipital 160        occipital            default

[160 rows x 2 columns])),
                ('vectorize', ConnectivityVectorizer()),
                ('scale', StandardScaler()), ('zerovar', VarianceThreshold()),
           

In [121]:

# TODO rewrite the rest of the code to support the new aggregation strategy

def get_model_name(params):
    """Helper function to generate a model name from the parameters."""

    atlas_name = params['connectivity__atlas']
    agg_name = params['connectivity__region_to_network']
    kind_name = params['connectivity__kind']
    name = f'{atlas_name}_{agg_name}_{kind_name}'

    return name

mappings = []
for c in grid.cv_results_['params']:
    if type(c['connectivity__region_to_network']) == pd.DataFrame:
        c['connectivity__region_to_network'] = c['connectivity__region_to_network'].name
    mappings.append(c['connectivity__region_to_network'])
grid.cv_results_['param_connectivity__region_to_network'] = mappings

results = pd.DataFrame(grid.cv_results_)
results['model_name'] = results['params'].apply(get_model_name)

results = results.sort_values(by='mean_test_score', ascending=False)

In [None]:
# PERMUTATION TEST (SHUFFLE Y)

# if not ENABLE_PERMUTATION_TEST:
#     raise ValueError('ENABLE_PERMUTATION_TEST must be True to run permutation test.')

perm_scores_agg = []
cv_scores_agg = []
pvalues = []
model_names = []

tqdm.pandas()

def permutation_test(model):

    p = model['params']
    p['connectivity__region_to_network'] = aggregation_strategies[p['connectivity__region_to_network']]
    pipe.set_params(**p)

    _, perm_scores, pvalue = permutation_test_score(pipe, X, y,
                                                    scoring='accuracy',
                                                    n_permutations=N_PERMUTATIONS,
                                                    cv=CV,
                                                    n_jobs=-2,
                                                    verbose=1)
    cv_scores = cross_val_score(pipe, X, y,
                                cv=CV,
                                scoring='accuracy',
                                n_jobs=-2,
                                verbose=0)
    perm_scores_agg.append(perm_scores)
    cv_scores_agg.append(cv_scores)
    pvalues.append(pvalue)
    model_names.append(model['model_name'])
    print(model['model_name'])

results.progress_apply(permutation_test, axis=1)

# ds_perm_test = xr.Dataset({
#     'perm_scores': (('model_name', 'permutation_dim'), perm_scores_agg),
#     'cv_scores': (('model_name', 'cv_dim'), cv_scores_agg),
#     'pvalue': (('model_name'), pvalues)},
#     coords={'model_name': model_names})

ds_perm_test = xr.Dataset({
    'perm_scores': (('model_name', 'permutation_dim'), perm_scores_agg),
    'cv_scores': (('model_name', 'cv_dim'), cv_scores_agg),
    'pvalue': (('model_name'), pvalues)},
    coords={'model_name': model_names})


In [None]:
# PERMUTATION FEATURE IMPORTANCE (SHUFFLE X)

if not ENABLE_PERMUTATION_IMPORTANCE:
    raise ValueError('ENABLE_PERMUTATION_IMPORTANCE must be True to run permutation feature importance.')

# sort by rank and take top models
top_models = pd.DataFrame(grid.cv_results_).sort_values('rank_test_score')[:N_TOP_MODELS].loc[:,'params'].to_list()

importances_cv = []

for p in (progress_bar := tqdm(top_models)):

    pipe.set_params(**p)

    model_name = generate_model_name(p)
    progress_bar.set_description(model_name)

    # get feature names for the connectivity vector
    X_conn = pipe[:2].transform(X)
    feature_names = pipe[:2].get_feature_names_out()

    importances = []

    # cross-validated permutation importance
    for train, test in tqdm(CV.split(X,y), total=CV.get_n_splits(X,y), desc='CV', leave=False):
        pipe.fit(X[train], y[train])

        results = permutation_importance(pipe[2:], X_conn[test], y[test],
                                        scoring=grid.scoring,
                                        n_jobs=-1)
        importances.append(results.importances.T)

    feature_dim_name = f'{"_".join(model_name.split("_")[0:2])}_feature'

    importances_cv_dataset = xr.Dataset({
        f'{model_name} importances': (('permutation_importance_num', feature_dim_name), np.vstack(importances))},
        coords={feature_dim_name: feature_names}
    )

    importances_cv.append(importances_cv_dataset)
    
    # sort by mean importance
    importances = pd.DataFrame(np.vstack(importances), columns=feature_names)
    sorted_columns = importances.mean(axis=0).sort_values(ascending=False).index
    importances = importances[sorted_columns]

ds_perm_importance = xr.merge(importances_cv)

In [None]:
# SHAP

if not ENABLE_SHAP:
    raise ValueError('ENABLE_SHAP must be True to run SHAP analysis.')


import shap
import logging

# turn off shap info-level logs while using progress bars
logging.getLogger('shap').setLevel(logging.WARNING)

top_models = pd.DataFrame(grid.cv_results_).sort_values('rank_test_score')[:N_TOP_MODELS].loc[:,'params'].to_list()
shap_agg = []

for p in (progress_bar := tqdm(top_models)):

    pipe.set_params(**p)

    model_name = generate_model_name(p)
    progress_bar.set_description(model_name)

    shap_values_cv = []
    test_indices = []
    y_test_cv = []
    y_pred_cv = []

    feature_names = pipe[:2].get_feature_names_out()

    X_conn = pipe[:2].fit_transform(X, y)

    for train, test in tqdm(CV.split(X, y), total=CV.get_n_splits(X, y), desc='CV', leave=False):

        shap_model = pipe[2:].fit(X_conn[train], y[train])

        y_pred = shap_model.predict(X_conn[test])

        test_indices.extend(test)
        y_test_cv.append(y[test])
        y_pred_cv.append(y_pred)

        explainer = shap.Explainer(
            shap_model.predict, X_conn[train],
            feature_names=feature_names,
            # approximate=True,
            # model_output='raw',
            # feature_perturbation='interventional',
        )

        shap_values = explainer(X_conn[test], max_evals=2*len(feature_names) + 1)#, check_additivity=True)

        shap_values_cv.append(shap_values)

    # merge CV SHAPs

    # X = subjects.reshape(-1, 1)
    # X_test = pd.DataFrame(X[np.hstack(test_indices)], columns=['subject'])
    y_test = np.hstack(y_test_cv)
    y_pred = np.hstack(y_pred_cv)

    shap_values = shap.Explanation(
      values = np.vstack([sh.values for sh in shap_values_cv]),
      base_values = np.hstack([sh.base_values for sh in shap_values_cv]),
      data = np.vstack([sh.data for sh in shap_values_cv]),
      feature_names=feature_names,
      compute_time=np.sum([sh.compute_time for sh in shap_values_cv]),
      output_names=y_encoder.classes_,
      output_indexes=y_pred,
    )

    feature_dim_name = f'{"_".join(model_name.split("_")[0:2])}_feature'

    shap_ds = xr.Dataset({
      f'{model_name} shap': (('shap_dim', feature_dim_name), shap_values.values),
      f'{model_name} shap data': (('shap_dim', feature_dim_name), shap_values.data),
      f'{model_name} shap y_test': (('shap_dim'), y_encoder.inverse_transform(y_test)),
      f'{model_name} shap y_pred': (('shap_dim'), y_encoder.inverse_transform(y_pred)),
      },
      coords={feature_dim_name: feature_names}
    )

    shap_agg.append(shap_ds)

ds_shap = xr.merge(shap_agg)

In [None]:
# LEARNING CURVE ANALYSIS (HOW DOES TRAIN/TRAIN SIZE IMPACT ACCURACY?)
# Note: this only analyze the best model

if not ENABLE_LEARNING_CURVE:
    raise ValueError('ENABLE_LEARNING_CURVE must be True to run learning curve analysis.')

train_sizes, train_scores, test_scores = learning_curve(grid.best_estimator_, X, y,
                                                        cv=CV,
                                                        scoring='accuracy',
                                                        n_jobs=-1,
                                                        shuffle=True,
                                                        train_sizes=np.array([16, 18, 20, 22, 24]))


learning_curve_results = pd.DataFrame({
    'learning_curve_train_size': train_sizes,
    'learning_curve_mean_train_score': train_scores.mean(axis=1),
    'learning_curve_mean_test_score': test_scores.mean(axis=1)
})

learning_curve_results.index.name  = 'learning_curve_num'

ds_learning_curve = learning_curve_results.to_xarray()

In [None]:
# %%script echo Skipping...

# STORE RESULTS

datasets = [
    {'X': xr.DataArray(X.flatten(), dims=['subject'])},
    {'y': xr.DataArray(y_encoder.inverse_transform(y), dims='subject')},
    {'y_classes': y_encoder.classes_},
    ds_grid
]

datasets.append(ds_perm_test) if ENABLE_PERMUTATION_TEST else None
datasets.append(ds_perm_importance) if ENABLE_PERMUTATION_IMPORTANCE else None
datasets.append(ds_shap) if ENABLE_SHAP else None
datasets.append(ds_learning_curve) if ENABLE_LEARNING_CURVE else None

results = xr.merge(datasets)

with open(OUTPUT_PATH, 'wb') as f:
    results.to_netcdf(f, engine='h5netcdf')
    results.close()

# reload from disk
results = xr.open_dataset(OUTPUT_PATH, engine='scipy').load()

results