## Setup

In [None]:
import numpy as np
import xarray as xr
import pandas as pd

from sklearn import feature_selection, svm, preprocessing, model_selection, ensemble
from sklearn import neighbors
from sklearn.pipeline import Pipeline

import xgboost
import shap

from nilearn import datasets as nilean_datasets

import matplotlib.pyplot as plt
import seaborn as sns; sns.set('paper')
import plotly.express as px
from tqdm import tqdm

shap.initjs()

# Technical reproducibility
%reload_ext watermark
%watermark -iv -co -ituhmv

In [None]:
# parameters

TRIU_K = 0
DATASET_NAME = 'dosenbach2010_tangent'

## Data

In [None]:
atlases = ['dosenbach2007', 'dosenbach2010', 'difumo_64_2',]# 'difumo_128_2',]# 'difumo_1024_2']

connectivity_measures = ['tangent', 'precision', 'correlation',
                         'covariance', 'partial_correlation']

DATASETS = dict()

for atlas in atlases:
  for connectivity in connectivity_measures:
    _conn_key = f'{connectivity}_connectivity'
    ds = xr.open_dataset(f'data/julia2018_resting/connectivity_{atlas}.nc')
    _conn = ds[_conn_key]
    _conn.coords['group'] = ds.group
    _conn['inverse_efficiency_score_ms'] = ds['inverse_efficiency_score_ms']
    DATASETS[f'{atlas}_{connectivity}'] = _conn
    
    if 'difumo_names' in ds.coords:
      _conn.coords['region'] = ds.coords['difumo_names'].values


In [None]:
dataset = DATASETS[DATASET_NAME]

behavioral_scores = dataset['inverse_efficiency_score_ms'].values

# remove subjects with missing behavioral data or duplicate scanning sessions
# subject_labels = xr.concat([dataset['subject'], dataset['subject'] + 'NEW'], dim='subject')
# invalid_subjects = subject_labels.to_series().duplicated(keep='first')[32:]
# invalid_subjects = invalid_subjects | np.isnan(behavioral_scores)
invalid_subjects = np.isnan(behavioral_scores)

In [None]:
regions = dataset.coords['region'].values

feature_names = pd.DataFrame(
  np.empty((len(regions), len(regions))),
  index=regions, columns=regions)

feature_names = feature_names.apply(lambda x: x.index + ' \N{left right arrow} ' + x.name)
feature_names = feature_names.values[np.triu_indices_from(feature_names.values, k=TRIU_K)]

# X
X = np.array([subj_conn[np.triu_indices_from(subj_conn, k=TRIU_K)]
              for subj_conn in dataset.values])

X_threshold = np.median(X, axis=1) + np.std(X, axis=1)
X = np.where(np.abs(X) >= X_threshold.reshape(-1,1), X, 0)

# y
y_encoder = preprocessing.LabelEncoder()
y = y_encoder.fit_transform(dataset['group'])

# remove subjects with missing behavioral data
X = X[~invalid_subjects]
y = y[~invalid_subjects]

## Model

In [None]:

# ALT 1: SVM
model = Pipeline([
    ('zerovar', feature_selection.VarianceThreshold(.01)),
    ('model', svm.SVC(kernel='rbf', C=1, probability=True)),
], verbose=False)

# ALT2: RandomForest
model = ensemble.RandomForestClassifier(n_estimators=100, max_depth=10, n_jobs=-1)

# ALT3: XGBoost
model = xgboost.XGBClassifier(
    n_estimators=100, max_depth=10,
    use_label_encoder=False,
    eval_metric='auc',
    n_jobs=-1)

# TODO hyper-parameter tuning

train, test = model_selection.train_test_split(
  range(len(X)),
  test_size=0.5,
  shuffle=True,
  stratify=y,
)

model.fit(X[train], y[train])

In [None]:
%%script echo skipping...

# DEBUG /start

test_score = model_selection.cross_val_score(
    model, X, y, n_jobs=-1, scoring='roc_auc',
    cv = model_selection.RepeatedStratifiedKFold(n_splits=5, n_repeats=20),
)

print(f'mean(CV-AUC): {test_score.mean():.2f}')

perm_score, _, pvalue = model_selection.permutation_test_score(
    model, X, y,
    cv = model_selection.RepeatedStratifiedKFold(n_splits=5, n_repeats=10),
    n_jobs=-1,
    n_permutations=10,
    # cv=5,
    scoring='roc_auc')

print(f'Permutation test AUC: {perm_score:.2f} (p-value={pvalue:.3f})')

# DEBUG /end

## CV-SHAP

In [None]:
shap_values_cv = []
X_test_indices_cv = []
y_test_cv = []
y_pred_cv = []
expected_value_cv = []

cv = model_selection.RepeatedStratifiedKFold(n_splits=5, n_repeats=200)

n_splits = cv.get_n_splits(X, y)

