# Correlation Plots

Includes code and results for various correlation related results like:
- Single model correlation plots (within and across subject)
- Model comparison via correlation
- Subject identification accuracy

In [None]:
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

import os

import numpy as np
import nibabel as nib
import nilearn as nil
import neuromaps as nm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle

import sys
sys.path.insert(0, '..')
from utils.utilities import CONTRASTS, GROUP_CONTRAST_IDS, plot_corr_matrices_across_contrasts, compute_corr_coeff, scale

base_data_dir = '../../data/'
mask = np.load('../data/glasser_medial_wall_mask.npy')
paper_subset_indicies = np.array([2,3,4,6,7,9,10,12,13,14,15,16,18,19,20,21,23,24,25,26,31,38,44,45])
paper_subset_contrasts = [CONTRASTS[idx] for idx in paper_subset_indicies]

test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')

test_contrasts = []
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    contrast_file = os.path.join(base_data_dir, "test_contrasts", "%s_joint_LR_task_contrasts.npy" % subj)
    contrast_data = np.load(contrast_file)
    test_contrasts.append(contrast_data)

test_contrasts = np.asarray(test_contrasts)
test_contrasts.shape

In [None]:
def plot_model_comparison_correlation(pred_by_model, colors=[ '#f99154', '#63bfa6', '#358cbb', '#3aac11', '#a89154','#e0d57e'], metric='Correlation with Groundtruth', indicies=np.arange(len(CONTRASTS)), legend_size=24, ymax=0.901, ymin=-0.1):
    colors = colors[:len(pred_by_model.keys())]
    sns.set_palette(sns.color_palette(colors))
    # metric = 'Correlation Increase of Self vs Mean Others'
    df = pd.DataFrame(columns=["Model", "Task Contrast", "Subject", "Correlation with Groundtruth", "Correlation Difference", "Rank"])
    for model in pred_by_model:
        all_subj_contrast_corr = pred_by_model[model]
        for i in indicies:
            item = CONTRASTS[i]
            task, cope_id, contrast_label = item
            key = "%s %s" % (task, contrast_label)

            contrast_corr = all_subj_contrast_corr[i]

            count = 0
            for j in range(len(test_subj_ids)):
                count = count  + 1
                corr_row = contrast_corr[j, :]
                self_corr = corr_row[j]
                other_corrs = np.concatenate((corr_row[:j], corr_row[j+1:]))
                mean_other_corr = np.mean(other_corrs)

                num_other_gt_self = np.sum(other_corrs > self_corr)
                sorted_indices = np.flip(np.argsort(corr_row))
                for k in range(len(sorted_indices)):
                    if sorted_indices[k] == j:
                        rank = k
                rank = rank + 1
                df.loc[len(df.index)] = [model,key,test_subj_ids[j],self_corr,(self_corr - mean_other_corr),rank]

    fig, ax = plt.subplots(1, 1, figsize=(36, 10))
    sns.boxplot(x="Task Contrast",
                y=metric, hue="Model",
                data=df, ax=ax, # palette="Set3",
                hue_order=list(pred_by_model.keys()))
    sns.stripplot(x="Task Contrast",
                y=metric, hue="Model",
                data=df, ax=ax, # palette="Set3",
                hue_order=list(pred_by_model.keys()),
                dodge=True, legend=False)
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha="right", rotation_mode="anchor") 
    L = ax.legend(frameon=False, ncol=len(pred_by_model.keys()), loc='upper center', bbox_to_anchor=(0.5, 1.1), fontsize=legend_size)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_ylim(ymin, ymax)
    ax.set_yticks(np.arange(ymin, ymax, 0.2))
    ax.set_xlim(-1, len(indicies))
    ax.tick_params(direction="in", labelsize=24)
    ax.xaxis.get_label().set_fontsize(40)
    ax.yaxis.get_label().set_fontsize(40)
    ax.tick_params(length = 10)
    plt.show()

