In [None]:
import sys
sys.path.insert(0, '..')

In [None]:
from neuromaps import plotting as neuroplot
import numpy as np
import os
from glob import glob
from matplotlib import colors as mcolors, pyplot as plt
from neuromaps.images import load_gifti
from neuromaps.datasets import fetch_atlas
import sys
from surfplot import Plot
from neuromaps.datasets import fetch_fslr
import seaborn as sns
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.cm import ScalarMappable

from utils.utilities import CONTRASTS, GROUP_CONTRAST_IDS, plot_corr_matrices_across_contrasts, compute_corr_coeff, scale


surfaces = fetch_fslr()
lh, rh = surfaces['inflated']

In [None]:
mask = np.load('../data/glasser_medial_wall_mask.npy')
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')

# from .utils.utilities import CONTRASTS, GROUP_CONTRAST_IDS, plot_corr_matrices_across_contrasts, compute_corr_coeff, scale

colors = [sns.color_palette("Spectral", 7)[0], sns.color_palette("Spectral", 7)[3], sns.color_palette("Spectral", 7)[6]]
alphas = [0.7, 0.7, 0.7]

colors_alphas = []
for c, a in zip(colors, alphas):
    colors_alphas.append((c[0], c[1], c[2], a))

my_cmap = ListedColormap(colors)
norm = BoundaryNorm([0.5,1.5,2.5,3.5], my_cmap.N, clip=False)

In [None]:
def plot_cortex_overlap(data, surface, subj, contrast_index, experiment, output_path, cmap, norm, labels=['gt', 'overlap', 'pred'], dice=None, title=None):
    p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout='row', zoom=1.6)
    p.add_layer(data, cmap=cmap, color_range=[0,3])
    # p.add_layer(, cmap=my_cmap)
    fig = p.build(colorbar=False)
    cbar = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), orientation='horizontal', shrink=0.1, ticks=[1,2,3])
    cbar.set_ticklabels(labels)

    if title: 
        fig.suptitle(f'{CONTRASTS[contrast_index][0]} {CONTRASTS[contrast_index][-1]}', fontsize=50)

    fig.tight_layout()
    if output_path: plt.savefig(fname=output_path)

def plot_cortex(data, surface, mask, subj, contrast_index, experiment, output_path=None, color_range=[-8, 8], layout='row', dice=None):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)
    # p.add_layer(, cmap=my_cmap)
    fig = p.build(colorbar=True, cbar_kws={'location':'right'})

    fig.tight_layout()
    if output_path: plt.savefig(fname=output_path)

def subplot_cortex(data, surface, mask, ax, color_range=[-8, 8], layout='row', dice=None):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)
    # p.add_layer(, cmap=my_cmap)
    # fig = p.build(colorbar=True, })

    p = p.render()
    p._check_offscreen()
    x = p.to_numpy(transparent_bg=True, scale=scale)
    ax.imshow(x)
    ax.axis('off')
    cbar_kws={'location':'right'}
    p._add_colorbars(**cbar_kws)
    return ax
    
def threshold_activation(hemi, threshold=0.25, abs=True):
    num_verts = int(hemi.shape[-1] * threshold)
    if abs: hemidata = np.abs(hemi)
    else: hemidata = hemi
    hemidata_thresh = np.argsort(-hemidata)[:num_verts]
    return hemidata_thresh

