In [41]:
%load_ext autoreload
%autoreload 2
import sys
import pandas as pd
import os
import matplotlib.pyplot as plt
import cortex
import seaborn as sns
from os.path import join
from collections import defaultdict
import numpy as np
from sklearn.preprocessing import StandardScaler
import joblib
import dvu
from copy import deepcopy
import sys
sys.path.append('../notebooks')
from tqdm import tqdm
from sasc.config import FMRI_DIR, STORIES_DIR, RESULTS_DIR, CACHE_DIR, RESULTS_DIR, cache_ngrams_dir, regions_idxs_dir
from neuro.config import repo_dir, PROCESSED_DIR
from neuro.features.qa_questions import get_questions, get_merged_questions_v3_boostexamples
import sasc.viz
from sasc import config
VOX_COUNTS = {
    'S02': 94251,
    'S03': 95556,
}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load avg-response flatmaps

In [49]:
# S02
# subject = 'S02'
# gemv_flatmaps_default = joblib.load(join(
# RESULTS_DIR, "processed", "flatmaps", 'resps_avg_dict_pilot.pkl'))
# gemv_flatmaps_roi_qa = joblib.load(join(
#     RESULTS_DIR, "processed", "flatmaps", 'resps_avg_dict_pilot5.pkl'))
# gemv_flatmaps_roi_custom = joblib.load(join(
#     RESULTS_DIR, "processed", "flatmaps_all", 'UTS02', 'roi_pilot7', 'resps_avg_dict_pilot6.pkl'))
# gemv_flatmaps_dict = gemv_flatmaps_default | gemv_flatmaps_roi_qa | gemv_flatmaps_roi_custom


# S03
subject = 'S03'
gemv_flatmaps_default = joblib.load(join(
    RESULTS_DIR, "processed", "flatmaps_all", 'UTS03', 'default', 'resps_avg_dict_pilot3.pkl'))
gemv_flatmaps_roi_custom1 = joblib.load(join(
    RESULTS_DIR, "processed", "flatmaps_all", 'UTS03', 'roi_pilot7', 'resps_avg_dict_pilot7.pkl'))
gemv_flatmaps_roi_custom2 = joblib.load(join(
    RESULTS_DIR, "processed", "flatmaps_all", 'UTS03', 'roi_pilot8', 'resps_avg_dict_pilot8.pkl'))
# gemv_flatmaps_dict = gemv_flatmaps_default | gemv_flatmaps_roi_custom1 | gemv_flatmaps_roi_custom2
gemv_flatmaps_dict = gemv_flatmaps_roi_custom1 | gemv_flatmaps_roi_custom2

# normalize flatmaps
normalize_flatmaps = False
if normalize_flatmaps:
    for k, v in gemv_flatmaps_dict.items():
        flatmap_unnormalized = gemv_flatmaps_dict[k]
        gemv_flatmaps_dict[k] = (
            flatmap_unnormalized - flatmap_unnormalized.mean()) / flatmap_unnormalized.std()

### Group regions to analyze into dictionary of voxel indexes
rois_dict is a dictionary with keys as region names and np arrays of voxel masks as values

In [50]:
def load_custom_rois(subject, suffix_setting='_fedorenko'):
    '''
    Params
    ------
    subject: str
        'S02' or 'S03'
    suffix_setting: str
        '' - load custom communication rois
        '_fedorenko' - load fedorenko rois
        '_spotlights' - load spotlights rois (there are a ton of these)
    '''
    if suffix_setting == '':
        # rois_dict = joblib.load(join(regions_idxs_dir, f'rois_{subject}.jbl'))
        # rois = joblib.load(join(FMRI_DIR, 'brain_tune/voxel_neighbors_and_pcs/', 'communication_rois_UTS02.jbl'))
        rois = joblib.load(join(FMRI_DIR, 'brain_tune/voxel_neighbors_and_pcs/',
                                f'communication_rois_v2_UT{subject}.jbl'))
        rois_dict_raw = {i: rois[i] for i in range(len(rois))}
        if subject == 'S02':
            raw_idxs = [
                [0, 7],
                [3, 4],
                [1, 5],
                [2, 6],
            ]
        elif subject == 'S03':
            raw_idxs = [
                [0, 7],
                [3, 4],
                [2, 5],
                [1, 6],
            ]
        return {
            'comm' + str(i): np.vstack([rois_dict_raw[j] for j in idxs]).sum(axis=0)
            for i, idxs in enumerate(raw_idxs)
        }
    elif suffix_setting == '_fedorenko':
        if subject == 'S03':
            rois_fedorenko = joblib.load(join(
                FMRI_DIR, 'brain_tune/voxel_neighbors_and_pcs/', 'lang_localizer_UTS03.jbl'))
        return {
            'fed' + str(i): rois_fedorenko[i] for i in range(len(rois_fedorenko))
        }
        # rois_dict = rois_dict_raw
    elif suffix_setting == '_spotlights':
        rois_spotlights = joblib.load(f'all_spotlights_UT{subject}.jbl')
        return {'spot' + str(i): rois_spotlights[i][-1]
                for i in range(len(rois_spotlights))}


