In [None]:
from dotenv import load_dotenv
load_dotenv()
import os
import sys
sys.path.append(os.getenv('PYTHONPATH')) 
import numpy as np
import pickle
from tqdm import tqdm
import hcp_utils as hcp
import matplotlib.pyplot as plt
from nilearn import plotting

#local
from src.utils.helpers import vectorized_correlation

### Compute the test-retest reliability of imagenet image presentations. Correlate sessions within a participant, then average over the 9 participants. This is to validate our preprocessing by replicating NOD manuscript Figure 5a

In [None]:
dataset_root = os.path.join(os.getenv("DATASETS_ROOT", "/default/path/to/datasets"),"NaturalObjectDataset") #use default if DATASETS_ROOT env variable is not set.
project_root = os.getenv("PROJECT_ROOT", "/default/path/to/project")
print(f"dataset_root: {dataset_root}")
print(f"project_root: {project_root}")
fmri_path = os.path.join(dataset_root,"derivatives", "GLM")
task='imagenet'
dataset = 'NOD'
#imagenet
with open(os.path.join(dataset_root,"derivatives", "stimuli_metadata", "testtrain_split", "synset_words_edited.txt"), 'r') as f:
    # Initialize lists to store the columns
    imagenet_names = []

    # Iterate through each line in the file
    for line in f:
        # Split the line at the first space to get the 'n*' code and the labels
        parts = line.strip().split(' ', 1)  # Split on first space only
        imagenet_names.append(parts[0])  # First part is the imagenet name

#get a list of imagenet filenames n*. order doesnt matter
assert(len(imagenet_names) == 1000)
subject_betas = {} #this will be a big dictionary holding all the beta estimates from the subjects
n_subjects = 9
nvertices = 91282
session_group = "sessiongroup-01"