for train, test in tqdm(cv.split(X, y), total=n_splits):

    # train the model
    model.fit(X[train], y[train])
    y_pred = model.predict(X[test])
    
    # # fit explainer
    explainer = shap.Explainer(
        model, X[train],
        feature_names=feature_names,
        algorithm='tree',
        # output_names=dataset['group'].values[train],
        # feature_perturbation='tree_path_dependent'
    )

    # evaluate explainer
    # Note: for Permutation explainer, add max_evals= 100 * X.shape[1] + 1
    shap_values = explainer(X[test])
    # shap_interaction_values = explainer.shap_interaction_values(X[test])

    shap_values_cv.append(shap_values)
    expected_value_cv.append(explainer.expected_value)
    X_test_indices_cv.append(test)
    y_test_cv.append(y[test])
    y_pred_cv.append(y_pred)

# merge CV results
shap_values = np.vstack([sh_val.values for sh_val in shap_values_cv])
X_test = pd.DataFrame(X[np.hstack(X_test_indices_cv)], columns=feature_names)
y_test = np.hstack(y_test_cv)
y_pred = np.hstack(y_pred_cv)

In [None]:
shap_values = shap.Explanation(
  values = np.vstack([sh.values for sh in shap_values_cv]),
  base_values = np.hstack([sh.base_values for sh in shap_values_cv]),
  data = np.vstack([sh.data for sh in shap_values_cv]),
  feature_names=shap_values_cv[0].feature_names,
  compute_time=np.sum([sh.compute_time for sh in shap_values_cv]),
  output_names=y_encoder.classes_,
  output_indexes=y_pred,
)

In [None]:
shap.plots.beeswarm(shap_values, max_display=20, alpha=.7)

In [None]:
# TODO use explainers to calculate base and subsample to speed up plotting
# shap.force_plot(np.mean(expected_values_cv), shap_values, X_test, feature_names=feature_names)

In [None]:
clustering = None
# clustering = shap.utils.hclust(X_test.iloc[:,:10], y_test)

shap.summary_plot(shap_values, X_test, plot_type='bar')

In [None]:
misclassified = y_pred != y_test

shap.plots.decision(np.mean(expected_value_cv),
                    shap_values.values,#[misclassified],
                    feature_names=feature_names.tolist(),
                    # feature_display_range=range(10, -1, -1),
                    link='logit',
                    # feature_order='hclust',
                    highlight=misclassified,
                    legend_labels=y_encoder.classes_.tolist()
                    )

In [None]:
def get_network_name(region, dataset_name=None):

  if 'difumo' in dataset_name.lower():
    atlas = nilean_datasets.fetch_atlas_difumo(
      dimension=64, resolution_mm=2, legacy_format=False)

    labels = atlas.labels.set_index('difumo_names')
    
    return labels.loc[region,'yeo_networks17']
  
  elif 'dosenbach2010' in dataset_name.lower():
    atlas = nilean_datasets.fetch_coords_dosenbach_2010(legacy_format=False)
    labels = pd.concat(
      [pd.DataFrame(v) for k, v in atlas.items() if k != 'description'], axis=1)
    labels.set_index(0, inplace=True)
    
    return labels.loc[region,'network']

  raise Exception('Invalid atlas name.')

In [None]:

shap2d_size = len(dataset.coords['region'])

agg_shap_values = np.abs(shap_values.values).sum(axis=0)


shap2d_values = np.zeros((shap2d_size, shap2d_size))
shap2d_triu_indices = np.triu_indices(shap2d_size, k=TRIU_K)
shap2d_values[shap2d_triu_indices] = agg_shap_values
shap2d_values = shap2d_values + shap2d_values.T - np.diag(np.diag(shap2d_values))

shap2d = pd.DataFrame(shap2d_values,
                      index=dataset.coords['region'],
                      columns=dataset.coords['region'])


sorted_shap_indices = np.argsort(agg_shap_values)[::-1]
top_n = 10
triu_idx = sorted_shap_indices[:top_n]
row_idx = np.triu_indices(shap2d_size, k=TRIU_K)[0][triu_idx]
col_idx = np.triu_indices(shap2d_size, k=TRIU_K)[1][triu_idx]

# DEBUG make sure indices are mapped correctly
assert np.all(shap2d_values[row_idx,col_idx] == agg_shap_values[triu_idx])


print('Top contributing connectivities:')
for i, (row, col) in enumerate(zip(row_idx, col_idx)):
  # print(f'{shap2d.index[i]} - {shap2d.columns[j]}')
  row_region = dataset.coords['region'].values[row]
  col_region = dataset.coords['region'].values[col]
  row_net = get_network_name(row_region, DATASET_NAME)
  col_net = get_network_name(col_region, DATASET_NAME)  
  print(f'{i+1}) {row_region} \N{left right arrow} {col_region} '
        f'[{row_net} \N{left right arrow} {col_net}]')
  
