In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
from typing import List
import numpy as np
import joblib
from pprint import pprint
from math import ceil
import cortex
from sasc.config import CACHE_DIR, RESULTS_DIR, cache_ngrams_dir, regions_idxs_dir, FMRI_DIR, SAVE_DIR_FMRI
from neuro.config import repo_dir, PROCESSED_DIR
from collections import defaultdict
from scipy.stats import norm
from statsmodels.stats.multitest import multipletests
import gemv
from neuro.flatmaps_helper import load_flatmaps
import sasc.viz
import nibabel as nib
import neurosynth


# set os environ SUBJECTS_DIR
FREESURFER_VARS = {
    'FREESURFER_HOME': os.path.expanduser('~/freesurfer'),
    'FSL_DIR': os.path.expanduser('~/fsl'),
    'FSFAST_HOME': os.path.expanduser('~/freesurfer/fsfast'),
    'MNI_DIR': os.path.expanduser('~/freesurfer/mni'),
    # 'SUBJECTS_DIR': join(repo_dir, 'notebooks_gt_flatmaps'),
    'SUBJECTS_DIR': os.path.expanduser('~/freesurfer/subjects'),
    # add freesurfer bin to path
    'PATH': os.path.expanduser('~/freesurfer/bin') + ':' + os.environ['PATH'],
}
for k in FREESURFER_VARS.keys():
    os.environ[k] = FREESURFER_VARS[k]

subject = 'S02'

### Load flatmaps

In [None]:
if subject in ['S01', 'S02', 'S03']:
    settings = ['individual_gpt4', 'individual_gpt4_wordrate', 'shapley_35']
else:
    settings = ['individual_gpt4_ndel=1_pc_new']
flatmaps_qa_list = defaultdict(list)
for setting in settings:
    flatmaps_qa_dict = joblib.load(
        join(PROCESSED_DIR, subject.replace('UT', ''), setting + '.pkl'))
    for q in flatmaps_qa_dict.keys():
        flatmaps_qa_list[q].append(flatmaps_qa_dict[q])
flatmaps_qa_dict = {
    q: np.mean(flatmaps_qa_list[q], axis=0)
    for q in flatmaps_qa_list.keys()}

In [None]:
flatmap_subject = flatmaps_qa_dict[list(flatmaps_qa_dict.keys())[0]].copy()

In [31]:
subj_vol = cortex.Volume(flatmap_subject, 'UT' + subject,
                         xfmname=f"UT{subject}_auto")
mni_vol = neurosynth.subj_vol_to_mni_surf(subj_vol, subject)
print('subj shape', subj_vol.shape, 'mni shape', mni_vol.shape)

Caching mapper...
subj shape (54, 84, 84) mni shape (91, 109, 91)


In [None]:
term = 'locations'
mni_filename = f'/home/chansingh/mntv1/deep-fMRI/qa/neurosynth_data/all_association-test_z/{term}_association-test_z.nii.gz'
mni_vol = cortex.Volume(mni_filename, "fsaverage", "atlas_2mm")
subj_vol, subj_arr = neurosynth.mni_vol_to_subj_vol_surf(
    mni_vol, subject=subject)
print('mni shape', mni_vol.shape, 'subj shape',
      subj_vol.shape, 'subj_arr shape', subj_arr.shape)

sasc.viz.quickshow(
    subj_vol.data,
    subject='UT' + subject,
    fname_save=join(f'viz/{term}_subj.png'),
)
sasc.viz.quickshow(
    mni_vol,
    subject='fsaverage',
    fname_save=join(f'viz/{term}_mni.png'),
)

Caching mapper...
mni shape (91, 109, 91) subj shape (54, 84, 84) subj_arr shape (94251,)


# Save flatmaps

In [None]:
q = 'Does the sentence mention a specific location?'
for subject in ['S03']:  # ['S02']: #['S01', 'S02', 'S03']:
    if subject in ['S03']:
        # settings = ['individual_gpt4', 'individual_gpt4_wordrate', 'shapley_35']
        settings = ['shapley_35']
    else:
        settings = ['individual_gpt4_ndel=1_pc_new']
    flatmaps_qa_list = defaultdict(list)
    for setting in settings:
        flatmaps_qa_dict = joblib.load(
            join(PROCESSED_DIR, subject.replace('UT', ''), setting + '.pkl'))
        flatmaps_qa_list[q].append(flatmaps_qa_dict[q])
    flatmaps_qa_dict = {
        q: np.mean(flatmaps_qa_list[q], axis=0)
        for q in flatmaps_qa_list.keys()}

    print('visualizing...')
    sasc.viz.quickshow(
        flatmaps_qa_dict[q],
        subject='UT'+subject,
        fname_save=join(repo_dir, 'qa_results', 'figs',
                        'flatmaps_export', q, f'QA_{subject}.png'),
        # cmap='RdYlBu_r',
    )

In [None]:
gemv_flatmaps_dict_S02, gemv_flatmaps_dict_S03 = load_flatmaps(
    normalize_flatmaps=False, load_timecourse=False)
q_tup = ('locations', 368)
sasc.viz.quickshow(
    gemv_flatmaps_dict_S02[q_tup],
    subject='UTS02',
    fname_save=join(repo_dir, 'qa_results', 'figs',
                    'flatmaps_export', q, f'gemv_S02.png'),
    # cmap='RdYlBu_r',
)

In [None]:
sasc.viz.quickshow(
    subjvol.data,
    subject='UTS02',
    fname_save=join('viz/test.png'),
)