In [None]:
from neuromaps import plotting as neuroplot
import numpy as np
import os
from glob import glob
from matplotlib import colors as mcolors, pyplot as plt
from neuromaps.images import load_gifti
from neuromaps.datasets import fetch_atlas
import sys
from surfplot import Plot
from neuromaps.datasets import fetch_fslr
import seaborn as sns
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.cm import ScalarMappable
sys.path.insert(0, '..')
from utils.utilities import CONTRASTS, GROUP_CONTRAST_IDS, plot_corr_matrices_across_contrasts, compute_corr_coeff, scale


surfaces = fetch_fslr()
lh, rh = surfaces['inflated']

In [None]:
model_numpy_dir = '../../aim3_results/HCP_feat64_s8_c15_lr0.01_seed28_epochs50/finetuned_feat64_s8_c15_lr0.01_seed28/predict_on_test_subj/best_corr/'

mask = np.load('../data/glasser_medial_wall_mask.npy')
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')

# from .utils.utilities import CONTRASTS, GROUP_CONTRAST_IDS, plot_corr_matrices_across_contrasts, compute_corr_coeff, scale

colors = [sns.color_palette("Spectral", 7)[0], sns.color_palette("Spectral", 7)[3], sns.color_palette("Spectral", 7)[6]]
# colors = ['#ff5964', '#ffe74c', '#38618c']
alphas = [0.7, 0.7, 0.7]
# alphas = [1, 1, 1]

colors_alphas = []
for c, a in zip(colors, alphas):
    colors_alphas.append((c[0], c[1], c[2], a))

my_cmap = ListedColormap(colors)
norm = BoundaryNorm([0.5,1.5,2.5,3.5], my_cmap.N, clip=False)

In [None]:
def plot_cortex_overlap(data, surface, subj, contrast_index, experiment, output_path, cmap, norm, labels=['gt', 'overlap', 'pred'], dice=None):
    p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout='row', zoom=1.6)
    p.add_layer(data, cmap=cmap, color_range=[0,3])
    # p.add_layer(, cmap=my_cmap)
    fig = p.build(colorbar=False)
    cbar = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), orientation='horizontal', shrink=10, ticks=[1,2,3])
    cbar.set_ticklabels(labels, fontsize=24)

    fig.tight_layout()

def plot_cortex(data, surface, mask, subj, contrast_index, experiment, output_path=None, color_range=[-8, 8], layout='row', dice=None):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)
    # p.add_layer(, cmap=my_cmap)
    fig = p.build(colorbar=True, cbar_kws={'location':'right', 'shrink':1})

    fig.tight_layout()

def subplot_cortex(data, surface, mask, ax, color_range=[-8, 8], layout='row', dice=None):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)

    p = p.render()
    p._check_offscreen()
    x = p.to_numpy(transparent_bg=True, scale=scale)
    ax.imshow(x)
    ax.axis('off')
    cbar_kws={'location':'right'}
    p._add_colorbars(**cbar_kws)
    return ax
    
def threshold_activation(hemi, threshold=0.25, abs=True):
    num_verts = int(hemi.shape[-1] * threshold)
    if abs: hemidata = np.abs(hemi)
    else: hemidata = hemi
    hemidata_thresh = np.argsort(-hemidata)[:num_verts]
    return hemidata_thresh