# Top contributing connectivities:
# 1) sup parietal 86 ↔ sup parietal 86 [occipital ↔ occipital]
# 2) vFC 40 ↔ vFC 40 [cingulo-opercular ↔ cingulo-opercular]
# 3) vlPFC 12 ↔ vlPFC 12 [default ↔ default]
# 4) IPS 134 ↔ IPS 134 [sensorimotor ↔ sensorimotor]
# 5) occipital 92 ↔ occipital 92 [sensorimotor ↔ sensorimotor]
# 6) post occipital 153 ↔ post occipital 153 [default ↔ default]
# 7) mid insula 56 ↔ mid insula 56 [occipital ↔ occipital]
# 8) occipital 137 ↔ occipital 137 [sensorimotor ↔ sensorimotor]
# 9) angular gyrus 124 ↔ angular gyrus 124 [sensorimotor ↔ sensorimotor]
# 10) mPFC 4 ↔ mPFC 4 [sensorimotor ↔ sensorimotor]

# Top contributing connectivities:
# 1) sup parietal 86 ↔ sup parietal 86 [occipital ↔ occipital]
# 2) vFC 40 ↔ vFC 40 [cingulo-opercular ↔ cingulo-opercular]
# 3) IPS 134 ↔ IPS 134 [sensorimotor ↔ sensorimotor]
# 4) vlPFC 12 ↔ vlPFC 12 [default ↔ default]
# 5) occipital 137 ↔ occipital 137 [sensorimotor ↔ sensorimotor]
# 6) post occipital 153 ↔ post occipital 153 [default ↔ default]
# 7) mid insula 56 ↔ mid insula 56 [occipital ↔ occipital]
# 8) occipital 92 ↔ occipital 92 [sensorimotor ↔ sensorimotor]
# 9) angular gyrus 124 ↔ angular gyrus 124 [sensorimotor ↔ sensorimotor]
# 10) mPFC 4 ↔ mPFC 4 [sensorimotor ↔ sensorimotor]

In [None]:
network_names = shap2d.index.to_frame(name='network').apply(
  lambda x: get_network_name(x, DATASET_NAME)
  )

palt = dict(zip(network_names['network'].unique(),
                sns.color_palette('Set1', network_names.nunique()['network'])))

row_colors = network_names.apply(
  lambda x: pd.Series((palt[x['network']], x['network'])), axis=1)
row_colors.rename(columns={0:'color', 1:'network'}, inplace=True)
row_colors.index.name = 'region'

import matplotlib.patches as mpatches

row_colors_legend = [mpatches.Patch(color=c, label=n)
                     for l,c,n in row_colors.drop_duplicates('network').itertuples()]

g = sns.clustermap(
  shap2d,
  figsize=(10,10),
  row_colors=row_colors[['color']],
  robust=True,
  dendrogram_ratio=0.0001,
  cbar_pos=(1.1, .79, 0.01, 0.2),
  cmap='Blues')

legend2=g.ax_heatmap.legend(
  loc='center left',
  bbox_to_anchor=(1.25,0.65),
  handles=row_colors_legend,
  frameon=True)

plt.suptitle('Clustered SHAP values\n'
             'Notes: Color bar shows the brain networks. '
             'Only a subset of labels are shown.', x=0.02, y=1.02, ha='left')

plt.show()

## interactive heatmap (but not clustered)
# fig = px.imshow(shap2d, aspect='auto', height=800)
# fig.show()

In [None]:

from nilearn import plotting

if 'dosenbach2010' in DATASET_NAME:
  atlas = nilean_datasets.fetch_coords_dosenbach_2010(legacy_format=False)
  atlas_coordinates = atlas['rois'].values
elif 'difumo' in DATASET_NAME:
  atlas = nilean_datasets.fetch_atlas_difumo(64, 2, legacy_format=False)
  atlas_coordinates = plotting.find_probabilistic_atlas_cut_coords(maps_img=atlas.maps)
  labels = atlas.labels.set_index('difumo_names')

In [None]:

# METHOD 1: aggregate all shaps for each node
# node_strength = np.sum(shap2d.values, axis=1).reshape(-1, 1)
# node_strength = preprocessing.StandardScaler().fit_transform(shap2d.values)
# node_strength = node_strength.sum(axis=1) * 4

# METHOD 2: just use the node's self edge strength
node_strength = np.diag(shap2d) / 2

plotting.plot_connectome(
  shap2d, atlas_coordinates,
  node_color=row_colors['color'],
  colorbar=True,
  node_size=node_strength,
  title='SHAP values of the edges and aggregated node values.',
  edge_threshold='95%',)

plt.gca().legend(
  loc='center left',
  bbox_to_anchor=(2,0.5),
  handles=row_colors_legend,
  frameon=True)

plt.show()

In [None]:
node_strength = np.diag(shap2d)

plotting.plot_markers(
    node_strength,
    atlas_coordinates,
    node_size=node_strength / 4,
    title='Node strength (SHAP values)',
    node_cmap='Blues'
)

plt.show()

In [None]:
view = plotting.view_connectome(
    shap2d.values, atlas_coordinates,
    node_color=row_colors['color'],
    node_size=np.diag(shap2d)/20, edge_threshold=10,
    colorbar_fontsize=12,
    title=f'SHAP values ({DATASET_NAME})')

view