# Cerebellum Connectivity Classifier

Steps:
1. Load the data
2. Extract the cerebellum features (from DiFuMo atlas)
2. Fit a SVM + HPO


## Inputs

Cerebellum activities from the DiFuMo atlas.

## Outputs

- Classification output: Participant's label, either AVGP or NVGP.
- Results:
  - `models/cerebellum_classifier_*.nc`


## Requirements

To run this notebook, you need to activate `acnets` environment using `conda activate acnets`.

# TODO:
- Add support for cerebellum in the ConnectivityPipeline


In [1]:
# 0. SETUP

%reload_ext autoreload
%autoreload 3

import numpy as np
import pandas as pd
from pathlib import Path
import scipy.stats as st
import xarray as xr
from src.acnets.pipeline import CerebellumPipeline, ConnectivityVectorizer
from sklearn.feature_selection import SelectFromModel, VarianceThreshold
from sklearn.inspection import permutation_importance
from sklearn.model_selection import (GridSearchCV, StratifiedShuffleSplit,
                                     cross_val_score, learning_curve,
                                     permutation_test_score)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.svm import LinearSVC
from tqdm.auto import tqdm

In [2]:
from nilearn import datasets

atlas = datasets.fetch_atlas_difumo(
    dimension=256,  # 64, 128, 256, 512, 1024
    resolution_mm=2,
    legacy_format=False)

atlas.labels.query('difumo_names.str.lower().str.contains("cerebellum")').shape
# atlas.labels['yeo_networks17'].unique()

(23, 7)

In [3]:
# 0.1. PARAMETERS

CV = StratifiedShuffleSplit(n_splits=100, test_size=8)
N_PERMUTATIONS = 10
N_TOP_MODELS = 5

MODELS_DIR= Path('models')

In [4]:
# 1. DATA

subjects = CerebellumPipeline().transform('all').coords['subject'].values
groups = [s[:4] for s in subjects]  # AVGP or NVGP

X = subjects.reshape(-1, 1)

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(groups)

In [5]:
#PREPARE OUTPUT

n_cv_fold = int(X.shape[0] / CV.test_size)

model_output_name = ('cerebellum'
                     '_classifier-SVML1'
                     'scoring-accuracy'
                     f'_top-{N_TOP_MODELS}'
                     f'_cv-{CV.get_n_splits()}x{n_cv_fold}fold.nc'
                     )

OUTPUT_PATH = MODELS_DIR / model_output_name

In [6]:
# 2. PIPELINE

pipe  = Pipeline([
    ('connectivity', CerebellumPipeline()),
    ('vectorize', ConnectivityVectorizer()),
    ('scale', StandardScaler()),
    ('zerovar', VarianceThreshold()),
    ('select', SelectFromModel(LinearSVC(penalty='l1', dual=False, max_iter=10000), max_features=10)),
    ('clf', LinearSVC(penalty='l1', dual=False, max_iter=10000))
])

# DEBUG
pipe.fit(X, y).score(X, y)

1.0

In [7]:
# 2.1. VERIFY THE MODEL
pipe.set_params(connectivity__difumo_dimension=128, connectivity__kind='tangent', connectivity__agg_networks=True)

scores = cross_val_score(pipe, X, y,
                         cv=CV,
                         scoring='accuracy',
                         n_jobs=-1)
bootstrap_ci = st.bootstrap(scores.reshape(1,-1), np.mean)
scores.mean(), scores.std(), bootstrap_ci

Downloading data from https://osf.io/wjvd5/download ...
Downloading data from https://osf.io/wjvd5/download ...
Downloading data from https://osf.io/wjvd5/download ...
Downloading data from https://osf.io/wjvd5/download ...
Downloading data from https://osf.io/wjvd5/download ...
Downloading data from https://osf.io/wjvd5/download ...
Downloading data from https://osf.io/wjvd5/download ...
Downloading data from https://osf.io/wjvd5/download ...


 ...done. (2 seconds, 0 min)
Extracting data from /home/morteza/nilearn_data/difumo_atlases/21d764636e41d335113bb464d8fdcb60/download..... done.


In [11]:
# 3. HPO: GRID SEARCH

param_grid = {
    'connectivity__difumo_dimension': [64, 128, 256, 512, 1024],
    # 'connectivity__atlas': ['seitzman2018'],
    'connectivity__kind': ['tangent'],
}