In [None]:
session_divisions = {"session01": range(10), "session02": range(10,20),"session03": range(20,30),"session04": range(30,40)}
for sub in range(1, n_subjects+1):
    session_conds = {"session01": [], "session02": [],"session03": [],"session04": []}
    betas_tmp_session01 = {stim: [] for stim in imagenet_names}
    betas_tmp_session02 = {stim: [] for stim in imagenet_names}
    betas_tmp_session03 = {stim: [] for stim in imagenet_names}
    betas_tmp_session04 = {stim: [] for stim in imagenet_names}

    subject = f"sub-{sub:02}"    
    fmri_data_wb = np.load(os.path.join(fmri_path, subject, session_group, "TYPED_FITHRF_GLMDENOISE_RR.npy"), allow_pickle=True).item()
    fmri_data_wb = fmri_data_wb['betasmd'].squeeze() #squeezed to shape numvertices x numtrials
    fmri_data_wb = fmri_data_wb.T #transpose to shape numtrials x numvertices, more representative of the samples x features format  
    print(f"shape of {subject} betas in the session (numtrials x numvertices): {fmri_data_wb.shape}")
  
    with open(os.path.join(fmri_path, subject, session_group, f"{subject}_{session_group}_conditionOrderDM.pkl"), 'rb') as f:
        events_run, ses_conds = pickle.load(f)

    session01_conds = [] 
    session02_conds = [] 
    session03_conds = [] 
    session04_conds = [] 
    for run_number, event in enumerate(events_run):
        stimuli = event['trial_type']
        for stim in stimuli:
            if 'imagenet' in stim:
                filename = stim.split('/')[1]
            elif 'coco' in stim:
                tmp = stim.split('/')[-1]
                filename = tmp.split('.')[0]

            if run_number in session_divisions["session01"]:
                session01_conds.append(filename) #add an item to the list to preserve length
                session02_conds.append("skip") #add an item to the list to preserve length
                session03_conds.append("skip") #add an item to the list to preserve length
                session04_conds.append("skip") #add an item to the list to preserve length
            elif run_number in session_divisions["session02"]:
                session01_conds.append("skip")
                session02_conds.append(filename) #add an item to the list to preserve length
                session03_conds.append("skip") #add an item to the list to preserve length
                session04_conds.append("skip") #add an item to the list to preserve length
            elif run_number in session_divisions["session03"]:
                session01_conds.append("skip")
                session02_conds.append("skip") #add an item to the list to preserve length
                session03_conds.append(filename) #add an item to the list to preserve length
                session04_conds.append("skip") #add an item to the list to preserve length
            elif run_number in session_divisions["session04"]:
                session01_conds.append("skip")
                session02_conds.append("skip") #add an item to the list to preserve length
                session03_conds.append("skip") #add an item to the list to preserve length
                session04_conds.append(filename) #add an item to the list to preserve length
            else:
                session01_conds.append("skip")
                session02_conds.append("skip") #add an item to the list to preserve length
                session03_conds.append("skip") #add an item to the list to preserve length
                session04_conds.append("skip") #add an item to the list to preserve length        
    session01_idx = np.isin(session01_conds, imagenet_names) #excludes coco filenames and 'skip' filenames.
    session02_idx = np.isin(session02_conds, imagenet_names) #excludes coco filenames and 'skip' filenames.
    session03_idx = np.isin(session03_conds, imagenet_names) #excludes coco filenames and 'skip' filenames.
    session04_idx = np.isin(session04_conds, imagenet_names) #excludes coco filenames and 'skip' filenames.

    fmri_data_wb_session01 = fmri_data_wb[session01_idx, :]
    fmri_data_wb_session02 = fmri_data_wb[session02_idx, :]
    fmri_data_wb_session03 = fmri_data_wb[session03_idx, :]
    fmri_data_wb_session04 = fmri_data_wb[session04_idx, :]
    
    #no need to normalize since we are doing a pearson correlation
    #average over repetitions
    session01_count = 0
    for session01_stim in session01_conds:
        if session01_stim in imagenet_names:
            betas_tmp_session01[session01_stim].append(fmri_data_wb_session01[session01_count,:])
            session01_count += 1
    session02_count = 0
    for session02_stim in session02_conds:
        if session02_stim in imagenet_names:
            betas_tmp_session02[session02_stim].append(fmri_data_wb_session02[session02_count,:])
            session02_count += 1
    session03_count = 0
    for session03_stim in session03_conds:
        if session03_stim in imagenet_names:
            betas_tmp_session03[session03_stim].append(fmri_data_wb_session03[session03_count,:])
            session03_count += 1
    session04_count = 0
    for session04_stim in session04_conds:
        if session04_stim in imagenet_names:
            betas_tmp_session04[session04_stim].append(fmri_data_wb_session04[session04_count,:])
            session04_count += 1

    assert(session01_count == fmri_data_wb_session01.shape[0])
    assert(session02_count == fmri_data_wb_session02.shape[0])
    assert(session03_count == fmri_data_wb_session03.shape[0])
    assert(session04_count == fmri_data_wb_session04.shape[0])

    numreps = 1 #max number of reps for each session split
    numvertices = 91282

    betas_session01 = np.zeros((len(betas_tmp_session01), numreps, numvertices))
    betas_session01.fill(np.nan)
    betas_session02 = np.zeros((len(betas_tmp_session02), numreps, numvertices))
    betas_session02.fill(np.nan)
    betas_session03 = np.zeros((len(betas_tmp_session03), numreps, numvertices))
    betas_session03.fill(np.nan)
    betas_session04 = np.zeros((len(betas_tmp_session04), numreps, numvertices))
    betas_session04.fill(np.nan)

    stimorder_session01 = []
    stimorder_session02 = []
    stimorder_session03 = []
    stimorder_session04 = []

    for stimcount, b in enumerate(betas_tmp_session01.keys()):
        value = betas_tmp_session01[b]
        stimorder_session01.append(b)
        for repcount, v in enumerate(value): #loop over reps
            betas_session01[stimcount, repcount, :] = np.array(v)
    for stimcount, b in enumerate(betas_tmp_session02.keys()):
        value = betas_tmp_session02[b]
        stimorder_session02.append(b)
        for repcount, v in enumerate(value): #loop over reps
            betas_session02[stimcount, repcount, :] = np.array(v)
    for stimcount, b in enumerate(betas_tmp_session03.keys()):
        value = betas_tmp_session03[b]
        stimorder_session03.append(b)
        for repcount, v in enumerate(value): #loop over reps
            betas_session03[stimcount, repcount, :] = np.array(v)
    for stimcount, b in enumerate(betas_tmp_session04.keys()):
        value = betas_tmp_session04[b]
        stimorder_session04.append(b)
        for repcount, v in enumerate(value): #loop over reps
            betas_session04[stimcount, repcount, :] = np.array(v)

    assert(stimorder_session01 == stimorder_session02 == stimorder_session03 == stimorder_session04)

    betas_session01_mean = np.nanmean(betas_session01, axis=1)
    betas_session02_mean = np.nanmean(betas_session02, axis=1)
    betas_session03_mean = np.nanmean(betas_session03, axis=1)
    betas_session04_mean = np.nanmean(betas_session04, axis=1)

    print(f"Session01 betas shape: {betas_session01_mean.shape}")
    print(f"Session02 betas shape: {betas_session02_mean.shape}")
    print(f"Session03 betas shape: {betas_session03_mean.shape}")
    print(f"Session04 betas shape: {betas_session04_mean.shape}")

    subject_betas.update({f"{subject}_session01": betas_session01_mean,
                          f"{subject}_session02": betas_session02_mean,
                          f"{subject}_session03": betas_session03_mean,
                          f"{subject}_session04": betas_session04_mean})