def compute_cortex_overlap(prediction, groundtruth, mask=None, threshold=0.10, abs=True, score=False):
    gt_lh = groundtruth[0]
    gt_rh = groundtruth[1]

    if type(mask) is not type(None):
        gt_lh[np.logical_not(mask[0, :])] = 0
        gt_rh[np.logical_not(mask[1, :])] = 0

    gt_lh_vertices = threshold_activation(gt_lh, threshold=threshold, abs=abs)
    gt_rh_vertices = threshold_activation(gt_rh, threshold=threshold, abs=abs)

    pred_lh = prediction[0]
    pred_rh = prediction[1]

    if type(mask) is not type(None): # apply the mask before thresholding
        pred_lh[np.logical_not(mask[0, :])] = 0
        pred_rh[np.logical_not(mask[1, :])] = 0

    pred_lh_vertices = threshold_activation(pred_lh, threshold=threshold, abs=abs)
    pred_rh_vertices = threshold_activation(pred_rh, threshold=threshold, abs=abs)

    masked_lh = np.zeros_like(gt_lh)
    masked_rh = np.zeros_like(gt_rh)
    masked_lh[gt_lh_vertices] = 0.5
    masked_rh[gt_rh_vertices] = 0.5

    masked_lh[pred_lh_vertices] = 2.5
    masked_rh[pred_rh_vertices] = 2.5

    lh_overlap = np.intersect1d(pred_lh_vertices, gt_lh_vertices)
    rh_overlap = np.intersect1d(pred_rh_vertices, gt_rh_vertices)
    masked_lh[lh_overlap] = 1.5
    masked_rh[rh_overlap] = 1.5

    if type(mask) is not type(None): # apply the mask one more time in case vertices on the medial wall end up in the thresholded sets
        masked_lh[np.logical_not(mask[0, :])] = 0
        masked_rh[np.logical_not(mask[1, :])] = 0

    return np.concatenate([masked_lh, masked_rh], axis =-1)

In [None]:
def make_and_save_plots(pred_dir, gt_dir, pred_end, gt_end, test_subj_ids, experiment, surfaces, mask, cmap, norm, labels=['gt', 'overlap', 'pred'], output_path='../../data/', threshold=0.1, contrast_index=9):
    for i in range(len(test_subj_ids)):
        subj = test_subj_ids[i]
        print(subj)
        if pred_dir: pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))
        if gt_dir: gt_contrast = np.load(os.path.join(gt_dir, f"{subj}_{gt_end}"))

        if pred_dir:
            if len(pred_contrast.shape) > 2:
                pred_contrast = pred_contrast.mean(0) # average across all 8 samples used to predict
            overlap = compute_cortex_overlap(pred_contrast[2*contrast_index: 2*contrast_index+2], gt_contrast[2*contrast_index: 2*contrast_index+2], mask=mask, threshold=threshold)
            plot_cortex_overlap(overlap, 
                            surface=surfaces["inflated"], 
                            subj=subj, 
                            contrast_index=contrast_index, 
                            experiment=experiment,
                            output_path=os.path.join(output_path, f'{subj}_{contrast_index}.png'), 
                            labels=labels,
                            cmap=cmap, norm=norm)
            plot_cortex(pred_contrast[2*contrast_index: 2*contrast_index+2], 
                        surface=surfaces["inflated"], 
                        subj=subj, 
                        mask=mask,
                        contrast_index=contrast_index, 
                        experiment=experiment,
                        output_path=os.path.join(output_path, f'{subj}_{contrast_index}.png'))
            
            
        else:
            overlap = compute_cortex_overlap(gt_contrast[2*contrast_index: 2*contrast_index+2], gt_contrast[2*contrast_index: 2*contrast_index+2], mask=mask, threshold=threshold)
            overlap /= 2.5
            plot_cortex_overlap(overlap, 
                            surface=surfaces["inflated"], 
                            subj=subj, 
                            contrast_index=contrast_index, 
                            experiment=experiment,
                            output_path=os.path.join(output_path, f'{subj}_{contrast_index}.png'), 
                            labels=labels,
                            cmap=cmap, norm=norm)
        

## Plotting Params

In [None]:
test_subj_ids = ['917255']
contrast_index = 5
threshold = 0.10

In [None]:
retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/retest/cortex'

make_and_save_plots(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='Target',
                    surfaces=surfaces,
                    mask=None,
                    cmap=my_cmap,
                    norm=norm,
                    labels=['groundtruth', 'overlap', 'prediction'],
                    threshold=threshold,
                    contrast_index=contrast_index,
                    output_path=output_dir
)

