In [1]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
from typing import List
import numpy as np
import joblib
from pprint import pprint
from math import ceil
import cortex
from sasc.config import CACHE_DIR, RESULTS_DIR, cache_ngrams_dir, regions_idxs_dir, FMRI_DIR, SAVE_DIR_FMRI
from neuro.config import repo_dir, PROCESSED_DIR
from collections import defaultdict
from scipy.stats import norm
from statsmodels.stats.multitest import multipletests
import gemv
from neuro.flatmaps_helper import load_flatmaps
import sasc.viz
import nibabel as nib

subject = 'S02'

In [2]:
if subject in ['S01', 'S02', 'S03']:
    settings = ['individual_gpt4', 'individual_gpt4_wordrate', 'shapley_35']
else:
    settings = ['individual_gpt4_ndel=1_pc_new']
flatmaps_qa_list = defaultdict(list)
for setting in settings:
    flatmaps_qa_dict = joblib.load(
        join(PROCESSED_DIR, subject.replace('UT', ''), setting + '.pkl'))
    for q in flatmaps_qa_dict.keys():
        flatmaps_qa_list[q].append(flatmaps_qa_dict[q])
flatmaps_qa_dict = {
    q: np.mean(flatmaps_qa_list[q], axis=0)
    for q in flatmaps_qa_list.keys()}

In [3]:
flatmap_subject = flatmaps_qa_dict[list(flatmaps_qa_dict.keys())[0]].copy()

### Freesurfer transform

In [4]:
# set os environ SUBJECTS_DIR
FREESURFER_VARS = {
    'FREESURFER_HOME': os.path.expanduser('~/freesurfer'),
    'FSL_DIR': os.path.expanduser('~/fsl'),
    'FSFAST_HOME': os.path.expanduser('~/freesurfer/fsfast'),
    'MNI_DIR': os.path.expanduser('~/freesurfer/mni'),
    # 'SUBJECTS_DIR': join(repo_dir, 'notebooks_gt_flatmaps'),
    'SUBJECTS_DIR': os.path.expanduser('~/freesurfer/subjects'),
}
for k in FREESURFER_VARS.keys():
    os.environ[k] = FREESURFER_VARS[k]

In [16]:
def mni_vol_to_subj_vol(
    mni_vol,
    subject='S02',
    pycortex_db_dir='/home/chansingh/mntv1/deep-fMRI/data/ds003020/derivative/pycortex-db/',
):
    fs_mapper = cortex.get_mapper("fsaverage", "atlas_2mm")
    fssurf = fs_mapper(mni_vol)

    subjmapper = cortex.get_mapper(
        'UT' + subject, join(pycortex_db_dir, f'UT{subject}/transforms/UT{subject}_auto/'))
    (ltrans, rtrans) = cortex.db.get_mri_surf2surf_matrix(
        subject="fsaverage",
        surface_type="pial",
        target_subj=subject,
    )
    subjsurf = cortex.Vertex(
        np.hstack([ltrans@fssurf.left, rtrans@fssurf.right]), 'UT' + subject)
    subjvol = subjmapper.backwards(subjsurf)
    mask = cortex.db.get_mask('UT' + subject, 'UT' + subject + '_auto')
    return subjvol, subjvol.data[mask]


term = 'location'
mni_filename = f'/home/chansingh/mntv1/deep-fMRI/qa/neurosynth_data/all_association-test_z/{term}_association-test_z.nii.gz'
mni_vol = cortex.Volume(mni_filename, "fsaverage", "atlas_2mm")
subj_vol, subj_arr = mni_vol_to_subj_vol(mni_vol, subject='S02')
print('mni shape', mni_vol.shape, 'subj shape',
      subj_vol.shape, 'subj_arr shape', subj_arr.shape)

Caching mapper...
mni shape (91, 109, 91) subj shape (54, 84, 84) subj_arr shape (94251,)


In [12]:
sasc.viz.quickshow(
    subj_vol.data,
    subject='UTS02',
    fname_save=join('viz/mni_to_subj.png'),
)

  def imshow_diverging(mat, clab="Mean response ($\sigma$)", clab_size='medium', vabs_multiplier=1):
  )
  plt.grid(axis='y')
  plt.grid(axis='y')
  plt.grid(axis='y')
  