grid = GridSearchCV(
    pipe,
    param_grid,
    cv=CV,
    verbose=1,
    scoring='accuracy')

grid.fit(X, y)

print('best estimator:', grid.best_estimator_)


Fitting 100 folds for each of 5 candidates, totalling 500 fits


: 

: 

In [None]:
# 3.1. STORE GRID SEARCH RESULTS

#STORE pd.DataFrame(grid.cv_results_).set_index('params')
#STORE grid.scoring, grid.cv.test_size,  grid.cv.n_splits, n_subjects
grid_results = pd.DataFrame(grid.cv_results_)

grid_results['grid_model_name'] = grid_results['params'].apply(lambda x: ' '.join(x.values()))
grid_results.set_index('grid_model_name', inplace=True)
grid_results.drop(columns=['params'], inplace=True)

ds_grid = grid_results.to_xarray()
ds_grid['scoring'] = grid.scoring
ds_grid['cv_test_size'] = CV.test_size
ds_grid['cv_n_splits'] = CV.n_splits
ds_grid['n_subjects'] = len(X)

In [None]:
# 4. PERMUTATION TEST (SHUFFLE Y)


perm_scores_agg = []
cv_scores_agg = []
pvalues = []
model_names = []

# sort by rank and take top n_top_models
top_models = pd.DataFrame(grid.cv_results_).sort_values('rank_test_score')[:N_TOP_MODELS].loc[:,'params'].to_list()

for p in tqdm(top_models):
    model_name = ' '.join(p.values())
    
    pipe.set_params(**p)

    # break if it's a low score

    _, perm_scores, pvalue = permutation_test_score(pipe, X, y,
                                                    scoring='accuracy',
                                                    n_permutations=N_PERMUTATIONS,
                                                    cv=4,
                                                    n_jobs=-1, verbose=0)

    cv_scores = cross_val_score(pipe, X, y,
                                cv=CV,
                                scoring='accuracy', n_jobs=-1)

    perm_scores_agg.append(perm_scores)
    cv_scores_agg.append(cv_scores)
    pvalues.append(pvalue)
    model_names.append(model_name)

ds_perm = xr.Dataset({
    'perm_scores': (('model_name', 'permutation_dim'), perm_scores_agg),
    'cv_scores': (('model_name', 'cv_dim'), cv_scores_agg),
    'pvalue': (('model_name'), pvalues)},
    coords={'model_name': model_names})

In [None]:
# 5. FEATURE IMPORTANCE (SHUFFLE X)

importances_agg = []

for p in top_models:
    model_name = ' '.join(p.values())

    pipe.set_params(**p)

    X_conn = pipe[:2].transform(X)
    feature_names = pipe[:2].get_feature_names_out()

    importances = []

    for train, test in tqdm(CV.split(X,y), total=CV.get_n_splits(X,y)):
        pipe.fit(X[train], y[train])

        results = permutation_importance(pipe[2:], X_conn[test], y[test],
                                        scoring=grid.scoring,
                                        n_jobs=1)
        importances.append(results.importances.T)

    feature_dim_name = f'{model_name.split(" ")[0]}_feature'

    importances_ds = xr.Dataset({
        f'{model_name} importances': (('permutation_importance_num', feature_dim_name), np.vstack(importances))},
        coords={feature_dim_name: feature_names}
    )

    importances_agg.append(importances_ds)
    
    # sort by mean importance
    # importances = pd.DataFrame(np.vstack(importances), columns=feature_names)
    # sorted_columns = importances.mean(axis=0).sort_values(ascending=False).index
    # importances = importances[sorted_columns]

ds_imp = xr.merge(importances_agg)

In [None]:
# 8. STORE RESULTS

results = xr.merge([
    {'X': xr.DataArray(X.flatten(), dims=['subject'])},
    {'y': xr.DataArray(y_encoder.inverse_transform(y), dims='subject')},
    {'y_classes': y_encoder.classes_},
    ds_grid, ds_imp, ds_perm])

with open(OUTPUT_PATH, 'wb') as f:
    results.to_netcdf(f, engine='scipy')
    results.close()

results = xr.open_dataset(OUTPUT_PATH, engine='scipy').load()
results