## Retest Plots

In [None]:
retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/retest/cortex'

make_and_save_plots(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=None,
                    cmap=my_cmap,
                    norm=norm,
                    labels=['test', 'overlap', 'retest'],
                    threshold=threshold,
                    contrast_index=contrast_index,
                    output_path=output_dir
)

## BrainSurfCNN

In [None]:
ic = 25
brainsurf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/brainsurf_ft/cortex'

make_and_save_plots(brainsurf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    cmap=my_cmap,
                    norm=norm,
                    labels=['test', 'overlap', 'pred'],
                    threshold=threshold,
                    contrast_index=contrast_index,
                    output_path=output_dir
)

## BrainSERF

In [None]:
ic = 25
brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/brainserf_ft/cortex'

make_and_save_plots(brainserf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    cmap=my_cmap,
                    norm=norm,
                    labels=['test', 'overlap', 'pred'],
                    threshold=threshold,
                    contrast_index=contrast_index,
                    output_path=output_dir
)

## BrainSurfGCN

In [None]:
ic = 50
brainsurfgcn_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/brainsurfgcn_ft/cortex'

make_and_save_plots(brainsurfgcn_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    cmap=my_cmap,
                    norm=norm,
                    labels=['test', 'overlap', 'pred'],
                    threshold=threshold,
                    contrast_index=contrast_index,
                    output_path=output_dir
)

## Group Average

Test

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
dir = '../../data/test_contrasts/'
end = 'joint_LR_task_contrasts.npy'
output_dir = '../../PaperPlots/groupavg'
contrasts = []
contrast_index = 4
print(CONTRASTS[contrast_index][0], CONTRASTS[contrast_index][-1])
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    contrast = np.load(os.path.join(dir, f"{subj}_{end}"))[2*contrast_index: 2*contrast_index+2]
    contrasts.append(contrast)

contrasts = np.mean(contrasts, axis=0)
contrast[0, np.logical_not(mask[0, :])] = 0
contrast[1, np.logical_not(mask[1, :])] = 0
plot_cortex(contrasts, 
                    surface=surfaces["inflated"], 
                    subj=subj, 
                    mask=mask,
                    contrast_index=contrast_index, 
                    experiment='GroupAverage',
                    color_range=[0, 1],
                    layout='grid',
                    output_path=os.path.join(output_dir, f'{subj}_{contrast_index}.png'))

test_ga = contrasts
    

BrainSurfCNN

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
ic = 25
dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
end = 'pred.npy'
output_dir = '../../PaperPlots/groupavg'
contrasts = []
contrast_index = 5

print(CONTRASTS[contrast_index][0], CONTRASTS[contrast_index][-1])
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    contrast = np.load(os.path.join(dir, f"{subj}_{end}")).mean(0)[2*contrast_index: 2*contrast_index+2]
    contrasts.append(contrast)

contrasts = np.mean(contrasts, axis=0)
plot_cortex(contrasts, 
                    surface=surfaces["inflated"], 
                    subj=subj, 
                    mask=mask,
                    contrast_index=contrast_index, 
                    experiment='GroupAverage',
                    color_range=[-1, 1],
                    layout='grid',
                    output_path=os.path.join(output_dir, f'{subj}_{contrast_index}.png'))

brainsurf_ga = contrasts

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
ic = 25
dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
end = 'pred.npy'
output_dir = '../../PaperPlots/groupavg'
contrasts = []
contrast_index = 9

print(CONTRASTS[contrast_index][0], CONTRASTS[contrast_index][-1])
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    contrast = np.load(os.path.join(dir, f"{subj}_{end}")).mean(0)[2*contrast_index: 2*contrast_index+2]
    contrasts.append(contrast)

