In [None]:
%load_ext autoreload
%autoreload 2
import sys
import pandas as pd
import os
import matplotlib.pyplot as plt
import cortex
from os.path import join
from collections import defaultdict
import numpy as np
import joblib
from tqdm import tqdm
import sys
from copy import deepcopy
sys.path.append('../notebooks')
from neuro.config import repo_dir, PROCESSED_DIR
from neuro import viz
from neurosynth import term_dict, term_dict_rev, get_neurosynth_flatmaps
import viz

Note, this notebook requires first running `03_export_qa_flatmaps.ipynb` into `df_qa_dict.pkl` files for each subject.

### compute correlations with qa flatmaps and plot avgs

In [None]:
# setting = 'shapley_neurosynth'
# setting = 'full_neurosynth'

apply_mask = True
frac_voxels_to_keep = 0.1  # 0.10
frac_voxels_to_keep_list = [frac_voxels_to_keep]


corrs_df_list = defaultdict(list)
flatmaps_qa_dicts_by_subject = {}
for subject in tqdm(['UTS01', 'UTS02', 'UTS03']):
    flatmaps_gt = get_neurosynth_flatmaps(subject)

    flatmaps_qa_dict_over_settings = defaultdict(list)
    # , 'individual_neurosynth']:
    # for setting in ['shapley_neurosynth', 'full_neurosynth', 'individual_gpt4']:
    # , 'individual_gpt4', 'shapley_neurosynth']:
    for setting in ['shapley_neurosynth']:
        flatmaps_qa_dict = joblib.load(
            join(PROCESSED_DIR, subject.replace('UT', ''), setting + '.pkl'))
        for q in flatmaps_qa_dict.keys():
            flatmaps_qa_dict_over_settings[q].append(flatmaps_qa_dict[q])
    flatmaps_qa_dict = {
        q: np.mean(flatmaps_qa_dict_over_settings[q], axis=0)
        for q in flatmaps_qa_dict_over_settings.keys()
    }

    if apply_mask:
        corrs_test = joblib.load(join(PROCESSED_DIR, subject.replace(
            'UT', ''), 'corrs_test_35.pkl')).values[0]
        # threshold
        if frac_voxels_to_keep < 1:
            corrs_test_mask = (corrs_test > np.percentile(
                corrs_test, 100 * (1 - frac_voxels_to_keep))).astype(bool)
        else:
            corrs_test_mask = np.ones_like(corrs_test).astype(bool)
        flatmaps_qa_dict_masked = {k: flatmaps_qa_dict[k][corrs_test_mask]
                                   for k in flatmaps_qa_dict.keys()}
        flatmaps_gt_masked = {k: flatmaps_gt[k][corrs_test_mask]
                              for k in flatmaps_gt.keys()}

    # get common flatmaps and put into d
    common_keys = set(flatmaps_gt_masked.keys()) & set(
        flatmaps_qa_dict_masked.keys())
    d = defaultdict(list)
    for k in common_keys:
        d['questions'].append(k)
        d['corr'].append(np.corrcoef(flatmaps_qa_dict_masked[k],
                                     flatmaps_gt_masked[k])[0, 1])
        d['flatmap_qa'].append(flatmaps_qa_dict_masked[k])
        d['flatmap_neurosynth'].append(flatmaps_gt_masked[k])
    d = pd.DataFrame(d).sort_values('corr', ascending=False)

    corrs = viz._calc_corrs(
        d['flatmap_qa'].values,
        d['flatmap_neurosynth'].values,
        titles_qa=d['questions'].values,
        titles_gt=d['questions'].values,
    )

    corrs_df_list[f'corrs_{frac_voxels_to_keep}'].extend(
        np.diag(corrs).tolist())
    corrs_df_list['questions'].extend(d['questions'].values.tolist())
    corrs_df_list['subject'].extend([subject] * len(d['questions'].values))

    # viz.corr_bars(
    #     corrs,
    #     out_dir_save=join(repo_dir, 'qa_results', 'neurosynth', setting),
    #     xlab='Neurosynth',
    # )

    # save flatmaps
    # for i in tqdm(range(len(d))):
    #     sasc.viz.quickshow(
    #         d.iloc[i]['flatmap_qa'],
    #         subject=subject,
    #         fname_save=join(repo_dir, 'qa_results', 'neurosynth', subject,
    #                         setting, f'{d.iloc[i]["questions"]}.png')
    #     )

    #     sasc.viz.quickshow(
    #         d.iloc[i]['flatmap_neurosynth'],
    #         subject=subject,
    #         fname_save=join(repo_dir, 'qa_results', 'neurosynth', subject,
    #                         'neurosynth', f'{d.iloc[i]["questions"]}.png')
    #     )
    flatmaps_qa_dicts_by_subject[subject] = deepcopy(flatmaps_qa_dict)