def load_known_rois(subject):
    nonzero_entries_dict = joblib.load(
        join(regions_idxs_dir, f'rois_{subject}.jbl'))
    rois_dict = {}
    for k, v in nonzero_entries_dict.items():
        mask = np.zeros(VOX_COUNTS[subject])
        mask[v] = 1
        rois_dict[k] = deepcopy(mask)
    return rois_dict


rois_dict_known = load_known_rois(subject)
rois_dict_comm = load_custom_rois(subject, suffix_setting='')
rois_dict_fedorenko = load_custom_rois(subject, suffix_setting='_fedorenko')
# rois_dict_spotlights = load_custom_rois(subject, suffix_setting='_spotlights')
# | rois_dict_spotlights
rois_dict = rois_dict_comm | rois_dict_fedorenko  # | rois_dict_known

### Visualize averages over different regions

In [51]:
gemv_flatmaps_dict.keys()
avg_defaultdict = defaultdict(list)
for roi_idx, roi in rois_dict.items():
    for explanation in gemv_flatmaps_dict.keys():
        avg_defaultdict[roi_idx].append(
            np.mean(gemv_flatmaps_dict[explanation][roi > 0]))
        # corrs_defaultdict[roi_idx].append(
        # np.corrcoef(roi, gemv_flatmaps_dict[explanation])[0, 1])

In [52]:
df = pd.DataFrame(avg_defaultdict, index=gemv_flatmaps_dict.keys())
df['AVG'] = df.mean(axis=1)

df = df.round(3).sort_values('AVG', ascending=False)
# display pandas formatting with redblue colormap centered at 0 rounded to 3 decimal places
vabs = max(abs(df.values.min()), abs(df.values.max()))
df.style.background_gradient(cmap='coolwarm', axis=None, vmax=vabs, vmin=-vabs).format("{:.3f}").set_caption(
    'Average GEM-V driving response averaged over ROI'
)

Unnamed: 0,Unnamed: 1,comm0,comm1,comm2,comm3,fed0,fed1,fed2,fed3,fed4,AVG
START,,0.263,0.222,0.559,0.479,0.83,1.174,1.154,0.93,0.964,0.73
Introspection,,0.213,0.279,0.282,0.264,0.245,0.313,0.399,0.399,0.416,0.312
Professions and Personal Backgrounds,,0.227,0.218,0.242,0.338,0.251,0.206,0.332,0.241,0.245,0.256
Dialogue,,-0.261,-0.049,0.171,0.007,0.184,0.522,0.475,0.637,0.614,0.255
Gruesome body imagery,,0.01,0.038,0.163,0.153,0.152,0.291,0.358,0.362,0.31,0.204
Clothing and Physical Appearance,,-0.035,0.05,0.323,0.252,0.046,0.009,0.429,0.408,0.241,0.191
Dialogue and responses,,0.037,0.064,0.189,-0.026,0.102,0.253,0.209,0.349,0.275,0.161
END,,0.147,0.157,0.141,0.318,0.187,0.274,-0.016,0.075,0.144,0.159
Relationships,,0.03,0.101,0.176,0.172,0.11,0.052,0.301,0.293,0.119,0.15
Positive Emotional Reactions,,0.134,0.226,0.106,0.195,0.141,0.055,0.123,0.142,0.191,0.146


In [None]:
# sasc.viz._save_flatmap(
#     gemv_flatmaps_dict[('relationships between people', 'qa')], subject, fname_save=f'gemv_flatmaps/relationships_{subject}.png')