contrasts = np.mean(contrasts, axis=0)
plot_cortex(contrasts, 
                    surface=surfaces["inflated"], 
                    subj=subj, 
                    mask=mask,
                    contrast_index=contrast_index, 
                    experiment='GroupAverage',
                    color_range=[-4, 4],
                    layout='grid',
                    output_path=os.path.join(output_dir, f'{subj}_{contrast_index}.png'))

brainserf_ga = contrasts

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
ic = 50
dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
end = 'pred.npy'
output_dir = '../../PaperPlots/groupavg'
contrasts = []
contrast_index = 9

print(CONTRASTS[contrast_index][0], CONTRASTS[contrast_index][-1])
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    contrast = np.load(os.path.join(dir, f"{subj}_{end}")).mean(0)[2*contrast_index: 2*contrast_index+2]
    contrasts.append(contrast)

contrasts = np.mean(contrasts, axis=0)
gcn_ga = contrasts

In [None]:
def compute_dice(prediction, groundtruth, abs=True):
    gt_lh = np.squeeze(np.argwhere(groundtruth[0]))
    gt_rh = np.squeeze(np.argwhere(groundtruth[1]))
    print(gt_lh[:10], gt_rh.shape)

    pred_lh = np.squeeze(np.argwhere(prediction[0]))
    pred_rh = np.squeeze(np.argwhere(prediction[1]))

    lh_overlap = np.intersect1d(gt_lh, pred_lh)
    rh_overlap = np.intersect1d(gt_rh, pred_rh)

    lh_dice = 2*len(lh_overlap) / (len(gt_lh) + len(pred_lh))
    rh_dice = 2*len(rh_overlap) / (len(gt_rh) + len(pred_rh))
    dice = (lh_dice + rh_dice) / 2
    
    return dice

In [None]:
print(compute_dice(test_ga, brainsurf_ga))
print(compute_dice(test_ga, brainserf_ga))
print(compute_dice(test_ga, gcn_ga))

## Contrast Frequency Plots

In [None]:
def plot_cnr(dirs, ends, models, contrast_index, mask=None):
    fig, axes = plt.subplots(len(dirs), 1, figsize=(13,10))
    
    for p, (dir, end) in enumerate(zip(dirs, ends)):
        contrasts = []
        for i in range(len(test_subj_ids)):
            subj = test_subj_ids[i]
            contrast = np.load(os.path.join(dir, f"{subj}_{end}"))
            if end == 'pred.npy':
                contrast = contrast.mean(0)
            contrast = contrast[2*contrast_index: 2*contrast_index+2]
            overlap = compute_cortex_overlap(prediction=contrast, groundtruth=contrast )
            overlap[overlap == 0.5] = 0
            overlap[overlap == 2.5] = 0
            overlap[overlap == 1.5] = 1
            contrasts.append(overlap)

        contrasts = np.mean(contrasts, axis=0)
        if p == 0:
            test_activation = contrasts
        else:
            print(models[p], np.corrcoef(test_activation, contrasts)[0,1])
        subplot_cortex(contrasts.reshape(2, -1), 
                            surface=surfaces["inflated"], 
                            subj=subj, 
                            mask=mask,
                            contrast_index=contrast_index, 
                            experiment='Contrast',
                            color_range=[0, 1],
                            layout='row',
                            ax=axes[p],
                            title = p == 0)
    
    plt.tight_layout()

def get_cnr(dirs, ends, models, contrast_index, mask=None):
    
    for p, (dir, end) in enumerate(zip(dirs, ends)):
        contrasts = []
        for i in range(len(test_subj_ids)):
            subj = test_subj_ids[i]
            contrast = np.load(os.path.join(dir, f"{subj}_{end}"))
            if end == 'pred.npy':
                contrast = contrast.mean(0)
            contrast = contrast[2*contrast_index: 2*contrast_index+2]
            overlap = compute_cortex_overlap(prediction=contrast, groundtruth=contrast )
            overlap[overlap == 0.5] = 0
            overlap[overlap == 2.5] = 0
            overlap[overlap == 1.5] = 1
            contrasts.append(overlap)

        contrasts = np.mean(contrasts, axis=0)
        if p == 0:
            test_activation = contrasts
        else:
            if models[p] == models[-1]: print(np.corrcoef(test_activation, contrasts)[0,1].round(3), end = ' \\\\ \n')
            else: print(np.corrcoef(test_activation, contrasts)[0,1].round(3), end=' & ')

