# ACNets: Diagonal Connectivity Classifier

This notebook fits a binary classifier to predict participant's group, AVGP or NVGP, using functional connectivity matrices. As input, it takes upper-triangular connectivity matrices for each participant.

To address the concerns about small sample size and test/train splits, results are evaluated using 5-fold cross-validated permutation testing.

## 0. Setup

In [1]:

import math
import re

import numpy as np
import xarray as xr
import pandas as pd

from IPython.display import clear_output
import matplotlib.pyplot as plt
import seaborn as sns; sns.set('notebook')

from sklearn import preprocessing, model_selection, metrics, ensemble, multioutput
from sklearn import decomposition, cross_decomposition, feature_selection, dummy, svm

from sklearn.pipeline import Pipeline

import skopt
from skopt import BayesSearchCV
from skopt.space import Real, Integer, Categorical

from python.acnets.datasets import load_connectivity


from tqdm import tqdm


# Technical reproducibility
%reload_ext watermark
%watermark -iv -co -ituhmv

%reload_ext autoreload
%autoreload 3

Last updated: 2022-02-23T14:18:03.276085+01:00

Python implementation: CPython
Python version       : 3.9.10
IPython version      : 8.0.1

conda environment: acnets

Compiler    : Clang 11.1.0 
OS          : Darwin
Release     : 21.3.0
Machine     : x86_64
Processor   : i386
CPU cores   : 12
Architecture: 64bit

Hostname: MP0159

sys       : 3.9.10 | packaged by conda-forge | (main, Feb  1 2022, 21:28:27) 
[Clang 11.1.0 ]
numpy     : 1.21.5
seaborn   : 0.11.2
sklearn   : 1.0.2
pandas    : 1.4.0
matplotlib: 3.5.1
re        : 2.2.1
xarray    : 0.21.1
skopt     : 0.9.0



# Fit the model

In [2]:
atlases = ['dosenbach2010']
connectivity_measures = ['tangent']

In [3]:
cv = model_selection.StratifiedKFold(5)

pipe = Pipeline([
  ('zv', feature_selection.VarianceThreshold()),
  ('fa', decomposition.FactorAnalysis()),
  ('clf', svm.SVC(kernel='linear', probability=False))
])

param_space = {
  'fa__rotation': Categorical(['varimax']),
  'fa__n_components': Integer(1, 10),
  'clf__C': Real(1e-3, 1e3, 'log-uniform'),
}

opt = BayesSearchCV(pipe, param_space, cv=cv, n_jobs=1)

def fit(X, y, feature_names):
  # encode y as integers
  y_encoder = preprocessing.LabelEncoder()
  y = y_encoder.fit_transform(y)
      
  # test/train splits
  train, test = model_selection.train_test_split(
    range(len(X)),
    test_size=0.2,
    shuffle=True,
    stratify=y,)
      
  progress_bar = tqdm(total=opt.total_iterations)

  opt.fit(X[train], y[train],
          callback = [
            skopt.callbacks.DeadlineStopper(total_time=300),
            lambda _: False if progress_bar.update() else False,
  ])

  progress_bar.close()
  
  # evaluate
  score_train = opt.score(X[train], y[train])
  score_test = opt.score(X[test], y[test])

  # report scores and hyperparameters
  print(f'train set score (roc_auc): {score_train:.2f}')
  print(f'test set score (roc_auc): {score_test:.2f}')

  # we don't have a hyperparameter so we pass 'pipe' instead of 'grid'
  obs_score, perm_scores, p_value = model_selection.permutation_test_score(
    opt.best_estimator_, X, y,
    cv=cv,
    scoring='roc_auc',
    n_permutations=1000,
    n_jobs=-1, verbose=0)
  
  print(obs_score, perm_scores, p_value)

for atlas in atlases:
  for kind in connectivity_measures:
    X, y, feature_names = load_connectivity(
      parcellation=atlas,
      kind=kind,
      vectorize=False,
      return_y=True,
      only_diagonal=True,
      return_feature_names=True,
      discard_diagonal=True,
      discard_cerebellum=False,)
    
    if len(X.shape) == 3:
      # binarize
      X_threshold = np.array([np.median(x, keepdims=True) + x.std(keepdims=True) for x in X])
      X = np.where(np.abs(X) > X_threshold, 1, 0)

      X = np.array([x[np.triu_indices_from(x, k=1)] for x in X])
      feature_names = feature_names.values[np.triu_indices_from(feature_names.values, k=1)]

      # remove zero-variance features
      zv_mask = (X.std(axis=0) == 0)
      X = X[:,~zv_mask]
      feature_names = feature_names[~zv_mask]
    
    # fit the FA model
    fit(X, y, feature_names)

 33%|███▎      | 50/150 [01:21<02:42,  1.63s/it]

train set score (roc_auc): 0.80
test set score (roc_auc): 0.86