In [13]:
sasc.viz.quickshow(
    mni_vol,
    subject='fsaverage',
    fname_save=join('viz/mni.png'),
)

In [None]:
## download individual subject freesurfer dirs
## untar under $SUBJECTS_DIR, usually this is ~/freesurfer/subjects

subjects = [f"S0{ii}" for ii in np.arange(8)+1]
xfms = <however you usually load subj xfms>
terms = <list of terms>

## get fs to subj xfms
## *** YOU'LL NEED FSL FOR THIS STEP ***
## {subject: (lefthem_xfm, righthem_xfm)}
fs2subj_xfms = {}
for subject in subjects:
    ## first param is source subject; keep "pial" for surface
    fs2subj_xfms[subject] = cortex.db.get_mri_surf2surf_matrix("fsaverage", "pial", 
                                                        target_subj=subject)

## vol2surf mapping for MNI 
fsmapper = cortex.get_mapper("fsaverage", "atlas_2mm")

## {subj: {term: cortex.Volume/cortex.Vertex}}
subj_term_surfs = {}
subj_term_vols = {}
for subject, (ltrans, rtrans) in fs2subj_xfms.items(): 
    subjmapper = cortex.get_mapper(subject, xfms[subject])

    ## {term: data}
    subj_surfs = {}
    subj_vols = {}
    for term in terms:
        
        ## map MNI volumes to fsaverage surface 
        mnivol = cortex.Volume(<filename>, "fsaverage", "atlas_2mm")
        fssurf = fsmapper(mnivol)

        ## xtrans @ surfdata to get xfmed surfdata, then put into cortex.Vertex
        subjsurf = cortex.Vertex(np.hstack([ltrans@fssurf.left, rtrans@fssurf.right]), subject)
        ## map subject surface to volume
        subjvol = subjmapper.backwards(subjsurf)
        
        subj_surfs[term] = subjsurf
        subj_vols[term] = subjvol
#         break
        
    subj_term_surfs[subject] = subj_surfs
    subj_term_vols[subject] = subj_vols
#     break

# Save flatmaps

In [None]:
q = 'Does the sentence mention a specific location?'
for subject in ['S03']:  # ['S02']: #['S01', 'S02', 'S03']:
    if subject in ['S03']:
        # settings = ['individual_gpt4', 'individual_gpt4_wordrate', 'shapley_35']
        settings = ['shapley_35']
    else:
        settings = ['individual_gpt4_ndel=1_pc_new']
    flatmaps_qa_list = defaultdict(list)
    for setting in settings:
        flatmaps_qa_dict = joblib.load(
            join(PROCESSED_DIR, subject.replace('UT', ''), setting + '.pkl'))
        flatmaps_qa_list[q].append(flatmaps_qa_dict[q])
    flatmaps_qa_dict = {
        q: np.mean(flatmaps_qa_list[q], axis=0)
        for q in flatmaps_qa_list.keys()}

    print('visualizing...')
    sasc.viz.quickshow(
        flatmaps_qa_dict[q],
        subject='UT'+subject,
        fname_save=join(repo_dir, 'qa_results', 'figs',
                        'flatmaps_export', q, f'QA_{subject}.png'),
        # cmap='RdYlBu_r',
    )

In [None]:
gemv_flatmaps_dict_S02, gemv_flatmaps_dict_S03 = load_flatmaps(
    normalize_flatmaps=False, load_timecourse=False)
q_tup = ('locations', 368)
sasc.viz.quickshow(
    gemv_flatmaps_dict_S02[q_tup],
    subject='UTS02',
    fname_save=join(repo_dir, 'qa_results', 'figs',
                    'flatmaps_export', q, f'gemv_S02.png'),
    # cmap='RdYlBu_r',
)

In [71]:
sasc.viz.quickshow(
    subjvol.data,
    subject='UTS02',
    fname_save=join('viz/test.png'),
)