def subplot_cortex(data, surface, mask, subj, contrast_index, experiment, ax, output_path=None, color_range=[-8, 8], layout='row', title=False):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)

    p.build(ax=ax, colorbar=True, cbar_kws={'location':'right'})
    
    if title: 
        ax.set_title(f'{CONTRASTS[contrast_index][0]} {CONTRASTS[contrast_index][-1]}', fontsize=50)
    if output_path: plt.savefig(fname=output_path)

Retest

In [None]:
dirs = ['../../data/test_contrasts/',
        '../../data/retest_contrasts/contrasts/',
        f"../../aim3_results/HCP_feat64_s8_c{25}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{25}_lr0.01_seed28/predict_on_test_subj/best_corr/",
        f"../../aim3_results/HCP_feat64_s8_c{25}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{25}_lr0.01_seed28/predict_on_test_subj/best_corr/",
        f"../../aim3_results/HCP_feat64_s8_c{25}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{25}_lr0.01_seed28/predict_on_test_subj/best_corr/"]
ends = [
    'joint_LR_task_contrasts.npy',
    'joint_LR_task_contrasts.npy',
    'pred.npy',
    'pred.npy',
    'pred.npy'
]
models=['test', 'retest', 'BrainSurfCNN', 'BrainSERF', 'BrainSurfGCN']
plot_cnr(dirs, ends, models, contrast_index=5)


In [None]:
dirs = ['../../data/test_contrasts/',
        '../../data/retest_contrasts/contrasts/',
        f"../../aim3_results/HCP_feat64_s8_c{25}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{25}_lr0.01_seed28/predict_on_test_subj/best_corr/",
        f"../../aim3_results/HCP_feat64_s8_c{25}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{25}_lr0.01_seed28/predict_on_test_subj/best_corr/",
        f"../../aim3_results/HCP_feat64_s8_c{25}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{25}_lr0.01_seed28/predict_on_test_subj/best_corr/"]
ends = [
    'joint_LR_task_contrasts.npy',
    'joint_LR_task_contrasts.npy',
    'pred.npy',
    'pred.npy',
    'pred.npy'
]
models=['test', 'retest', 'BrainSurfCNN', 'BrainSERF', 'BrainSurfGCN']
for c in range(0,47):
    print(f'{CONTRASTS[c][0]} {CONTRASTS[c][-1]}', end=' & ')
    get_cnr(dirs, ends, models, contrast_index=c)

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
test_dir = '../../data/test_contrasts/'
test_end = 'joint_LR_task_contrasts.npy'
pred_dir = '../../data/retest_contrasts/contrasts/'
ic=25
# pred_dir = brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
pred_end = 'joint_LR_task_contrasts.npy'#'pred.npy'
output_dir = '../../PaperPlots/groupavg'
contrasts = []
contrast_index = 4
print(CONTRASTS[contrast_index][0], CONTRASTS[contrast_index][-1])
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))[2*contrast_index: 2*contrast_index+2]
    overlap = compute_cortex_overlap(prediction=pred_contrast, groundtruth=pred_contrast)
    overlap[overlap == 0.5] = 0
    overlap[overlap == 2.5] = 0
    overlap[overlap == 1.5] = 1
    contrasts.append(overlap)

contrasts = np.mean(contrasts, axis=0)
contrasts[contrasts < 0.5] = 0
print(contrasts.shape)
plot_cortex(contrasts.reshape(2, -1), 
                    surface=surfaces["inflated"], 
                    subj=subj, 
                    mask=mask,
                    contrast_index=contrast_index, 
                    experiment='GroupAverage',
                    color_range=[0, 100],
                    layout='grid',
                    output_path=os.path.join(output_dir, f'{subj}_{contrast_index}.png'))