In [None]:
reliability = np.zeros((nvertices,))
splits = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
for sub in range(1,n_subjects+1):
    subject_reliability = np.zeros((nvertices,))
    subject = f"sub-{sub:02}"    
    print(f"running pairwise split correlation between sessions on subject: {subject}")
    for split in tqdm(splits):
        betas_sessionA = subject_betas[f"{subject}_session{split[0]+1:02}"]
        betas_sessionB = subject_betas[f"{subject}_session{split[1]+1:02}"]
        #shuffled_indices = np.random.permutation(len(imagenet_names)) #another sanity check to make sure wer are capturing signal across categories
        #subject_reliability += vectorized_correlation(betas_sessionA[shuffled_indices,:], betas_sessionB)
        subject_reliability += vectorized_correlation(betas_sessionA, betas_sessionB, axis=0, ddof=1)
    reliability += subject_reliability/len(splits)
reliability = reliability/n_subjects #average the individual reliabilities

In [None]:
ext_list = ['png'] #only matters if save_flag is True 

save_flag=True #set to True to save plots or False to not save plots

save_root = os.path.join(project_root, "src", "fmriDatasetPreparation", "datasets", "NaturalObjectDataset", "validation", "output", "imagenet_reliability")
if not os.path.exists(save_root):
    os.makedirs(save_root)

views = ['lateral', 'ventral', 'medial'] #['lateral', 'medial', 'dorsal', 'ventral', 'anterior', 'posterior']
stat = reliability.copy()
print(f"Min, Max Group Averaged Pairwise Session Reliability Pearson Correlation: {np.nanmin(stat)}, {np.nanmax(stat)}")
threshold = None
cmap = 'coolwarm' #'hot'
#save inflated surfaces
cortex_data = hcp.cortex_data(stat)
#determine global min/max for consistent color scaling
datamin = np.nanmin(cortex_data)
datamax = np.nanmax(cortex_data)
vmin=-datamax #datamin
vmax=datamax

views = ['lateral', 'ventral', 'dorsal'] #['lateral', 'medial', 'dorsal', 'ventral', 'anterior', 'posterior']
for hemi in ['left','right']:
    mesh = hcp.mesh.inflated
    bg = hcp.mesh.sulc
    for view in views:
        display = plotting.plot_surf_stat_map(mesh, cortex_data, hemi=hemi,
        threshold=threshold, bg_map=bg, view=view, cmap=cmap)
        if save_flag:
            for ext in ext_list:
                if ext == 'png':
                    plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-inflated_view-{view}_hemi-{hemi}.{ext}"),dpi=300)
                else:
                    plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-inflated_view-{view}_hemi-{hemi}.{ext}"))

#Save flat maps. hemispheres are combined in one plot
#get the data for both hemispheres
cortex_data_left = hcp.left_cortex_data(stat)
cortex_data_right = hcp.right_cortex_data(stat)

#create a figure with multiple axes to plot each anatomical image
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), subplot_kw={'projection': '3d'})
plt.subplots_adjust(wspace=0)
im = plotting.plot_surf(hcp.mesh.flat_left, cortex_data_left,
        threshold=threshold, bg_map=hcp.mesh.sulc_left, 
        colorbar=False, cmap=cmap, 
        vmin=vmin, vmax=vmax,
        axes = axes[0])
im = plotting.plot_surf(hcp.mesh.flat_right, cortex_data_right,
        threshold=threshold, bg_map=hcp.mesh.sulc_right, 
        colorbar=False, cmap=cmap, 
        vmin=vmin, vmax=vmax,
        axes = axes[1])

#flip along the horizontal
axes[0].invert_yaxis()
axes[1].invert_yaxis()

#create colorbar
norm = plt.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), shrink=0.6)

cbar.set_ticks([round(vmin,2), 0, round(vmax,2)])
cbar.set_ticklabels([round(vmin,2), 0, round(vmax,2)])
if save_flag:
    for ext in ext_list:
        if ext == 'png':
            plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-flat.{ext}"),dpi=300)
        else:
            plt.savefig(os.path.join(save_root, f"groupISC_{dataset}_task-{task}_mesh-flat.{ext}"))
plt.show()