In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
from typing import List
import numpy as np
from copy import deepcopy
import joblib
from pprint import pprint
from math import ceil
import cortex
from neuro.config import repo_dir, PROCESSED_DIR, setup_freesurfer
from collections import defaultdict
from scipy.stats import norm
from statsmodels.stats.multitest import multipletests
import gemv
from neuro.flatmaps_helper import load_flatmaps
import sasc.viz
from neuro import analyze_helper
import nibabel as nib
neurosynth_compare = __import__('04_neurosynth_compare')
import neurosynth
from neuro.features.questions.gpt4 import QS_35_STABLE
setup_freesurfer()

subject = 'S02'
subjects = [f'S0{i}' for i in range(1, 9) if not i == 6] # there's some problem with S06 surf2surf
# subjects = ['S01', 'S02', 'S03']

### single flatmap

In [None]:
subj_vols = []
for term in ['place', 'location']:  # , 'locations']:
    mni_filename = f'/home/chansingh/mntv1/deep-fMRI/qa/neurosynth_data/all_association-test_z/{term}_association-test_z.nii.gz'
    mni_vol = cortex.Volume(mni_filename, "fsaverage", "atlas_2mm")
    subj_vol, subj_arr = neurosynth.mni_vol_to_subj_vol_surf(
        mni_vol, subject=subject)
    print('mni shape', mni_vol.shape, 'subj shape',
          subj_vol.shape, 'subj_arr shape', subj_arr.shape)

    sasc.viz.quickshow(
        subj_vol.data,
        subject='UT' + subject,
        fname_save=join(f'intersubject/{term}_subj.png'),
    )
    sasc.viz.quickshow(
        mni_vol,
        subject='fsaverage',
        fname_save=join(f'intersubject/{term}_mni.png'),
    )
    subj_vols.append(subj_vol.data)

sasc.viz.quickshow(
    np.array(subj_vols).mean(axis=0),
    subject='UT' + subject,
    fname_save=join(f'intersubject/avg_subj.png'),
)

In [None]:
# qs = ['Does the sentence mention a specific location?']
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
qs = QS_35_STABLE
for q in tqdm(qs):
    for subject in ['S02']:  # ['S01', 'S02', 'S03']:
        # if subject in ['S03']:
        # settings = ['individual_gpt4', 'individual_gpt4_wordrate', 'shapley_35']
        # settings = ['shapley_35']
        # else:
        settings = ['individual_gpt4_pc_new']
        flatmaps_qa_list = defaultdict(list)
        for setting in settings:
            flatmaps_qa_dict = joblib.load(
                join(PROCESSED_DIR, subject.replace('UT', ''), setting + '.pkl'))
            flatmaps_qa_list[q].append(flatmaps_qa_dict[q])
        flatmaps_qa_dict = {
            q: np.mean(flatmaps_qa_list[q], axis=0)
            for q in flatmaps_qa_list.keys()}

        print('visualizing...')
        sasc.viz.quickshow(
            flatmaps_qa_dict[q],
            subject='UT'+subject,
            fname_save=join(repo_dir, 'qa_results', 'figs',
                            # 'flatmaps_export', subject, q + '.png'),
                            'flatmaps_export', subject, q + '.pdf'),
            with_colorbar=False,
            # cmap='RdYlBu_r',
        )

In [7]:
# save a RedBlue horizontal colorbar
from matplotlib import cm
from matplotlib.colors import Normalize


def save_colorbar(cmap, norm, orientation='horizontal', filename='colorbar.png'):
    fig, ax = plt.subplots(figsize=(2.5, 0.5))
    fig.subplots_adjust(bottom=0.5)
    cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap),
                        cax=ax, orientation=orientation)
    cbar.set_label('Normalized coefficient')
    plt.savefig(filename, bbox_inches='tight', dpi=300, transparent=True)
    plt.close(fig)


cmap = cm.RdBu_r
norm = Normalize(vmin=-1, vmax=1)
save_colorbar(cmap, norm, orientation='horizontal',
              filename='colorbar_norm.png')

In [None]:
# save flatmaps into a grid
for subject in ['S01', 'S02', 'S03']:
    C = 4
    R = 9
    fig, axs = plt.subplots(R, C, figsize=(12 * 1.7 * 4/5, 12*9/7))
    for i, q in enumerate(qs):
        img = mpimg.imread(
            join(repo_dir, 'qa_results', 'figs', 'flatmaps_export', subject, q + '.png'))
        axs[i//C, i % C].imshow(img)
        # axs[i//C, i % C].axis('off')
        axs[i//C, i % C].set_title(analyze_helper.abbrev_question(q))
    for i in range(R * C):
        axs[i//C, i % C].axis('off')
    plt.tight_layout()
    plt.savefig(join(repo_dir, 'qa_results', 'figs',
                     #  'flatmaps_export', subject + '_grid.png'))
                     'flatmaps_export', subject + '_grid.pdf'))

In [None]:
gemv_flatmaps_dict_S02, gemv_flatmaps_dict_S03 = load_flatmaps(
    normalize_flatmaps=False, load_timecourse=False)
q_tup = ('locations', 368)
sasc.viz.quickshow(
    gemv_flatmaps_dict_S02[q_tup],
    subject='UTS02',
    fname_save=join(repo_dir, 'qa_results', 'figs',
                    'flatmaps_export', q, f'gemv_S02.png'),
    # cmap='RdYlBu_r',
)