# Error Analysis

In [None]:
def plot_cortex_err(data, surface, mask, subj, contrast_index, experiment, output_path=None, color_range=[-8, 8], layout='row', dice=None, title=False):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)
    # p.add_layer(, cmap=my_cmap)
    fig = p.build(colorbar=True, cbar_kws={'location':'right'})
    if title: 
        fig.suptitle(f'{CONTRASTS[contrast_index][0]}-{CONTRASTS[contrast_index][-1]}', fontsize=64)
    fig.tight_layout()
    if output_path: plt.savefig(fname=output_path)

def analyze_error_plot(pred_dir, gt_dir, pred_end, gt_end, test_subj_ids, experiment, surfaces, mask, ax=None, contrast_index=9, title=False):
    diff_contrasts = []
    for i in range(len(test_subj_ids)):
        subj = test_subj_ids[i]
        if pred_dir: pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))
        if gt_dir: gt_contrast = np.load(os.path.join(gt_dir, f"{subj}_{gt_end}"))

        if len(pred_contrast.shape) > 2:
            pred_contrast = pred_contrast.mean(0) # average across all 8 samples used to predict

        pred_contrast = pred_contrast[2*contrast_index: 2*contrast_index+2]
        gt_contrast = gt_contrast[2*contrast_index: 2*contrast_index+2]
        diff_contrasts.append(np.square(gt_contrast - pred_contrast))
    diff_contrast = np.array(diff_contrasts).mean(0)
    plot_cortex_err(diff_contrast, 
                surface=surfaces["inflated"],  
                mask=mask,
                subj=subj,
                contrast_index=contrast_index, 
                experiment=experiment,
                color_range=[0, 10],
                title=title)

def subplot_cortex_err(data, surface, mask, subj, contrast_index, experiment, ax, output_path=None, color_range=[-8, 8], layout='row', dice=None, title=False):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)
    # p.add_layer(, cmap=my_cmap)
    # fig = p.build(colorbar=True, })

    p.build(ax=ax, colorbar=True, cbar_kws={'location':'right'})
    if title: 
        ax.set_title(f'{CONTRASTS[contrast_index][0]}-{CONTRASTS[contrast_index][-1]}', fontsize=64)
    if output_path: plt.savefig(fname=output_path)
    
def analyze_error_subplot(pred_dir, gt_dir, pred_end, gt_end, test_subj_ids, experiment, surfaces, mask, ax=None, contrast_index=9, title=False):
    diff_contrasts = []
    for i in range(len(test_subj_ids)):
        subj = test_subj_ids[i]
        if pred_dir: pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))
        if gt_dir: gt_contrast = np.load(os.path.join(gt_dir, f"{subj}_{gt_end}"))

        if len(pred_contrast.shape) > 2:
            pred_contrast = pred_contrast.mean(0) # average across all 8 samples used to predict

        pred_contrast = pred_contrast[2*contrast_index: 2*contrast_index+2]
        gt_contrast = gt_contrast[2*contrast_index: 2*contrast_index+2]
        # print(np.max(pred_contrast), np.max(gt_contrast))
        # print(np.min(pred_contrast), np.min(gt_contrast))
        diff_contrasts.append(np.square(gt_contrast - pred_contrast))
    diff_contrast = np.array(diff_contrasts).std(0)
    # diff_contrast = np.array(diff_contrasts).mean(0)
    subplot_cortex_err(diff_contrast, 
                surface=surfaces["inflated"],  
                mask=mask,
                ax=ax,
                subj=subj,
                contrast_index=contrast_index, 
                experiment=experiment,
                color_range=[0, 30],
                title=title)

## Retest

In [None]:
retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/retest/cortex'

analyze_error_plot(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index,
                    title=True
)

