### Compute the test-retest reliability of COCO image presentations. Correlate odd and even runs withing a participant, then average over the 9 participants. This is to validate our preprocessing by replicating NOD manuscript Figure 5b

In [None]:
from dotenv import load_dotenv
load_dotenv()
import os
import sys
sys.path.append(os.getenv('PYTHONPATH')) 
import numpy as np
import pickle
import hcp_utils as hcp
import matplotlib.pyplot as plt
from nilearn import plotting

#local
from src.utils.helpers import vectorized_correlation

In [None]:
dataset_root = os.path.join(os.getenv("DATASETS_ROOT", "/default/path/to/datasets"),"NaturalObjectDataset") #use default if DATASETS_ROOT env variable is not set.
project_root = os.getenv("PROJECT_ROOT", "/default/path/to/project")
print(f"dataset_root: {dataset_root}")
print(f"project_root: {project_root}")
fmri_path = os.path.join(dataset_root,"derivatives", "GLM")
task='coco'
dataset = 'NOD'
#first define the stimulus order and matrix.
#Next we will essentially place the betas into this pre-defined matrix
with open(os.path.join(dataset_root, "derivatives", "stimuli_metadata", "testtrain_split", "coco_groupings_rdm.pkl"),'rb') as f:
    data = pickle.load(f)
#get a list of coco filenames. order doesnt matter
coco_filenames = [f for _,fname in data.items() for f in fname]
assert(len(coco_filenames) == 120)
subject_betas = {} #this will be a big dictionary holding all the beta estimates from the subjects
n_subjects = 9
nvertices = 91282
session_group = "sessiongroup-01"

In [None]:
for sub in range(1,n_subjects+1):
    betas_tmp_odd = {stim: [] for stim in coco_filenames}
    betas_tmp_even = {stim: [] for stim in coco_filenames}
    subject = f"sub-{sub:02}"    
    fmri_data_wb = np.load(os.path.join(fmri_path, subject, session_group, "TYPED_FITHRF_GLMDENOISE_RR.npy"), allow_pickle=True).item()
    fmri_data_wb = fmri_data_wb['betasmd'].squeeze() #squeezed to shape numvertices x numtrials
    fmri_data_wb = fmri_data_wb.T #transpose to shape numtrials x numvertices, more representative of the samples x features format  
    print(f"shape of {subject} betas in the session (numtrials x numvertices): {fmri_data_wb.shape}")
  
    with open(os.path.join(fmri_path, subject, session_group, f"{subject}_{session_group}_conditionOrderDM.pkl"), 'rb') as f:
        events_run, ses_conds = pickle.load(f)
    odd_conds = [] #even and odd conds list will include imagenet filenames too
    even_conds = []
    for run_number, event in enumerate(events_run):
        stimuli = event['trial_type']
        for stim in stimuli:
            tmp = stim.split('/')[-1]
            filename = tmp.split('.')[0]
            if run_number % 2 == 0:
                even_conds.append(filename)
                odd_conds.append("skip") #add an item to the list to preserve length, but the item is not one of the coco filenames
            else:
                odd_conds.append(filename)
                even_conds.append("skip")

    odd_idx = np.isin(odd_conds, coco_filenames) #excludes imagenet filenames and 'skip' filenames.
    even_idx = np.isin(even_conds, coco_filenames)

    fmri_data_wb_odd = fmri_data_wb[odd_idx, :]
    fmri_data_wb_even = fmri_data_wb[even_idx, :]
    
    #no need to normalize since we are doing a pearson correlation
    #average over repetitions
    odd_count = 0
    for odd_stim in odd_conds:
        if odd_stim in coco_filenames:
            betas_tmp_odd[odd_stim].append(fmri_data_wb_odd[odd_count,:])
            odd_count += 1
    even_count = 0
    for even_stim in even_conds:
        if even_stim in coco_filenames:
            betas_tmp_even[even_stim].append(fmri_data_wb_even[even_count,:])
            even_count += 1

    assert(odd_count == fmri_data_wb_odd.shape[0])
    assert(even_count == fmri_data_wb_even.shape[0])    

    numreps = 6 #max number of reps for a odd/even split
    numvertices = 91282

    #these matrices will be mainly nans because we use the numreps for coco images even on the imagenet ones
    betas_odd = np.zeros((len(betas_tmp_odd), numreps, numvertices))
    betas_odd.fill(np.nan)
    betas_even = np.zeros((len(betas_tmp_even), numreps, numvertices))
    betas_even.fill(np.nan)

    stimorder_odd = []
    stimorder_even = []
    for stimcount, b in enumerate(betas_tmp_odd.keys()):
        value = betas_tmp_odd[b]
        stimorder_odd.append(b)
        for repcount, v in enumerate(value): #loop over reps
            betas_odd[stimcount, repcount, :] = np.array(v)
    for stimcount, b in enumerate(betas_tmp_even.keys()):
        value = betas_tmp_even[b]
        stimorder_even.append(b)
        for repcount, v in enumerate(value): #loop over reps
            betas_even[stimcount, repcount, :] = np.array(v)
    assert(stimorder_odd == stimorder_even)

    betas_odd_mean = np.nanmean(betas_odd, axis=1)
    betas_even_mean = np.nanmean(betas_even, axis=1)
    print(f"Odd run betas shape: {betas_odd_mean.shape}")
    print(f"Even run betas shape: {betas_even_mean.shape}")
    subject_betas.update({f"{subject}_odd": betas_odd_mean, f"{subject}_even": betas_even_mean})