def compute_cortex_overlap(prediction, groundtruth, mask=None, threshold=0.10, abs=True, score=False):
    gt_lh = groundtruth[0]
    gt_rh = groundtruth[1]

    if type(mask) is not type(None):
        gt_lh[np.logical_not(mask[0, :])] = 0
        gt_rh[np.logical_not(mask[1, :])] = 0

    gt_lh_vertices = threshold_activation(gt_lh, threshold=threshold, abs=abs)
    gt_rh_vertices = threshold_activation(gt_rh, threshold=threshold, abs=abs)

    pred_lh = prediction[0]
    pred_rh = prediction[1]

    if type(mask) is not type(None): # apply the mask before thresholding
        pred_lh[np.logical_not(mask[0, :])] = 0
        pred_rh[np.logical_not(mask[1, :])] = 0

    pred_lh_vertices = threshold_activation(pred_lh, threshold=threshold, abs=abs)
    pred_rh_vertices = threshold_activation(pred_rh, threshold=threshold, abs=abs)

    masked_lh = np.zeros_like(gt_lh)
    masked_rh = np.zeros_like(gt_rh)
    masked_lh[gt_lh_vertices] = 0.5
    masked_rh[gt_rh_vertices] = 0.5

    masked_lh[pred_lh_vertices] = 2.5
    masked_rh[pred_rh_vertices] = 2.5

    lh_overlap = np.intersect1d(pred_lh_vertices, gt_lh_vertices)
    rh_overlap = np.intersect1d(pred_rh_vertices, gt_rh_vertices)
    masked_lh[lh_overlap] = 1.5
    masked_rh[rh_overlap] = 1.5

    if type(mask) is not type(None): # apply the mask one more time in case vertices on the medial wall end up in the thresholded sets
        masked_lh[np.logical_not(mask[0, :])] = 0
        masked_rh[np.logical_not(mask[1, :])] = 0

    lh_dice = 2*len(lh_overlap) / (len(pred_lh_vertices) + len(gt_lh_vertices))
    rh_dice = 2*len(rh_overlap) / (len(pred_rh_vertices) + len(gt_rh_vertices))
    dice = (lh_dice + rh_dice) / 2


    if not score: return np.concatenate([masked_lh, masked_rh], axis=-1)
    return dice

In [None]:
def make_and_save_plots(pred_dir, gt_dir, pred_end, gt_end, test_subj_ids, experiment, surfaces, mask, cmap, norm, labels=['gt', 'overlap', 'pred'], output_path='../../data/', threshold=0.1, contrast_index=9):
    for i in range(len(test_subj_ids)):
        subj = test_subj_ids[i]
        print(subj)
        if pred_dir: pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))
        if gt_dir: gt_contrast = np.load(os.path.join(gt_dir, f"{subj}_{gt_end}"))

        if pred_dir:
            if len(pred_contrast.shape) > 2:
                pred_contrast = pred_contrast.mean(0) # average across all 8 samples used to predict
            overlap = compute_cortex_overlap(pred_contrast[2*contrast_index: 2*contrast_index+2], gt_contrast[2*contrast_index: 2*contrast_index+2], mask=mask, threshold=threshold)
            plot_cortex_overlap(overlap, 
                            surface=surfaces["inflated"], 
                            subj=subj, 
                            contrast_index=contrast_index, 
                            experiment=experiment,
                            output_path=os.path.join(output_path, f'{subj}_{contrast_index}.png'), 
                            labels=labels,
                            cmap=cmap, norm=norm)
            plot_cortex(pred_contrast[2*contrast_index: 2*contrast_index+2], 
                        surface=surfaces["inflated"], 
                        subj=subj, 
                        mask=mask,
                        contrast_index=contrast_index, 
                        experiment=experiment,
                        output_path=os.path.join(output_path, f'{subj}_{contrast_index}.png'))
            
            
        else:
            overlap = compute_cortex_overlap(gt_contrast[2*contrast_index: 2*contrast_index+2], gt_contrast[2*contrast_index: 2*contrast_index+2], mask=mask, threshold=threshold)
            overlap /= 2.5
            plot_cortex_overlap(overlap, 
                            surface=surfaces["inflated"], 
                            subj=subj, 
                            contrast_index=contrast_index, 
                            experiment=experiment,
                            output_path=os.path.join(output_path, f'{subj}_{contrast_index}.png'), 
                            labels=labels,
                            cmap=cmap, norm=norm)
        