def plot_model_comparison_correlation_w_performance(pred_by_model, performance, contrasts, colors=[ '#f99154', '#63bfa6', '#358cbb', '#3aac11', '#a89154','#e0d57e'], metric='Correlation with Groundtruth', indicies=np.arange(len(CONTRASTS)), legend_size=24):
    colors = colors[:len(pred_by_model.keys())]
    sns.set_palette(sns.color_palette(colors))
    
    df = pd.DataFrame(columns=["Model", "Task Contrast", "Subject", "Correlation with Groundtruth", "Correlation Increase of Self vs Mean Others", "Performance", "Rank"])
    for model in pred_by_model:
        all_subj_contrast_corr = pred_by_model[model]
        for i in indicies:
            item = contrasts[i]
            task, cope_id, contrast_label = item
            key = "%s %s" % (task, contrast_label)

            contrast_corr = all_subj_contrast_corr[i]
            task_performance = performance[i]

            count = 0
            for j in range(len(test_subj_ids)):
                count = count  + 1
                corr_row = contrast_corr[j, :]
                self_corr = corr_row[j]
                other_corrs = np.concatenate((corr_row[:j], corr_row[j+1:]))
                mean_other_corr = np.mean(other_corrs)

                num_other_gt_self = np.sum(other_corrs > self_corr)
                sorted_indices = np.flip(np.argsort(corr_row))
                for k in range(len(sorted_indices)):
                    if sorted_indices[k] == j:
                        rank = k
                rank = rank + 1
                subj_perf = 'Above' if task_performance[j] > np.mean(task_performance) else 'Below'
                df.loc[len(df.index)] = [model,key,test_subj_ids[j],self_corr,(self_corr - mean_other_corr), subj_perf, rank]

    fig, ax = plt.subplots(1, 1, figsize=(36, 10))
    sns.boxplot(x="Task Contrast",
                y=metric, hue="Model",
                data=df, ax=ax, # palette="Set3",
                hue_order=list(pred_by_model.keys()))
    sns.stripplot(x="Task Contrast",
                y="Performance", hue="Model", hue_order=['Above', 'Below'],
                data=df, ax=ax, palette=sns.color_palette("vlag", as_cmap=True),
                dodge=True, legend=True)
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha="right", rotation_mode="anchor") 
    L = ax.legend(frameon=False, ncol=len(pred_by_model.keys()), loc='upper center', bbox_to_anchor=(0.5, 1.1), fontsize=legend_size)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_ylim(-0.1, 0.901)
    ax.set_yticks(np.arange(-0.1, 0.901, 0.2))
    ax.set_xlim(-1, len(indicies))
    ax.tick_params(direction="in", labelsize=24)
    ax.xaxis.get_label().set_fontsize(40)
    ax.yaxis.get_label().set_fontsize(40)
    ax.tick_params(length = 10)
    plt.show()

In [None]:
def compute_subj_contrast_corr(pred, ref, contrasts, contrast_ids, mask):
    all_lh_corr = []
    all_rh_corr = []
    all_avg_corr = []
    
    masked_lh_pred = pred[:, ::2, mask[0, :]]
    masked_rh_pred = pred[:, 1::2, mask[1, :]]
    
    masked_lh_ref = ref[:, ::2, mask[0, :]]
    masked_rh_ref = ref[:, 1::2, mask[1, :]]

    for i in range(len(contrasts)):
        lh_contrast_ref = masked_lh_ref[:, i, :]
        rh_contrast_ref = masked_rh_ref[:, i, :]

        lh_contrast_pred = masked_lh_pred[:, i, :]
        rh_contrast_pred = masked_rh_pred[:, i, :]

        lh_corr = compute_corr_coeff(lh_contrast_ref, lh_contrast_pred)
        # print(compute_corr_coeff(lh_contrast_ref, lh_contrast_pred) == compute_corr_coeff(lh_contrast_pred, lh_contrast_ref))
        rh_corr = compute_corr_coeff(rh_contrast_ref, rh_contrast_pred)
        # print(lh_corr[0, :] - lh_corr[1, :])

        all_lh_corr.append(lh_corr)
        all_rh_corr.append(rh_corr)
        all_avg_corr.append((lh_corr + rh_corr) / 2)
    return all_lh_corr, all_rh_corr, all_avg_corr



## Retest Contrasts

In [None]:
retest = {}
for ic in [1]:
   print('--------------------------------------------------')
   label = f'Retest Contrasts'
   print(label)
   mse_brainsurf_path = f"../../data/retest_contrasts/contrasts/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_joint_LR_task_contrasts.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   
   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(paper_subset_mse_corr, paper_subset_contrasts):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   retest[label] = test_brainsurfcnn_mse_corr   

## BrainSurf CNN - MSE