0.861111111111111 [0.36666667 0.37222222 0.57222222 0.40555556 0.50555556 0.44444444
 0.76666667 0.51666667 0.41111111 0.63888889 0.42777778 0.72222222
 0.35555556 0.46111111 0.88888889 0.37777778 0.52222222 0.87777778
 0.58888889 0.48333333 0.66111111 0.61666667 0.53333333 0.38333333
 0.50555556 0.40555556 0.56111111 0.56666667 0.71111111 0.37777778
 0.63333333 0.29444444 0.59444444 0.41111111 0.62777778 0.80555556
 0.40555556 0.48888889 0.32777778 0.54444444 0.3        0.41666667
 0.24444444 0.68888889 0.56111111 0.58333333 0.35       0.59444444
 0.5        0.50555556 0.46666667 0.49444444 0.63888889 0.49444444
 0.3        0.50555556 0.4        0.72777778 0.77777778 0.3
 0.45555556 0.46111111 0.52777778 0.42222222 0.51666667 0.50555556
 0.74444444 0.47222222 0.6        0.57222222 0.54444444 0.46111111
 0.45       0.24444444 0.33888889 0.44444444 0.53333333 0.56666667
 0.41111111 0.48333333 0.46666667 0.36111111 0.18888889 0.48888889
 0.4        0.48333333 0.56111111 0.50555556 0.65  

Extract network names from the @Dosenbach2010 atlas:

In [None]:
from nilearn import datasets as nilean_datasets

atlas = nilean_datasets.fetch_coords_dosenbach_2010(legacy_format=False)
labels = pd.concat(
  [pd.DataFrame(v) for k, v in atlas.items() if k != 'description'], axis=1)
labels.set_index(0, inplace=True)

feature_network_names = labels.loc[feature_names, 'network']

Assign colors to the networks and their corresponding regions:

In [None]:
palt = dict(zip(
  feature_network_names.unique(),
  sns.color_palette('Set1', feature_network_names.nunique())))

feature_network_colors = feature_network_names.apply(
  lambda x: pd.Series((palt[x], x)))
feature_network_colors.rename(columns={0:'color', 1:'network'}, inplace=True)
feature_network_colors.index.name = 'region'

Now plot the Factor Analysis components:

In [None]:
fa_comps = pd.DataFrame(
  opt.best_estimator_.named_steps['fa'].components_,
  columns=feature_network_names
)

sns.clustermap(fa_comps.T, figsize=(5,36),
               col_cluster=False,
               robust=True,
               dendrogram_ratio=(0.2, 0.000001),
               cbar_pos=(.96, .967, 0.01, 0.03),
               row_colors=feature_network_colors['color'].tolist())

# Extra

-[ ] TODO: replicate https://www.frontiersin.org/articles/10.3389/fnhum.2014.00425/full

## FIXME: Permutation Importance

In [None]:
%%script echo skipping

from sklearn.inspection import permutation_importance

pipe.fit(X, y)
perm_imp_result = permutation_importance(pipe, X, y, 
                                         n_repeats=100,
                                         scoring='roc_auc', n_jobs=-1)

perm_sorted_idx = perm_imp_result.importances_mean.argsort()

# sns.boxplot(
#     result.importances[perm_sorted_idx].T,
#     vert=False,
#     labels=data.feature_names[perm_sorted_idx],
# )


perm_imp_result.importances[perm_sorted_idx]

In [None]:
%%script echo skipping

from scipy.stats import spearmanr
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform

corr = spearmanr(X).correlation

# Ensure the correlation matrix is symmetric
corr = (corr + corr.T) / 2
np.fill_diagonal(corr, 1)

fig, (ax1,ax2) = plt.subplots(1,2,figsize=(10, 10))

# We convert the correlation matrix to a distance matrix before performing
# hierarchical clustering using Ward's linkage.
distance_matrix = 1 - np.abs(corr)
dist_linkage = hierarchy.ward(squareform(distance_matrix))
dendro = hierarchy.dendrogram(
    dist_linkage, labels=feature_names.tolist(), ax=ax1, leaf_rotation=90
)

dendro_idx = np.arange(0, len(dendro["ivl"]))



ax2.imshow(corr[dendro["leaves"], :][:, dendro["leaves"]])
ax2.set_xticks(dendro_idx)
ax2.set_yticks(dendro_idx)
ax2.set_xticklabels(dendro["ivl"], rotation="vertical")
ax2.set_yticklabels(dendro["ivl"])
fig.tight_layout()
plt.show()

In [None]:
%%script echo skipping

from collections import defaultdict

cluster_ids = hierarchy.fcluster(dist_linkage, 1, criterion="distance")
cluster_id_to_feature_ids = defaultdict(list)
for idx, cluster_id in enumerate(cluster_ids):
    cluster_id_to_feature_ids[cluster_id].append(idx)
selected_features = [v[0] for v in cluster_id_to_feature_ids.values()]

X_clustered = X[:, selected_features]

# Permmutation Importance

from sklearn.inspection import permutation_importance

grid.fit(X_clustered[train], y[train])
perm_imp_result = permutation_importance(grid, X_clustered[test], y[test],
                                         scoring='roc_auc',
                                         n_repeats=100)

sorted_idx = perm_imp_result.importances_mean.argsort()[::-1][:10][::-1]

fig, ax = plt.subplots(figsize=(10, 10))
ax.boxplot(
    perm_imp_result.importances[sorted_idx].T,
    labels = feature_names[np.array(selected_features)[sorted_idx]],
    vert=False,
)
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()