## Plotting Params

In [None]:
test_subj_ids = ['917255']
contrast_index = 4
threshold = 0.50

In [None]:
def compute_dice(prediction, groundtruth, abs=True):
    gt_lh = np.squeeze(np.argwhere(groundtruth[0]))
    gt_rh = np.squeeze(np.argwhere(groundtruth[1]))

    pred_lh = np.squeeze(np.argwhere(prediction[0]))
    pred_rh = np.squeeze(np.argwhere(prediction[1]))

    lh_overlap = np.intersect1d(gt_lh, pred_lh)
    rh_overlap = np.intersect1d(gt_rh, pred_rh)

    lh_dice = 2*len(lh_overlap) / (len(gt_lh) + len(pred_lh))
    rh_dice = 2*len(rh_overlap) / (len(gt_rh) + len(pred_rh))
    dice = (lh_dice + rh_dice) / 2
    
    return dice

# Error Analysis

In [None]:
def plot_cortex_err(data, surface, mask, subj, contrast_index, experiment, output_path=None, color_range=[-8, 8], layout='row', dice=None, title=False):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)
    # p.add_layer(, cmap=my_cmap)
    fig = p.build(colorbar=True, cbar_kws={'location':'right'})
    if title: 
        fig.suptitle(f'{CONTRASTS[contrast_index][0]} {CONTRASTS[contrast_index][-1]}', fontsize=64)
    fig.tight_layout()
    if output_path: plt.savefig(fname=output_path)

def analyze_error_plot(pred_dir, gt_dir, pred_end, gt_end, test_subj_ids, experiment, surfaces, mask, ax=None, contrast_index=9, title=False):
    diff_contrasts = []
    for i in range(len(test_subj_ids)):
        subj = test_subj_ids[i]
        if pred_dir: pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))
        if gt_dir: gt_contrast = np.load(os.path.join(gt_dir, f"{subj}_{gt_end}"))

        if len(pred_contrast.shape) > 2:
            pred_contrast = pred_contrast.mean(0) # average across all 8 samples used to predict

        pred_contrast = pred_contrast[2*contrast_index: 2*contrast_index+2]
        gt_contrast = gt_contrast[2*contrast_index: 2*contrast_index+2]
        diff_contrasts.append(np.square(gt_contrast - pred_contrast))
    diff_contrast = np.array(diff_contrasts).std(0)
    plot_cortex_err(diff_contrast, 
                surface=surfaces["inflated"],  
                mask=mask,
                subj=subj,
                contrast_index=contrast_index, 
                experiment=experiment,
                color_range=[0, 10],
                title=title)

def subplot_cortex_err(data, surface, mask, subj, contrast_index, experiment, ax, output_path=None, color_range=[-8, 8], layout='row', dice=None, title=False):
    if layout == 'row': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1600, 300), layout=layout, zoom=1.6)
    elif layout == 'col': p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(300, 1600), layout=layout, zoom=1.6)
    else: p = Plot(surf_lh=surface[0], surf_rh=surface[1], size=(1200, 800), layout=layout, zoom=1.6)
    lh, rh = data
    if type(mask) is not type(None):
        lh[np.logical_not(mask[0, :])] = 0
        rh[np.logical_not(mask[1, :])] = 0
    p.add_layer(np.concatenate((lh, rh), axis=-1), cmap='Spectral_r', color_range=color_range)
    p.build(ax=ax, colorbar=True, cbar_kws={'location':'right'})

    if title: 
        ax.set_title(f'{CONTRASTS[contrast_index][0]} {CONTRASTS[contrast_index][-1]}', fontsize=64)
    if output_path: plt.savefig(fname=output_path)
    
