## test parcellation approach with subj05

In [1]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import h5py
import numpy as np
import scipy as sp
import scipy.stats as stats
import nibabel as nibabel
import pandas as pd
import nibabel.freesurfer.mghformat as mgh
import scipy.io
import itertools 
import pickle

In [2]:
import sys
utils_dir = '/oak/stanford/groups/kalanit/biac2/kgs/projects/Dawn/NSD/code/streams/utils/'
sys.path.append(utils_dir)

In [3]:
from rsm_utils import get_flat_lower_tri, make_flat_rsms

In [4]:
data_dir = '../../../data/'
local_data_dir = '../../../local_data/'

subjid = ['06']
n_repeats = 3

#threshold for voxels (based on split-half reliability)
thresh = 0 #0.1 #0.2

In [5]:
#get ROI data
rh_parcels = []
for sidx, sid in enumerate(subjid):
    mgh_file = mgh.load(local_data_dir+'freesurfer/subj'+ sid +'/rh.tessellate.mgz')
    rh_parcels.append(mgh_file.get_fdata()[:,0,0])

In [6]:
num_rois = int(np.max(rh_parcels))

In [7]:
num_rois = 20

In [8]:
#get voxel level split-half reliability data
reliability = []
for sidx, sid in enumerate(subjid):
    
    sh_dir = local_data_dir + 'freesurfer/subj' + sid + '/rh_split_half.mat'
    sh = scipy.io.loadmat(sh_dir)
    
    reliability.append(sh['mean'])

In [9]:
# let's organize by ROI 
sh_by_ROI = [[[] for j in range(num_rois)] for i in range(len(subjid))]
total_vox = np.zeros((len(subjid), num_rois))

for sidx, sid in enumerate(subjid):  
    for roi_idx in range(num_rois):       
        sh_by_ROI[sidx][roi_idx]=reliability[sidx][:,rh_parcels[sidx] == roi_idx+1]
        total_vox[sidx,roi_idx] = len(sh_by_ROI[sidx][roi_idx][0])

In [10]:
total_vox

array([[1087., 1072., 1094., 1173., 1114., 1269., 1300., 1571., 1165.,
        1707., 1560.,  920., 1333., 1540., 1220., 2202., 1956., 1700.,
        1015., 1000.]])

In [11]:
all_ids = []
max_session = np.zeros(len(subjid))
for sidx, sid in enumerate(subjid):
    
    data = pd.read_csv(data_dir+'nsddata/ppdata/subj'+ sid +'/behav/responses.tsv', sep='\t')
    
    max_session[sidx] = np.max(np.array(data['SESSION'])) 
    
    all_ids.append(np.array(data['73KID']))
    
    #shared_mask.append(np.isin(all_ids[sidx],sharedix))

In [12]:
which_reps = []
for sidx, sid in enumerate(subjid):
    vals, idx_start, count = np.unique(all_ids[sidx], return_counts=True,
                                    return_index=True)
    which_reps.append(vals[count == n_repeats])
    
least_trials = min(which_reps, key=len)

In [13]:
id_nums_3reps = []
mask_3reps = []
for sidx, sid in enumerate(subjid):
    
    data = pd.read_csv(data_dir+'nsddata/ppdata/subj'+ sid +'/behav/responses.tsv', sep='\t')
    
    mask_3reps.append(np.isin(all_ids[sidx],which_reps[sidx]))
    id_nums_3reps.append(np.array(data['73KID'])[mask_3reps[sidx]])


In [14]:
arr1inds = id_nums_3reps[sidx].argsort()

In [15]:
#right hemisphere
betas_by_ROI = [[] for j in range(num_rois)]

for sidx, sid in enumerate(subjid):
    
    print(sidx)
    mask = mask_3reps[sidx]
    sorted_betas = []
    
    #get all betas across all sessions
    for sess in range(1,int(max_session[sidx])+1):
        print(sess)
                
        if(sess < 10):
            idx = '0' + str(sess)
        else:
            idx = str(sess)

        raw_betas = h5py.File(local_data_dir+'freesurfer/subj'+sid+'/betas/rh.zscore_betas_session'+idx+'.hdf5','r')

        sess_betas = raw_betas['zscore_betas'][:][mask[(sess-1)*750:sess*750]]
        del raw_betas
        
        if(sess==1):
            for roi_idx in range(num_rois):
                betas_by_ROI[roi_idx] = sess_betas[:,rh_parcels[sidx] == roi_idx+1]
        else:
            for roi_idx in range(num_rois):
                betas_by_ROI[roi_idx] = np.append(betas_by_ROI[roi_idx],sess_betas[:,rh_parcels[sidx] == roi_idx+1],axis=0)
        
        del sess_betas


0
1
2


OSError: Can't read data (file read failed: time = Fri Feb  5 16:03:48 2021
, filename = '../../../local_data/freesurfer/subj06/betas/rh.zscore_betas_session02.hdf5', file descriptor = 59, errno = 5, error message = 'Input/output error', buf = 0x7fb756a94810, total read size = 1556093984, bytes this sub-read = 1556093984, bytes actually read = 18446744073709551615, offset = 344064)

In [None]:
betas_by_repeat_by_ROI = [[[] for j in range(num_rois)] for i in range(len(subjid))]
for sidx, sid in enumerate(subjid):
    for roi_idx in range(num_rois):  
        
        sorted_betas = betas_by_ROI[roi_idx][arr1inds[::-1]]
        
        for r in range(n_repeats):
            betas_by_repeat_by_ROI[sidx][roi_idx].insert(r,sorted_betas[r::3])