In [None]:
#no longer needed with vectorized correlation
nan_vertices_list = []
for subject, betas in subject_betas.items():
    sub_nans = np.argwhere(np.isnan(np.mean(betas, axis=0)))
    nan_vertices_list.extend(sub_nans.flatten())
    print(f"subject {subject} has {len(sub_nans)} nans")
nan_vertices = set(nan_vertices_list)
print(f"all subjects have {len(nan_vertices)} unique nans")

In [None]:
reliability = np.zeros((nvertices,))
for sub in range(1,n_subjects+1):
    subject = f"sub-{sub:02}"    
    print(f"running odd/even split correlation on subject: {subject}")
    betas_odd = subject_betas[f"{subject}_odd"]
    betas_even = subject_betas[f"{subject}_even"]
    #shuffled_indices = np.random.permutation(len(coco_filenames)) #another sanity check to make sure wer are capturing signal across categories
    #reliability += vectorized_correlation(betas_odd[shuffled_indices,:], betas_even)
    reliability += vectorized_correlation(betas_odd, betas_even, axis=0, ddof=1)

reliability = reliability/n_subjects #average the individual reliabilities

In [None]:
ext_list = ['png'] #only matters if save_flag is True 

save_flag=True #set to True to save plots or False to not save plots

save_root = os.path.join(project_root, "src", "fmriDatasetPreparation", "datasets", "NaturalObjectDataset", "validation", "output", "coco_reliability")
if not os.path.exists(save_root):
    os.makedirs(save_root)

views = ['lateral', 'ventral', 'medial'] #['lateral', 'medial', 'dorsal', 'ventral', 'anterior', 'posterior']
stat = reliability.copy()
print(f"Min, Max Group Averaged Odd/Even Reliability Pearson Correlation: {np.nanmin(stat)}, {np.nanmax(stat)}")
threshold = None
cmap = 'coolwarm' #'hot'
#save inflated surfaces
cortex_data = hcp.cortex_data(stat)
#determine global min/max for consistent color scaling
datamin = np.nanmin(cortex_data)
datamax = np.nanmax(cortex_data)
vmin=-datamax #datamin
vmax=datamax

views = ['lateral', 'ventral', 'dorsal'] #['lateral', 'medial', 'dorsal', 'ventral', 'anterior', 'posterior']
for hemi in ['left','right']:
    mesh = hcp.mesh.inflated
    bg = hcp.mesh.sulc
    for view in views:
        display = plotting.plot_surf_stat_map(mesh, cortex_data, hemi=hemi,
        threshold=threshold, bg_map=bg, view=view, cmap=cmap)
        if save_flag:
            for ext in ext_list:
                if ext == 'png':
                    plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-inflated_view-{view}_hemi-{hemi}.{ext}"),dpi=300)
                else:
                    plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-inflated_view-{view}_hemi-{hemi}.{ext}"))

#Save flat maps. hemispheres are combined in one plot
#get the data for both hemispheres
cortex_data_left = hcp.left_cortex_data(stat)
cortex_data_right = hcp.right_cortex_data(stat)

#create a figure with multiple axes to plot each anatomical image
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), subplot_kw={'projection': '3d'})
plt.subplots_adjust(wspace=0)
im = plotting.plot_surf(hcp.mesh.flat_left, cortex_data_left,
        threshold=threshold, bg_map=hcp.mesh.sulc_left, 
        colorbar=False, cmap=cmap, 
        vmin=vmin, vmax=vmax,
        axes = axes[0])
im = plotting.plot_surf(hcp.mesh.flat_right, cortex_data_right,
        threshold=threshold, bg_map=hcp.mesh.sulc_right, 
        colorbar=False, cmap=cmap, 
        vmin=vmin, vmax=vmax,
        axes = axes[1])

#flip along the horizontal
axes[0].invert_yaxis()
axes[1].invert_yaxis()

#create colorbar
norm = plt.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), shrink=0.6)

cbar.set_ticks([round(vmin,2), 0, round(vmax,2)])
cbar.set_ticklabels([round(vmin,2), 0, round(vmax,2)])
if save_flag:
    for ext in ext_list:
        if ext == 'png':
            plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-flat.{ext}"),dpi=300)
        else:
            plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-flat.{ext}"))
plt.show()