def analyze_error_subplot(pred_dir, gt_dir, pred_end, gt_end, test_subj_ids, experiment, surfaces, mask, ax=None, contrast_index=9, title=False):
    diff_contrasts = []
    for i in range(len(test_subj_ids)):
        subj = test_subj_ids[i]
        if pred_dir: pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))
        if gt_dir: gt_contrast = np.load(os.path.join(gt_dir, f"{subj}_{gt_end}"))

        if len(pred_contrast.shape) > 2:
            pred_contrast = pred_contrast.mean(0) # average across all 8 samples used to predict

        pred_contrast = pred_contrast[2*contrast_index: 2*contrast_index+2]
        gt_contrast = gt_contrast[2*contrast_index: 2*contrast_index+2]
        # print(np.max(pred_contrast), np.max(gt_contrast))
        # print(np.min(pred_contrast), np.min(gt_contrast))
        diff_contrasts.append(np.square(gt_contrast - pred_contrast))
    diff_contrast = np.array(diff_contrasts).mean(0)
    # diff_contrast = np.array(diff_contrasts).mean(0)
    subplot_cortex_err(diff_contrast, 
                surface=surfaces["inflated"],  
                mask=mask,
                ax=ax,
                subj=subj,
                contrast_index=contrast_index, 
                experiment=experiment,
                color_range=[0, 30],
                title=title)

In [None]:
def compute_subj_contrast_corr(pred, ref, contrasts, contrast_ids, mask):
    all_lh_corr = []
    all_rh_corr = []
    all_avg_corr = []
    
    masked_lh_pred = pred[:, ::2, mask[0, :]]
    masked_rh_pred = pred[:, 1::2, mask[1, :]]
    
    masked_lh_ref = ref[:, ::2, mask[0, :]]
    masked_rh_ref = ref[:, 1::2, mask[1, :]]

    for i in range(len(contrasts)):
        lh_contrast_ref = masked_lh_ref[:, i, :]
        rh_contrast_ref = masked_rh_ref[:, i, :]

        lh_contrast_pred = masked_lh_pred[:, i, :]
        rh_contrast_pred = masked_rh_pred[:, i, :]

        lh_corr = compute_corr_coeff(lh_contrast_ref, lh_contrast_pred)
        # print(compute_corr_coeff(lh_contrast_ref, lh_contrast_pred) == compute_corr_coeff(lh_contrast_pred, lh_contrast_ref))
        rh_corr = compute_corr_coeff(rh_contrast_ref, rh_contrast_pred)
        # print(lh_corr[0, :] - lh_corr[1, :])

        all_lh_corr.append(lh_corr)
        all_rh_corr.append(rh_corr)
        all_avg_corr.append((lh_corr + rh_corr) / 2)
    return all_lh_corr, all_rh_corr, all_avg_corr



In [None]:
def compute_mse(pred_dir, gt_dir, pred_end, gt_end, test_subj_ids, experiment, surfaces, mask, ax=None, contrast_index=9, title=False):
    diff_contrasts = []
    for i in range(len(test_subj_ids)):
        subj = test_subj_ids[i]
        if pred_dir: pred_contrast = np.load(os.path.join(pred_dir, f"{subj}_{pred_end}"))
        if gt_dir: gt_contrast = np.load(os.path.join(gt_dir, f"{subj}_{gt_end}"))

        if len(pred_contrast.shape) > 2:
            pred_contrast = pred_contrast.mean(0) # average across all 8 samples used to predict

        pred_contrast = pred_contrast[2*contrast_index: 2*contrast_index+2]
        gt_contrast = gt_contrast[2*contrast_index: 2*contrast_index+2]
        # print(np.max(pred_contrast), np.max(gt_contrast))
        # print(np.min(pred_contrast), np.min(gt_contrast))
        diff_contrasts.append(np.square(gt_contrast - pred_contrast)) # is MSE ~ inverse of tSNR (higher tSNR = better scan)
    return np.array(diff_contrasts)

## Retest

In [None]:
retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/retest/cortex'

analyze_error_plot(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index,
                    title=True
)