## BrainSurfCNN

In [None]:
ic = 25
brainsurf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/brainsurf_ft/cortex'

analyze_error_plot(brainsurf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
)

## BrainSERF

In [None]:
ic = 25
brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'

analyze_error_plot(brainserf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
)

In [None]:
ic = 50
brainsurfgcn_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'

analyze_error_plot(brainsurfgcn_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
)

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
for contrast_index in range(0,1):
    print(contrast_index)
    retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/retest/cortex'
    fig, axes = plt.subplots(4, 1, figsize=(13,10))

    analyze_error_subplot(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[0],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index,
                    title=True
    )

    ic = 25
    brainsurf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/brainsurf_ft/cortex'

    analyze_error_subplot(brainsurf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[1],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )

    ic = 25
    brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainserf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[2],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )
    ic = 50
    brainsurfgcn_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainsurfgcn_numpy_dir, test_numpy_dir, 
                        pred_end='pred.npy', 
                        gt_end='joint_LR_task_contrasts.npy',
                        test_subj_ids=test_subj_ids,
                        experiment='',
                        ax=axes[3],
                        surfaces=surfaces,
                        mask=mask,
                        contrast_index=contrast_index
    )
    plt.tight_layout()

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
for contrast_index in range(12,24):
    print(contrast_index)
    retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/retest/cortex'
    fig, axes = plt.subplots(4, 1, figsize=(13,10))

    analyze_error_subplot(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[0],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index,
                    title=True
    )

    ic = 25
    brainsurf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/brainsurf_ft/cortex'

    analyze_error_subplot(brainsurf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[1],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )

    ic = 25
    brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainserf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[2],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )
    ic = 50
    brainsurfgcn_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainsurfgcn_numpy_dir, test_numpy_dir, 
                        pred_end='pred.npy', 
                        gt_end='joint_LR_task_contrasts.npy',
                        test_subj_ids=test_subj_ids,
                        experiment='',
                        ax=axes[3],
                        surfaces=surfaces,
                        mask=mask,
                        contrast_index=contrast_index
    )
    plt.tight_layout()

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
for contrast_index in range(24,36):
    print(contrast_index)
    retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/retest/cortex'
    fig, axes = plt.subplots(4, 1, figsize=(13,10))

    analyze_error_subplot(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[0],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index,
                    title=True
    )

    ic = 25
    brainsurf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/brainsurf_ft/cortex'

    analyze_error_subplot(brainsurf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[1],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )

    ic = 25
    brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainserf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[2],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )
    ic = 50
    brainsurfgcn_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainsurfgcn_numpy_dir, test_numpy_dir, 
                        pred_end='pred.npy', 
                        gt_end='joint_LR_task_contrasts.npy',
                        test_subj_ids=test_subj_ids,
                        experiment='',
                        ax=axes[3],
                        surfaces=surfaces,
                        mask=mask,
                        contrast_index=contrast_index
    )
    plt.tight_layout()

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
for contrast_index in range(36,47):
    print(contrast_index)
    retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/retest/cortex'
    fig, axes = plt.subplots(4, 1, figsize=(13,10))

    analyze_error_subplot(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[0],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index,
                    title=True
    )

    ic = 25
    brainsurf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/brainsurf_ft/cortex'

    analyze_error_subplot(brainsurf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[1],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )

    ic = 25
    brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainserf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    ax=axes[2],
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
    )
    ic = 50
    brainsurfgcn_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
    test_numpy_dir = '../../data/test_contrasts/'

    analyze_error_subplot(brainsurfgcn_numpy_dir, test_numpy_dir, 
                        pred_end='pred.npy', 
                        gt_end='joint_LR_task_contrasts.npy',
                        test_subj_ids=test_subj_ids,
                        experiment='',
                        ax=axes[3],
                        surfaces=surfaces,
                        mask=mask,
                        contrast_index=contrast_index
    )
    plt.tight_layout()