In [None]:
%matplotlib inline
import os, sys, glob, scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nilearn, nibabel, nltools, nistats
from nltools.data import Brain_Data
from nistats.reporting import get_clusters_table
from nilearn.plotting import plot_stat_map, plot_roi, plot_img, plot_glass_brain
from nilearn.input_data import NiftiMasker
from nibabel.nifti1 import Nifti1Image

In [None]:
base_dir = os.path.realpath('../../..')
print(base_dir)
results_dir = os.path.realpath('../../../Results/voxelwise_ISRSA/nifti')

## Function definitions

In [None]:
def get_clusters(image, stat_threshold = 0, extent_threshold = 0, sort_by_size = False, include_peaks_only = True):
    
    # Compute negative image
    neg_image_arr = -(image.get_data().copy())
    neg_image = Nifti1Image(neg_image_arr, image.affine.copy())
    
    # Find clusters
    clusters = get_clusters_table(image, stat_threshold, cluster_threshold = extent_threshold)
    clusters = clusters.rename(columns = {'Cluster ID':'ID', 'Cluster Size (mm3)':'size',
                                          'Peak Stat':'peak_value'})
    neg_clusters = get_clusters_table(neg_image, stat_threshold, cluster_threshold = extent_threshold)
    neg_clusters = neg_clusters.rename(columns = {'Cluster ID':'ID', 'Cluster Size (mm3)':'size',
                                          'Peak Stat':'peak_value'})
    neg_clusters['peak_value'] = -neg_clusters['peak_value']
    
    # Append pos and neg
    clusters = clusters.append(neg_clusters)
    clusters['abs_peak_value'] = clusters['peak_value'].apply(np.abs)
    clusters = clusters.sort_values(by='abs_peak_value', ascending = False)
    
    # Get peaks
    if include_peaks_only:
        peaks = clusters.copy()
        peaks = peaks.loc[peaks['ID'].apply(lambda x: not str(x)[-1].isalpha()),:].reset_index(drop=True)
        if sort_by_size:
            peaks = peaks.sort_values(by = 'size', ascending = False).reset_index(drop=True)
    
    return clusters, peaks

In [None]:
def xyz_to_ijk(MNI, image):
    return list(np.linalg.inv(image.affine[:3,:3]).dot(MNI-image.affine[:3,3]).astype(int))

In [None]:
def expand_cluster(image, cluster_peak_MNI, verbose = True):
    
    # Find IJK of peak
    i,j,k = xyz_to_ijk(cluster_peak_MNI,image)
    cluster_peak_IJK = [i,j,k]
    if verbose: print('MNI: %s, IJK: %s'%(cluster_peak_MNI, cluster_peak_IJK))
    
    # Double check that peak value is correct
    peakval = image.get_data()[i][j][k]
    if verbose: print('Peak value extracted from image data array: %f'%peakval)
    
    # Binarize image
    binarized = (image.get_data() != 0).astype(int)
        
    # Label each cluster with a different number, reserve 0 for empty voxels
    conn_mat = np.zeros((3, 3, 3), int)  # 6-connectivity, aka NN1 or "faces"
    conn_mat[1, 1, :] = 1
    conn_mat[1, :, 1] = 1
    conn_mat[:, 1, 1] = 1
    label_map = scipy.ndimage.measurements.label(binarized, conn_mat)[0]
    clust_image = nibabel.Nifti1Image(label_map, affine=image.affine)
    
    # Find voxels with same label as cluster peak
    cluster_label = clust_image.get_data()[i,j,k]
    cluster_ROI = (clust_image.get_data() == cluster_label).astype(int)
    ROI_mask = nibabel.Nifti1Image(cluster_ROI, affine = image.affine)
    if verbose: print('Cluster ROI located')
    if verbose: print('Cluster size as extracted from image data array: %i voxels = %i mm^3'%(
        sum(cluster_ROI.flatten()),sum(cluster_ROI.flatten())*8))
    
    return cluster_peak_IJK, cluster_ROI, ROI_mask

## Select effect of interest

In [None]:
run = 3
filter_TR = False
TR_start = 1
TR_end = 711
model = 'ideology'
term = 'scale(ideology_similarity)'
threshold = 'thr-pval-fdr-0.05'