## BrainSurfCNN

In [None]:
ic = 25
brainsurf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'
output_dir = '../../PaperPlots/brainsurf_ft/cortex'

analyze_error_plot(brainsurf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
)

## BrainSERF

In [None]:
ic = 25
brainserf_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/se_attn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'

analyze_error_plot(brainserf_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
)

In [None]:
ic = 50
brainsurfgcn_numpy_dir = f"../../aim3_results/HCP_feat64_s8_c{ic}_lr0.01_seed28_epochs50/gnn_finetuned_feat64_s8_c{ic}_lr0.01_seed28/predict_on_test_subj/best_corr/"
test_numpy_dir = '../../data/test_contrasts/'

analyze_error_plot(brainsurfgcn_numpy_dir, test_numpy_dir, 
                    pred_end='pred.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index
)

### Dice of Error between Test and Retest and negative tSNR

In [None]:
def compute_dice(prediction, groundtruth, abs=True):
    gt_lh = np.squeeze(np.argwhere(groundtruth[0]))
    gt_rh = np.squeeze(np.argwhere(groundtruth[1]))

    pred_lh = np.squeeze(np.argwhere(prediction[0]))
    pred_rh = np.squeeze(np.argwhere(prediction[1]))

    lh_overlap = np.intersect1d(gt_lh, pred_lh)
    rh_overlap = np.intersect1d(gt_rh, pred_rh)

    lh_dice = 2*len(lh_overlap) / (len(gt_lh) + len(pred_lh))
    rh_dice = 2*len(rh_overlap) / (len(gt_rh) + len(pred_rh))
    dice = (lh_dice + rh_dice) / 2
    
    return dice

def dice_auc(prediction, groundtruth, contrast_index, mask=None, dx=0.05, abs=True, return_all=True):
    dices = []
    for th in np.arange(0.05, 0.51, dx):
        score = compute_cortex_overlap(prediction[2*contrast_index: 2*contrast_index+2], groundtruth[2*contrast_index: 2*contrast_index+2], mask=mask, threshold=th, abs=abs, score=True)
        dices.append(score)

    auc = 0.0
    for i in range(len(dices) - 1):
        auc += (dices[i] + dices[i+1]) / 2 * dx
    
    if return_all: return auc, np.array(dices)
    return auc

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
ic = 25
dir = f"../../data/test_tsnr/"
end = 'tsnr.npy'
output_dir = '../../PaperPlots/tsnr/'
contrasts=[]
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    contrast = np.load(os.path.join(dir, f"{subj}_{end}"))
    contrasts.append([contrast[0], contrast[1]])

contrast = np.mean(contrasts, axis=0)
# Standardize
# contrast[0] = (contrast[0] - np.mean(contrast[0]) ) / np.std(contrast[0])
# contrast[1] = (contrast[1] - np.mean(contrast[1]) ) / np.std(contrast[1])
# contrast[0, np.logical_not(mask[0, :])] = 0
# contrast[1, np.logical_not(mask[1, :])] = 0
# lh_top_verts = threshold_activation(contrast[0], threshold=0.1)
# rh_top_verts = threshold_activation(contrast[1], threshold=0.1)
# contrasts = np.zeros_like(contrast)
# contrasts[0, lh_top_verts] += contrast[0, lh_top_verts]
# contrasts[1, rh_top_verts] += contrast[1, rh_top_verts]
plot_cortex(contrast * -1, 
            surface=surfaces["inflated"], 
            subj=subj, 
            mask=mask,
            contrast_index=contrast_index, 
            experiment='tSNR',
            color_range=[-50, 0],
            layout='grid')
tsnr = contrast * -1
plt.show()
plt.close('all')