In [None]:
brainsurf_mse_models = {}
for ic in [15, 25, 50, 100]:
   print('--------------------------------------------------')
   label = f'BrainSurfCNN: {ic} ICS - MSE'
   print(f'BrainSurfCNN: {ic} ICS - MSE')
   mse_brainsurf_path = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/hcp_{ic}_sample8_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_pred.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   multisample_brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_mse_pred, 1)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   plot_corr_matrices_across_contrasts(test_brainsurfcnn_mse_corr[0:3], CONTRASTS[0:3], vmin=0.35 , vmax=0.9)
   plt.show()
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainsurf_mse_models[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

In [None]:
plot_model_comparison_correlation(pred_by_model=brainsurf_mse_models)
plot_model_comparison_accuracy(pred_by_model=brainsurf_mse_models, indicies=paper_subset_indicies)
plot_model_comparison_accuracy(pred_by_model=brainsurf_mse_models)

## Bagged 5-Fold Cross Validation - MSE
25 ICs

In [None]:
brainsurf_cv_models = {}
ic = 25
all_fold_predictions = []
for fold in range(6):
   print('--------------------------------------------------')
   if fold < 5:
      label = f'BaggedSurfCNN: {ic} ICS - Fold {fold+1}'
      print(label)
      cv_brainsurf_path = f"../../brainsurf_model/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/kfold_{ic}_sample8/predict_on_test_subj_{fold}/best_corr/"

      multisample_brainsurfcnn_ft_pred = []
      for i in range(len(test_subj_ids)):
         subj = test_subj_ids[i]
         pred_file = os.path.join(cv_brainsurf_path, "%s_pred.npy" % subj)
         pred = np.load(pred_file)
         multisample_brainsurfcnn_ft_pred.append(pred)
         
      multisample_brainsurfcnn_ft_pred = np.asarray(multisample_brainsurfcnn_ft_pred)
      all_fold_predictions.append(multisample_brainsurfcnn_ft_pred)

   else:
      label = f'BaggedSurfCNN: {ic} ICS - Bagged'
      print(label)
      multisample_brainsurfcnn_ft_pred = np.mean(all_fold_predictions, 0)

   
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_ft_pred, 1)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   plot_corr_matrices_across_contrasts(test_brainsurfcnn_mse_corr[0:3], CONTRASTS[0:3], vmin=0.35 , vmax=0.9)
   plt.show()
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainsurf_cv_models[label] = test_brainsurfcnn_mse_corr  
   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

In [None]:
plot_model_comparison_correlation(pred_by_model=brainsurf_cv_models, legend_size=18)
plot_model_comparison_accuracy(pred_by_model=brainsurf_cv_models, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=brainsurf_cv_models)

## BrainSurfATN - MSE