In [None]:
run_model_dir = ('run-%i_TRs-%i-%i_model-%s'%(run,TR_start,TR_end,model) 
                     if filter_TR else 'run-%i_model-%s'%(run,model))
fpath = glob.glob('%s/%s/*%s*%s*.nii.gz'%(results_dir, run_model_dir, term, threshold))[0]
print(fpath)

##### Load image

In [None]:
image = nilearn.image.load_img(fpath)

In [None]:
%matplotlib inline
plot_stat_map(image, colorbar=False)
plt.show()

In [None]:
plot_stat_map(image, cut_coords=np.arange(-60,1,10), display_mode = 'x', colorbar = False)
plot_stat_map(image, cut_coords=np.arange(0,61,10), display_mode = 'x', colorbar = False)

##### Find clusters

In [None]:
min_k = 5
clusters, peaks = get_clusters(image, 0, min_k)
display(peaks.head())
print(peaks.shape[0])

##### Plot peaks

In [None]:
plot_n = np.shape(peaks)[0]
if plot_n < 30:
    fig, ax = plt.subplots(nrows = plot_n, ncols = 1, figsize = [8,plot_n*2])
    for clust_i, row in peaks.iterrows():
        vmax = 2
        plot_stat_map(image, cut_coords = row[['X','Y','Z']].values.flatten(),
                      axes = ax[clust_i], cmap = 'RdBu_r', vmax = vmax, colorbar = False)
        ax[clust_i].set(title = 'Cluster ID %i, %imm^3'%(row['ID'],row['size']))

## Find cluster locations and masks

In [None]:
cluster_list = peaks['ID'].unique()
ROIs = dict()
for ID in cluster_list:
    print(ID, end = ',')
    cluster_info = peaks.query('ID == @ID')
#     display(cluster_info)
    cluster_peak_MNI = list(cluster_info.iloc[0].loc[['X','Y','Z']].values.flatten())
    cluster_peak_IJK, cluster_voxels, cluster_mask = expand_cluster(image, cluster_peak_MNI)
    ROIs[ID] = {'peak_IJK':cluster_peak_IJK,'mask':cluster_mask}
ROIs

## Load data and compute inter-subject correlation values from mean cluster signal

In [None]:
def load_sub_funx(sub, run, nifti_dir = "/gpfs_home/jvanbaar/data/jvanbaar/polarization/derivatives/cleaning"):
    fname = nifti_dir + '/sub-%03d'%sub + '/ses-1/func/' + \
            'sub-%03d_ses-1_task-videoWatching_run-%i'%(sub,run) + \
            '_space-MNI152NLin2009cAsym_desc-cleaned_bold.nii.gz'
    sub_funx = nilearn.image.load_img(fname)
    return sub_funx

In [None]:
def extract_sub_cluster_data(sub_funx, cluster_peak_IJK, cluster_mask,
                            return_peak_data = True,
                            return_mean_cluster_data = True,
                            return_full_cluster_data = False):
    
    return_dict = {}
    
    # Get values at peak
    if return_peak_data:
        i,j,k = cluster_peak_IJK
        peak_data = sub_funx.get_data()[i,j,k]
        peak_data = pd.DataFrame(peak_data, columns=['peak_BOLD']).reset_index()
        peak_data = peak_data.rename(columns = {'index':'TR'})
        peak_data['sub'] = sub
        peak_data['cluster'] = ID
        peak_data = peak_data[['sub','cluster','TR','peak_BOLD']]
        return_dict['peak_data'] = peak_data
    
    # Get cluster data
    if return_mean_cluster_data or return_full_cluster_data:
        roi_masker = NiftiMasker(mask_img=cluster_mask).fit()
        voxels_signals = roi_masker.transform(sub_funx)
    
    # Take mean across voxels
    if return_mean_cluster_data:
        mean_cluster_dat = pd.DataFrame(voxels_signals.mean(1)).reset_index()
        mean_cluster_dat = mean_cluster_dat.rename(columns = {'index':'TR',0:'mean_BOLD'})
        mean_cluster_dat['sub'] = sub
        mean_cluster_dat['cluster'] = ID
        mean_cluster_dat = mean_cluster_dat[['sub','cluster','TR','mean_BOLD']]
        return_dict['mean_cluster_data'] = mean_cluster_dat
    
    # Return all voxel data
    if return_full_cluster_data:
        full_cluster_dat = pd.DataFrame(voxels_signals).melt(var_name = 'voxel', value_name = 'BOLD').reset_index()
        full_cluster_dat = full_cluster_dat.rename(columns = {'index':'TR'})
        full_cluster_dat['sub'] = sub
        full_cluster_dat['cluster'] = ID
        full_cluster_dat = full_cluster_dat[['sub','cluster','voxel','TR','BOLD']]
        return_dict['full_cluster_data'] = full_cluster_dat
    
    return return_dict