corrs_df = pd.DataFrame(corrs_df_list)
# corrs_df.to_pickle(join(repo_dir, 'qa_results',
#    'neurosynth', setting + '_corrs_df.pkl'))

### Plot correlations in corrs_df

In [None]:
c = corrs_df
xlab = f'Flatmap correlation (Top-{int(100*frac_voxels_to_keep)}% best-predicted voxels)'
plt.figure(figsize=(7, 5))
colors = {
    'UTS01': 'C0',
    'UTS02': 'C1',
    'UTS03': 'C2',
    'mean': 'black'
}

d_mean = pd.DataFrame(c.groupby('questions')[
    f'corrs_{frac_voxels_to_keep}'].mean()).reset_index()
d_mean['subject'] = 'mean'
c = pd.concat([c, d_mean])
c = c.set_index('questions')

for subject in ['mean', 'UTS01', 'UTS02', 'UTS03']:
    corrs_df_subject = c[c['subject'] == subject]
    if subject == 'mean':
        idx_sort = corrs_df_subject[f'corrs_{frac_voxels_to_keep}'].sort_values(
            ascending=False).index
    corrs_df_subject = corrs_df_subject.loc[idx_sort]

    # plot corrs
    if subject == 'mean':
        plt.errorbar(
            corrs_df_subject[f'corrs_{frac_voxels_to_keep}'],
            range(len(corrs_df_subject)),
            color='black',
            fmt='o',
            zorder=1000,
            label=subject.capitalize(),
        )
    else:
        plt.errorbar(
            corrs_df_subject[f'corrs_{frac_voxels_to_keep}'],
            range(len(corrs_df_subject)),
            # xerr=np.sqrt(
            # r_df[f'corrs_{frac_voxels_to_keep}'] * (1-r_df[f'corrs_{frac_voxels_to_keep}'])
            # / r_df['num_test']),
            alpha=0.5,
            label=subject.upper(),
            fmt='o')
    plt.axvline(corrs_df_subject[f'corrs_{frac_voxels_to_keep}'].mean(),
                linestyle='--', color=colors[subject], zorder=-1000)

    print('mean corr', corrs_df_subject[f'corrs_{frac_voxels_to_keep}'].mean())

# add horizontal bars
plt.yticks(range(len(corrs_df_subject)), [term_dict_rev[k] for k in idx_sort])
plt.xlabel(xlab, fontsize='large')
plt.grid(axis='y', alpha=0.2)
plt.axvline(0, color='gray')

abs_lim = max(np.abs(plt.xlim()))
plt.xlim(-abs_lim, abs_lim)

# annotate with baseline and text label
plt.legend(fontsize='large')
plt.tight_layout()
plt.savefig(join(repo_dir, 'qa_results',
            'neurosynth', 'corrs_' + setting + '.png'), dpi=300)

### Compute pvals

In [None]:
subject = 'UTS03'
corrs_df_subject = corrs_df[corrs_df['subject']
                            == subject].set_index('questions')