In [None]:
from copy import deepcopy
tsnr_thresh = deepcopy(tsnr)
tsnr_thresh[0, np.logical_not(mask[0, :])] = 0
tsnr_thresh[1, np.logical_not(mask[1, :])] = 0
lh_top_verts = threshold_activation(tsnr_thresh[0], threshold=0.25)
rh_top_verts = threshold_activation(tsnr_thresh[1], threshold=0.25)
contrasts = np.zeros_like(tsnr_thresh)
contrasts[0, lh_top_verts] += tsnr_thresh[0, lh_top_verts]
contrasts[1, rh_top_verts] += tsnr_thresh[1, rh_top_verts]
tsnr_thresh = contrasts

In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')
dices = []
for contrast_index in [2, 28, 43, 46, 27]:
    retest_numpy_dir = '../../data/retest_contrasts/contrasts/'
    test_numpy_dir = '../../data/test_contrasts/'
    output_dir = '../../PaperPlots/retest/cortex'
    # fig, axes = plt.subplots(4, 1, figsize=(13,10))

    retest_mse = compute_mse(retest_numpy_dir, test_numpy_dir, 
                    pred_end='joint_LR_task_contrasts.npy', 
                    gt_end='joint_LR_task_contrasts.npy',
                    test_subj_ids=test_subj_ids,
                    experiment='',
                    surfaces=surfaces,
                    mask=mask,
                    contrast_index=contrast_index,
                    title=True
    ).mean(0) # Average across all subjects
    overlap = compute_cortex_overlap(tsnr, np.abs(retest_mse), mask=mask, threshold=0.25, abs=False)
    # print(overlap.shape)
    plot_cortex_overlap(overlap, 
                    surface=surfaces["inflated"], 
                    subj=subj, 
                    contrast_index=contrast_index, 
                    experiment=None,
                    output_path=None, 
                    labels=['error', 'overlap', 'tsnr'],
                    cmap=my_cmap, norm=norm, title=f'{CONTRASTS[contrast_index][0]} {CONTRASTS[contrast_index][-1]}')
    
    auc, scores = dice_auc( prediction=np.abs(retest_mse),
                            groundtruth=tsnr,
                            abs=False,
                            mask=mask,
                            contrast_index=0,
                            dx=0.05)
    dices.append(auc)

    

    # print(f'{CONTRASTS[contrast_index][0] } {CONTRASTS[contrast_index][-1] }', compute_dice(retest_mse_temp, tsnr_thresh))
tsnr_dices = np.array(dices)

In [None]:

print(tsnr_dices.shape)
for i in np.argsort(dices)[:-11:-1]:
    print(i, dices[i], CONTRASTS[i][0], CONTRASTS[i][-1])


In [None]:
test_subj_ids = np.genfromtxt("../data/MICCAI2020/HCP_test_retest_subj_ids.csv", dtype='<U13')

test_contrasts = []
for i in range(len(test_subj_ids)):
    subj = test_subj_ids[i]
    contrast_file = os.path.join('../../data/', "test_contrasts", "%s_joint_LR_task_contrasts.npy" % subj)
    contrast_data = np.load(contrast_file)
    test_contrasts.append(contrast_data)

test_contrasts = np.asarray(test_contrasts)
test_contrasts.shape

In [None]:

retest_dice = []
for ic in [1]:
   print('--------------------------------------------------')
   label = f'Retest Contrasts'
   print(label)
   retest_path = f"../../data/retest_contrasts/contrasts/"
   
   for i in range(len(test_subj_ids)):
      subj = test_subj_ids[i]
      pred_file = os.path.join(retest_path, "%s_joint_LR_task_contrasts.npy" % subj)
      retest_data = np.load(pred_file)
      test_data = test_contrasts[i]

      dices = []
      for c in range(47):
         test = test_data[c:c+2]
         retest= retest_data[c:c+2]
         # overlap = compute_cortex_overlap(test, retest, mask=mask, threshold=0.10, abs=True)
         # plot_cortex_overlap(overlap, 
         #                 surface=surfaces["inflated"], 
         #                 subj=subj, 
         #                 contrast_index=contrast_index, 
         #                 experiment=None,
         #                 output_path=None, 
         #                 labels=['error', 'overlap', 'tsnr'],
         #                 cmap=my_cmap, norm=norm)
         
         # retest[0, np.logical_not(mask[0, :])] = 0
         # retest[1, np.logical_not(mask[1, :])] = 0
         # lh_top_verts = threshold_activation(retest[0], threshold=0.10)
         # rh_top_verts = threshold_activation(retest[1], threshold=0.10)
         # retest_temp = np.zeros_like(retest)
         # retest_temp[0, lh_top_verts] += retest[0, lh_top_verts]
         # retest_temp[1, rh_top_verts] += retest[1, rh_top_verts]

         # test[0, np.logical_not(mask[0, :])] = 0
         # test[1, np.logical_not(mask[1, :])] = 0
         # lh_top_verts = threshold_activation(test[0], threshold=0.10)
         # rh_top_verts = threshold_activation(test[1], threshold=0.10)
         # test_temp = np.zeros_like(test)
         # test_temp[0, lh_top_verts] += test[0, lh_top_verts]
         # test_temp[1, rh_top_verts] += test[1, rh_top_verts]

         auc, scores = dice_auc( prediction=retest,
                                 groundtruth=test,
                                 mask=mask,
                                 contrast_index=0,
                                 dx=0.05)
         dices.append(auc)
      
      retest_dice.append(dices)
#    print('Max ACC:', CONTRASTS[np.argmax(accs)], np.max(accs))
#    print('Min ACC:', CONTRASTS[np.argmin(accs)], np.min(accs))
#    print('Avg ACC:',np.mean(accs))

In [None]:
retest_dice = np.array(retest_dice)
retest_dice.shape

In [None]:
import pandas as pd
def plot_model_comparison_accuracy(pred_by_model, colors=[ '#f99154', '#63bfa6', '#358cbb', '#3aac11', '#a89154', '#e0d57e'], metric='Dice', indicies=np.arange(len(CONTRASTS)), legend_size=16):
    colors = colors[:len(pred_by_model.keys())]
    sns.set_palette(sns.color_palette(colors))
    df = pd.DataFrame(columns=["Model", "Task Contrast", "Dice"])
    for model in pred_by_model:
        data = pred_by_model[model]
        if len(data.shape) > 1: 
            data = np.mean(data, 0)

        print(data.shape)
        for i in indicies:
            item = CONTRASTS[i]
            task, cope_id, contrast_label = item
            key = "%s %s" % (task, contrast_label)
            
            df.loc[len(df.index)] = [model,key,data[i]]

    width = 36 if len(indicies) == len(CONTRASTS) else 22
    fig, ax = plt.subplots(1, 1, figsize=(width, 10))
    sns.barplot(x="Task Contrast",
                y=metric, hue="Model",
                data=df, ax=ax, # palette="Set3",
                hue_order=list(pred_by_model.keys()))
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha="right", rotation_mode="anchor") 
    L = ax.legend(frameon=False, ncol=len(pred_by_model.keys()), loc='upper center', bbox_to_anchor=(0.5, 1.1), fontsize=legend_size)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    y_bottom = 0
    y_top=0.45
    ax.set_ylim(y_bottom, y_top)
    ax.set_yticks(np.arange(y_bottom, y_top+0.01, 0.1))
    ax.set_xlim(-1, len(indicies))
    ax.tick_params(direction="in", labelsize=24)
    ax.xaxis.get_label().set_fontsize(40)
    ax.yaxis.get_label().set_fontsize(40)
    ax.tick_params(length = 10)
    plt.show()

In [None]:
retest_dice.shape, tsnr_dices.shape

In [None]:
models = {
    'groundtruth': retest_dice,
    'tsnr': tsnr_dices
}
plot_model_comparison_accuracy(pred_by_model=models)