In [None]:
all_subs = pd.read_csv(base_dir + '/Data/Subjects_and_exclusions/all_subjects.csv')['sub'].values.tolist()
print(len(all_subs))

# Exclusions
exclude = pd.read_csv(base_dir + '/Data/Subjects_and_exclusions/exclude_video-watching_aggregate_run-%i.csv'%run)[
    'sub'].values.tolist()
if run == 1:
    last_TR = 390
elif run == 2:
    last_TR = 307
elif run == 3:
    last_TR = 720
print('Exclusions: %s'%exclude)
subs_keep = [i for i in all_subs if i not in exclude]
print('Keep %i subjects'%len(subs_keep))

In [None]:
all_ROI_dat = pd.DataFrame()
for sub in subs_keep:
    print('Loading data subject %i run %i...'%(sub,run))
    sub_funx = load_sub_funx(sub, run)
    print('Extracting cluster data... ', end = '')
    for ID in list(ROIs.keys()):
        print(ID, end = ', ')
        [cluster_peak_IJK, cluster_mask] = [ROIs[ID][i] for i in ['peak_IJK','mask']]
        out = extract_sub_cluster_data(sub_funx, cluster_peak_IJK, cluster_mask, return_mean_cluster_data=True)
        all_ROI_dat = all_ROI_dat.append(out['mean_cluster_data'].merge(out['peak_data'])).reset_index(drop=True)
    print('')

In [None]:
display(all_ROI_dat.head())
print(all_ROI_dat.shape)
print(all_ROI_dat['sub'].unique())
print(all_ROI_dat['cluster'].unique())

In [None]:
results_dir + '/' + run_model_dir 

In [None]:
all_ROI_dat.to_csv(results_dir + '/' + run_model_dir + 
                    '/ROI_data_term-%s.csv'%term, index=False)

## Compute inter-subject correlations and store data for dyadic regression at the ROI level

In [None]:
all_ROI_dat = pd.read_csv(results_dir + '/' + run_model_dir + 
                    '/ROI_data_term-%s.csv'%term, index_col = None)
all_ROI_dat.head()

In [None]:
out_var = 'mean_BOLD'
for IDi,ID in enumerate(sorted(all_ROI_dat['cluster'].unique())):
    print(ID, end = ', ')
    ISC_dat = all_ROI_dat.query('cluster == @ID and TR > 3 and TR < @last_TR').pivot_table(
        index = ['TR'], values = [out_var], columns = 'sub')
    subcolumns = [i[1] for i in ISC_dat.columns.tolist()]
    ISC = 1 - pd.DataFrame(scipy.spatial.distance.squareform(
                        scipy.spatial.distance.pdist(ISC_dat.T, metric = 'correlation')),
                       columns = subcolumns)
    ISC['sub'] = subcolumns
    ISC = ISC.melt(id_vars = 'sub', var_name = 'sub2', value_name = 'ISC.ROI.%i'%ID)
    ISC = ISC.rename(columns = {'sub2':'SubID1','sub':'SubID2'})
    if IDi == 0:
        all_ISC = ISC.copy()
    else:
        all_ISC = all_ISC.merge(ISC, on = ['SubID1','SubID2'])

In [None]:
predictor_RDMs = pd.read_csv(base_dir + '/Data/Cleaned/Surveys/predictor_RDMs_4.csv', index_col = None)

In [None]:
reg_dat = predictor_RDMs[['SubID1','SubID2','ideology_similarity','joint_IUS',
                             'age_distance','scan_day_distance','same_gender',
                             'same_undergrad','same_community']]
reg_dat = reg_dat.merge(all_ISC, on = ['SubID1','SubID2'])
reg_dat.head()

In [None]:
reg_dat.to_csv(results_dir + '/' + run_model_dir + 
                    '/ROI_regression_data_term-%s_%s.csv'%(term,out_var), index=False)