# corrs_df = pd.DataFrame(corrs_df_dict)
flatmaps_qa_list_subject = [flatmaps_qa_dict[q]
                            for q in corrs_df_subject.index]
for frac_voxels_to_keep in tqdm(frac_voxels_to_keep_list):
    eng1000_dir = join(PROCESSED_DIR, subject.replace(
        'UT', ''), 'eng1000_weights.pkl')
    pvals = viz.compute_pvals(flatmaps_qa_list_subject, frac_voxels_to_keep,
                              corrs_df_subject[f'corrs_{frac_voxels_to_keep}'].values, eng1000_dir=eng1000_dir)

    # get what fraction of 'corrs_perm_eng1000' column is greater than f'corrs_{frac_voxels_to_keep}'
    corrs_df_subject[f'pval_{frac_voxels_to_keep}'] = pvals

# format scientific notation
corrs_df_subject.sort_values(
    by=f'pval_{frac_voxels_to_keep}').style.background_gradient().format(precision=3)

### Merged flatmaps

In [30]:
flatmaps_qa_dict_list_subjects = {subject: [flatmaps_qa_dicts_by_subject[subject][q] for q in corrs_df_subject.index]
                                  for subject in ['UTS01', 'UTS02', 'UTS03']}

In [8]:
from cortex import mni
import os
os.environ["FSLDIR"] = "/home/chansingh/fsl"

In [36]:
arr = flatmaps_qa_dict_list_subjects['UTS01'][0]

In [37]:
vol = cortex.Volume(data=arr.flatten(), subject='UTS01', xfmname='UTS01_auto')

In [38]:
s1_to_mni_cached = cortex.db.get_mnixfm('UTS01', 'UTS01_auto')

In [39]:
mni_data = mni.transform_to_mni(vol, s1_to_mni_cached)

In [40]:
mni_data_vol = mni_data.get_fdata()  # the actual array, shape=(182,218,182)

### Look at merged flatmaps

In [None]:
corrs_df = pd.read_pickle(join(repo_dir, 'qa_results',
                               'neurosynth', setting + '_corrs_df.pkl'))

In [None]:
setting = 'shapley_neurosynth'
for subject in ['UTS01', 'UTS02', 'UTS03']:
    img_dir1 = join(repo_dir, 'qa_results', 'neurosynth',
                    subject, 'neurosynth')
    img_dir2 = join(repo_dir, 'qa_results', 'neurosynth',
                    subject, setting)

    # read images and combine them with their filenames on a single plot
    # fnames = os.listdir(img_dir1)
    # fnames = [f for f in fnames if f.endswith('.png')]
    # only keep the ones that are in both directories
    # fnames = [f for f in fnames if f in os.listdir(img_dir2)]

    corrs = corrs_df[corrs_df['subject'] == subject]
    # corrs = corrs.sort_values(f'corrs_{frac_voxels_to_keep}', ascending=False)
    fnames = [v + '.png' for v in corrs['questions'].values]

    n = len(fnames)
    C = 4
    R = int(np.ceil(n / C))

    fig, axs = plt.subplots(R, C, figsize=(C * 3.2, R * 1))
    axs = axs.flatten()
    for i in range(len(axs)):
        axs[i].axis('off')
    for i, fname in enumerate(fnames):
        img1 = plt.imread(join(img_dir1, fname))
        img2 = plt.imread(join(img_dir2, fname))
        axs[i].imshow(np.concatenate([img1, img2], axis=1))
        axs[i].set_title(
            f'{term_dict_rev[fname[:-4]]} ({corrs["corrs"].values[i]:0.3f})', fontsize=8)

    # add text in bottom right of figure
    fig.text(0.99, 0.01, f'{subject}\nNeurosynth on left, QA on right',
             ha='right', va='bottom', fontsize=8)
    plt.tight_layout()
    plt.savefig(join(repo_dir, 'qa_results', 'neurosynth',
                subject, f'flatmaps_{setting}_{subject}.png'), dpi=300)
    plt.close()