In [None]:
betas_by_repeat_by_ROI[0][0]

In [None]:
#Replace voxels with split-half reliability < thresh with NaNs and then trim those from data structure

#convert to nans
for sidx, sid in enumerate(subjid):  
    for roi_idx in range(num_rois): 
        for vox in range(len(sh_by_ROI[sidx][roi_idx][0])):
            if sh_by_ROI[sidx][roi_idx][0][vox] < thresh:
                betas_by_repeat_by_ROI[sidx][roi_idx][0][:,vox]=np.nan
                betas_by_repeat_by_ROI[sidx][roi_idx][1][:,vox]=np.nan
                betas_by_repeat_by_ROI[sidx][roi_idx][2][:,vox]=np.nan    
#trim out nans
for sidx, sid in enumerate(subjid):   
    for roi_idx in range(num_rois): 
        for r in range(n_repeats):
            temp = betas_by_repeat_by_ROI[sidx][roi_idx][r]
            trimmed = temp[:,~np.all(np.isnan(temp), axis=0)]

            betas_by_repeat_by_ROI[sidx][roi_idx][r] = trimmed

In [None]:
del sorted_betas

In [None]:
del betas_by_ROI

In [None]:
del trimmed

In [None]:
del temp

In [None]:
#Create RSMS for all the ROIs, repeats and subjects
tril_flat_shape = int((betas_by_repeat_by_ROI[0][0][0].shape[0]**2/2) - (betas_by_repeat_by_ROI[0][0][0].shape[0]/2))
flat_rsm_r1 = np.zeros((20, tril_flat_shape))
flat_rsm_r2 = np.zeros((20, tril_flat_shape))
flat_rsm_r3 = np.zeros((20, tril_flat_shape))


In [None]:
sidx = 0
for roi_idx in range(20):
        
    rsm = np.corrcoef(betas_by_repeat_by_ROI[sidx][roi_idx][0])
    flat_rsm_r1[roi_idx, :] = get_flat_lower_tri(rsm,diagonal=False)

    rsm = np.corrcoef(betas_by_repeat_by_ROI[sidx][roi_idx][1])
    flat_rsm_r2[roi_idx, :] = get_flat_lower_tri(rsm,diagonal=False)

    rsm = np.corrcoef(betas_by_repeat_by_ROI[sidx][roi_idx][2])
    flat_rsm_r3[roi_idx, :] = get_flat_lower_tri(rsm,diagonal=False)


In [None]:
flat_rsm_r1

In [None]:
r1_trial_order = [0, 0, 1, 1, 2, 2]
r2_trial_order = [1, 2, 0, 2, 0, 1]

In [None]:
del betas_by_repeat_by_ROI

In [None]:
del rsm

In [None]:
mega_matrix = np.zeros((num_rois,num_rois))

for roi_idx1 in range(num_rois): #rows - i.e. model candidate
    
    split_half = np.zeros((3))
    split_half = [stats.pearsonr(flat_rsm_r1[roi_idx1,:],flat_rsm_r2[roi_idx1,:])[0],
                  stats.pearsonr(flat_rsm_r1[roi_idx1,:],flat_rsm_r3[roi_idx1,:])[0],
                  stats.pearsonr(flat_rsm_r2[roi_idx1,:],flat_rsm_r3[roi_idx1,:])[0]]
    NC_model = np.mean(split_half) * 100
    
    for roi_idx2 in range(num_rois): #columns - i.e. target data
        
        split_half = np.zeros((3))
        split_half = [stats.pearsonr(flat_rsm_r1[roi_idx2,:],flat_rsm_r2[roi_idx2,:])[0],
                      stats.pearsonr(flat_rsm_r1[roi_idx2,:],flat_rsm_r3[roi_idx2,:])[0],
                      stats.pearsonr(flat_rsm_r2[roi_idx2,:],flat_rsm_r3[roi_idx2,:])[0]]
        NC_target = np.mean(split_half) * 100
        
        
        rsm_corr = np.zeros((6))
        
        rsm_corr[0] = stats.pearsonr(flat_rsm_r1[roi_idx1,:],
                                     flat_rsm_r2[roi_idx2,:])[0]
        rsm_corr[1] = stats.pearsonr(flat_rsm_r1[roi_idx1,:],
                                     flat_rsm_r3[roi_idx2,:])[0]
        rsm_corr[2] = stats.pearsonr(flat_rsm_r2[roi_idx1,:],
                                     flat_rsm_r1[roi_idx2,:])[0]
        rsm_corr[3] = stats.pearsonr(flat_rsm_r2[roi_idx1,:],
                                     flat_rsm_r3[roi_idx2,:])[0]
        rsm_corr[4] = stats.pearsonr(flat_rsm_r3[roi_idx1,:],
                                     flat_rsm_r1[roi_idx2,:])[0]
        rsm_corr[5] = stats.pearsonr(flat_rsm_r3[roi_idx1,:],
                                     flat_rsm_r2[roi_idx2,:])[0]
        
        mega_matrix[roi_idx1,roi_idx2] = np.mean(rsm_corr) * np.sqrt(100/NC_model) * np.sqrt(100/NC_target)
        

In [None]:
plt.imshow(mega_matrix)