In [None]:
brainsurfatn_mse_models = {}
for ic in [15, 25, 50, 100]:
   print('--------------------------------------------------')
   label = f'BrainSurfATN: {ic} ICS - MSE'
   print(label)
   mse_brainsurf_path = f"../../brainsurf_model/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/attn_{ic}_sample8_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_pred.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   multisample_brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_mse_pred, 1)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   plot_corr_matrices_across_contrasts(test_brainsurfcnn_mse_corr[0:3], CONTRASTS[0:3], vmin=0.35 , vmax=0.9)
   plt.show()
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainsurfatn_mse_models[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

In [None]:
plot_model_comparison_correlation(pred_by_model=brainsurfatn_mse_models, legend_size=18)
plot_model_comparison_accuracy(pred_by_model=brainsurfatn_mse_models, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=brainsurfatn_mse_models)

## BrainSERF - MSE

In [None]:
brainserf_mse_models = {}
for ic in [15, 25, 50]:
   print('--------------------------------------------------')
   label = f'BrainSERF: {ic} ICS - MSE'
   print(f'BrainSERF: {ic} ICS - MSE')
   mse_brainsurf_path = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_{ic}_sample8_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_pred.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   multisample_brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_mse_pred, 1)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   plot_corr_matrices_across_contrasts(test_brainsurfcnn_mse_corr[0:3], CONTRASTS[0:3], vmin=0.35 , vmax=0.9)
   plt.show()
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainserf_mse_models[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

In [None]:
plot_model_comparison_correlation(pred_by_model=brainserf_mse_models, legend_size=18)
plot_model_comparison_accuracy(pred_by_model=brainserf_mse_models, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=brainserf_mse_models)

## BrainSurfGNN - MSE 

In [None]:
brainsurfgnn_mse_models = {}
for ic in [15, 25, 50, 100]:
   print('--------------------------------------------------')
   label = f'BrainSurfGNN: {ic} ICS - MSE'
   print(f'BrainSurfGNN: {ic} ICS - MSE')
   mse_brainsurf_path = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_mse_larger_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_pred.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   multisample_brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_mse_pred, 1)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   plot_corr_matrices_across_contrasts(test_brainsurfcnn_mse_corr[0:3], CONTRASTS[0:3], vmin=0.35 , vmax=0.9)
   plt.show()
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainsurfgnn_mse_models[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

## Compare all MSE models

25 ICs Only

In [None]:
mse_comparison = {}
mse_comparison['BrainSurfCNN'] = brainsurf_mse_models['BrainSurfCNN: 100 ICS - MSE']
mse_comparison['BrainSERF'] = brainserf_mse_models['BrainSERF: 25 ICS - MSE']
mse_comparison['BrainSurfGCN'] = brainsurfgnn_mse_models['BrainSurfGNN: 100 ICS - MSE']
mse_comparison['BaggedSurfCNN'] = brainsurf_cv_models['BaggedSurfCNN: 25 ICS - Bagged']
mse_comparison['BrainSurfATN'] = brainsurfatn_mse_models['BrainSurfATN: 15 ICS - MSE']
mse_comparison['Retest'] = retest['Retest Contrasts']

plot_model_comparison_correlation(pred_by_model=mse_comparison, legend_size=18)
plot_model_comparison_accuracy(pred_by_model=mse_comparison, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=mse_comparison)


## BrainSurfCNN - FineTuned

In [None]:
brainsurf_ft_models = {}
predictions = []
for ic in [15, 25, 50, 100]:
   print('--------------------------------------------------')
   label = f'BrainSurfCNN: {ic} ICS - Fine Tuned'
   print(f'BrainSurfCNN: {ic} ICS - Fine Tuned')
   mse_brainsurf_path = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_pred.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   multisample_brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_mse_pred, 1)
   predictions.append(brainsurfcnn_mse_pred)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   # plot_corr_matrices_across_contrasts(test_brainsurfcnn_mse_corr[0:3], CONTRASTS[0:3], vmin=0.35 , vmax=0.9)
   # plt.show()
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainsurf_ft_models[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

print('--------------------------------------------------')
label = f'BrainSurfCNN: All ICS - Fine Tuned'
print(label)
brainsurfcnn_mse_pred = np.array(predictions).mean(0)
_, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

accs = []
for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
   matrix = scale(matrix, axis=0)
   matrix = matrix.T
   for i, row in enumerate(matrix):
      total += 1
      correct += int(np.argmax(row) == i)
   accs.append(correct/total * 100.0)
   correct = 0
   total = 0
print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
print('Avg ACC:',np.mean(accs))
brainsurf_ft_models[label] = test_brainsurfcnn_mse_corr 


In [None]:
plot_model_comparison_correlation(pred_by_model=brainsurf_ft_models, legend_size=18)
plot_model_comparison_accuracy(pred_by_model=brainsurf_ft_models, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=brainsurf_ft_models)

## BrainSERF - FineTuned

In [None]:
brainserf_ft_models = {}
predictions = []
for ic in [15, 25, 50]:
   print('--------------------------------------------------')
   label = f'BrainSERF: {ic} ICS - Fine Tuned'
   print(label)
   mse_brainsurf_path = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_pred.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   multisample_brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_mse_pred, 1)
   predictions.append(brainsurfcnn_mse_pred)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainserf_ft_models[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))
print('--------------------------------------------------')
label = f'BrainSERF: All ICS - Fine Tuned'
print(label)
brainsurfcnn_mse_pred = np.array(predictions).mean(0)
_, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

accs = []
for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
   matrix = scale(matrix, axis=0)
   matrix = matrix.T
   for i, row in enumerate(matrix):
      total += 1
      correct += int(np.argmax(row) == i)
   accs.append(correct/total * 100.0)
   correct = 0
   total = 0
print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
print('Avg ACC:',np.mean(accs))
brainserf_ft_models[label] = test_brainsurfcnn_mse_corr  

In [None]:
plot_model_comparison_correlation(pred_by_model=brainserf_ft_models, legend_size=18)
plot_model_comparison_accuracy(pred_by_model=brainserf_ft_models, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=brainserf_ft_models)

## BrainSurfGCN - FineTuned

In [None]:
brainsurfgnn_ft_models = {}
predictions = []
for ic in [15, 25, 50, 100]:
   print('--------------------------------------------------')
   label = f'BrainSurfGCN: {ic} ICS - Fine Tuned'
   print(label)
   mse_brainsurf_path = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_pred.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   multisample_brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   brainsurfcnn_mse_pred = np.mean(multisample_brainsurfcnn_mse_pred, 1)
   predictions.append(brainsurfcnn_mse_pred)

   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
   # for matrix, contrast in zip(paper_subset_mse_corr, paper_subset_contrasts):
      # degree = np.linalg.
      matrix = scale(matrix, axis=0)
      # matrix = scale(matrix, axis=1)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   brainsurfgnn_ft_models[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

print('--------------------------------------------------')
label = f'BrainSurfGCN: All ICS - Fine Tuned'
print(label)
brainsurfcnn_mse_pred = np.array(predictions).mean(0)
_, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

accs = []
for matrix, contrast in zip(test_brainsurfcnn_mse_corr, CONTRASTS):
   matrix = scale(matrix, axis=0)
   matrix = matrix.T
   for i, row in enumerate(matrix):
      total += 1
      correct += int(np.argmax(row) == i)
   accs.append(correct/total * 100.0)
   correct = 0
   total = 0
print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
print('Avg ACC:',np.mean(accs))
brainsurfgnn_ft_models[label] = test_brainsurfcnn_mse_corr   

In [None]:
plot_model_comparison_correlation(pred_by_model=brainsurfgnn_ft_models, legend_size=18)
plot_model_comparison_accuracy(pred_by_model=brainsurfgnn_ft_models, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=brainsurfgnn_ft_models)

## Retest Contrasts

In [None]:
retest = {}
for ic in [1]:
   print('--------------------------------------------------')
   label = f'Retest Contrasts'
   print(label)
   mse_brainsurf_path = f"../../data/retest_contrasts/contrasts/"

   multisample_brainsurfcnn_mse_pred = []
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(mse_brainsurf_path, "%s_joint_LR_task_contrasts.npy" % subj)
      pred = np.load(pred_file)
      multisample_brainsurfcnn_mse_pred.append(pred)

   brainsurfcnn_mse_pred = np.asarray(multisample_brainsurfcnn_mse_pred)
   
   _, _, test_brainsurfcnn_mse_corr = compute_subj_contrast_corr(test_contrasts, brainsurfcnn_mse_pred, CONTRASTS, GROUP_CONTRAST_IDS, mask)

   plot_corr_matrices_across_contrasts(test_brainsurfcnn_mse_corr[0:3], CONTRASTS[0:3], vmin=0.35 , vmax=0.9)
   plt.show()
   plot_corr_matrices_across_contrasts([test_brainsurfcnn_mse_corr[7], test_brainsurfcnn_mse_corr[4], test_brainsurfcnn_mse_corr[12]], [CONTRASTS[7], CONTRASTS[4], CONTRASTS[12]], vmin=0.35 , vmax=0.9)
   plt.show()

   paper_subset_mse_corr = [test_brainsurfcnn_mse_corr[idx] for idx in paper_subset_indicies]

   total = 0
   correct = 0
   accs = []
   for matrix, contrast in zip(paper_subset_mse_corr, paper_subset_contrasts):
      matrix = scale(matrix, axis=0)
      matrix = matrix.T
      for i, row in enumerate(matrix):
         total += 1
         correct += int(np.argmax(row) == i)
      accs.append(correct/total * 100.0)
      correct = 0
      total = 0

   retest[label] = test_brainsurfcnn_mse_corr   
   print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
   print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
   print('Avg ACC:',np.mean(accs))

## All Models - Fine Tuned

In [None]:
ft_comparison = {}
ft_comparison['BrainSurfCNN'] = brainsurf_ft_models['BrainSurfCNN: 25 ICS - Fine Tuned']
ft_comparison['BrainSERF (ours)'] = brainserf_ft_models['BrainSERF: 25 ICS - Fine Tuned']
ft_comparison['BrainSurfGCN (ours)'] = brainsurfgnn_ft_models['BrainSurfGCN: 50 ICS - Fine Tuned']
ft_comparison['Retest'] = retest['Retest Contrasts']


plot_model_comparison_correlation(pred_by_model=ft_comparison, legend_size=18)
plot_model_comparison_correlation(pred_by_model=ft_comparison, legend_size=18, metric='Correlation Difference', ymin=-0.2, ymax=0.51)
plot_model_comparison_accuracy(pred_by_model=ft_comparison, indicies=paper_subset_indicies, legend_size=12)
plot_model_comparison_accuracy(pred_by_model=ft_comparison)
