In [None]:
%matplotlib inline
import os, sys, glob, scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import nilearn, nibabel, nltools, nistats
from nltools.data import Brain_Data
from nistats.reporting import get_clusters_table
from nilearn.plotting import plot_stat_map, plot_roi, plot_img, plot_glass_brain
from nilearn.input_data import NiftiMasker

In [None]:
base_dir = os.path.realpath('../../..')
print(base_dir)
results_dir = os.path.realpath('../../../Results/voxelwise_ISRSA/nifti')

##### Function definitions

In [None]:
def get_clusters(image, threshold = 0, extent_threshold = 0, sort_by_size = False, include_peaks_only = True):
    clusters = get_clusters_table(image, threshold, cluster_threshold = extent_threshold)
    clusters = clusters.rename(columns = {'Cluster ID':'ID', 'Cluster Size (mm3)':'size',
                                          'Peak Stat':'peak_value'})
    if include_peaks_only:
        peaks = clusters.copy()
        peaks = peaks.loc[peaks['ID'].apply(lambda x: not str(x)[-1].isalpha()),:].reset_index(drop=True)
#         peaks = peaks.sort_values(by = 'size', ascending = False).reset_index(drop=True)
    if sort_by_size:
        clusters = clusters.sort_values(by = 'size', ascending = False).reset_index(drop=True)
    return clusters, peaks

In [None]:
def xyz_to_ijk(MNI, image):
    return list(np.linalg.inv(image.affine[:3,:3]).dot(MNI-image.affine[:3,3]).astype(int))

In [None]:
def expand_cluster(image, cluster_peak_MNI, voxel_volume = 27):
    
    # Find IJK of peak
    i,j,k = xyz_to_ijk(cluster_peak_MNI,image)
    cluster_peak_IJK = [i,j,k]
    print('MNI: %s, IJK: %s'%(cluster_peak_MNI, cluster_peak_IJK))
    
    # Double check that peak value is correct
    peakval = image.get_data()[i][j][k]
    print('Peak value extracted from image data array: %f'%peakval)
    
    # Binarize image
    binarized = (image.get_data() != 0).astype(int)
        
    # Label each cluster with a different number, reserve 0 for empty voxels
    conn_mat = np.zeros((3, 3, 3), int)  # 6-connectivity, aka NN1 or "faces"
    conn_mat[1, 1, :] = 1
    conn_mat[1, :, 1] = 1
    conn_mat[:, 1, 1] = 1
    label_map = scipy.ndimage.measurements.label(binarized, conn_mat)[0]
    clust_image = nibabel.Nifti1Image(label_map, affine=image.affine)
    
    # Find voxels with same label as cluster peak
    cluster_label = clust_image.get_data()[i,j,k]
    cluster_ROI = (clust_image.get_data() == cluster_label).astype(int)
    ROI_mask = nibabel.Nifti1Image(cluster_ROI, affine = image.affine)
    print('Cluster ROI located')
    print('Cluster size as extracted from image data array: %i voxels = %i mm^3'%(
        sum(cluster_ROI.flatten()),sum(cluster_ROI.flatten())*voxel_volume))
    
    return cluster_peak_IJK, cluster_ROI, ROI_mask

In [None]:
def load_sub_funx(sub, run, nifti_dir = "/gpfs_home/jvanbaar/data/jvanbaar/polarization/derivatives/cleaning"):
    fname = nifti_dir + '/sub-%03d'%sub + '/ses-1/func/' + \
            'sub-%03d_ses-1_task-videoWatching_run-%i'%(sub,run) + \
            '_space-MNI152NLin2009cAsym_desc-cleaned_bold.nii.gz'
    sub_funx = nilearn.image.load_img(fname)
    return sub_funx

In [None]:
def extract_sub_cluster_data(sub_funx, cluster_ID, cluster_peak_IJK, cluster_mask,
                            return_peak_data = True,
                            return_mean_cluster_data = True,
                            return_full_cluster_data = False):
    
    return_dict = {}
    
    # Get values at peak
    if return_peak_data:
        i,j,k = cluster_peak_IJK
        peak_data = sub_funx.get_data()[i,j,k]
        peak_data = pd.DataFrame(peak_data, columns=['peak_BOLD']).reset_index()
        peak_data = peak_data.rename(columns = {'index':'TR'})
        peak_data['sub'] = sub
        peak_data['cluster'] = cluster_ID
        peak_data = peak_data[['sub','cluster','TR','peak_BOLD']]
        return_dict['peak_data'] = peak_data
    
    # Get cluster data
    if return_mean_cluster_data or return_full_cluster_data:
        roi_masker = NiftiMasker(mask_img=cluster_mask).fit()
        voxels_signals = roi_masker.transform(sub_funx)
    
    # Take mean across voxels
    if return_mean_cluster_data:
        mean_cluster_dat = pd.DataFrame(voxels_signals.mean(1)).reset_index()
        mean_cluster_dat = mean_cluster_dat.rename(columns = {'index':'TR',0:'mean_BOLD'})
        mean_cluster_dat['sub'] = sub
        mean_cluster_dat['cluster'] = cluster_ID
        mean_cluster_dat = mean_cluster_dat[['sub','cluster','TR','mean_BOLD']]
        return_dict['mean_cluster_data'] = mean_cluster_dat
    
    # Return all voxel data
    if return_full_cluster_data:
        full_cluster_dat = pd.DataFrame(voxels_signals).melt(var_name = 'voxel', value_name = 'BOLD').reset_index()
        full_cluster_dat = full_cluster_dat.rename(columns = {'index':'TR'})
        full_cluster_dat['sub'] = sub
        full_cluster_dat['cluster'] = cluster_ID
        full_cluster_dat = full_cluster_dat[['sub','cluster','voxel','TR','BOLD']]
        return_dict['full_cluster_data'] = full_cluster_dat
    
    return return_dict

##### Select effect of interest

In [None]:
run = 3
filter_TR = False
TR_start = 1
TR_end = 711
model = 'ideology'
term = 'ideosim'
threshold = 'fdr'
extent_threshold = True
extent = 5
if extent_threshold:
    total_threshold = threshold + '_ext-%i'%extent
else:
    total_threshold = threshold

In [None]:
run_model_dir = ('run-%i_TRs-%i-%i_model-%s'%(run,TR_start,TR_end,model) 
                     if filter_TR else 'run-%i_model-%s'%(run,model))
search_string = '%s/%s/*%s*%s.ni*'%(results_dir, run_model_dir, term, total_threshold)
print(search_string)
fpath = glob.glob(search_string)[0]
print(fpath)

In [None]:
image = nilearn.image.load_img(fpath)

In [None]:
%matplotlib inline
plot_stat_map(image)
plt.show()

##### Find activation clusters in tresholded stat map

In [None]:
clusters, peaks = get_clusters(image, 0, 5)
print('%i clusters found, loading...\n..\n.\n'%peaks.shape[0])
cluster_list = peaks['ID'].unique()
ROIs = dict()
for ID in cluster_list:
    print(ID, end = ',')
    cluster_info = peaks.query('ID == @ID')
#     display(cluster_info)
    cluster_peak_MNI = list(cluster_info.iloc[0].loc[['X','Y','Z']].values.flatten())
    cluster_peak_IJK, cluster_voxels, cluster_mask = expand_cluster(image, cluster_peak_MNI)
    ROIs[ID] = {'peak_IJK':cluster_peak_IJK,'mask':cluster_mask}
ROIs

##### Select ROI

In [None]:
peaks_VMPFC = peaks.query('X > 5 and X < 20 and Y > 15 and Y < 30').sort_values(by = ['X','Y','Z'])
peaks_VMPFC

In [None]:
plot_roi(mask, cut_coords = [10], display_mode = 'x')

In [None]:
MNI_peak = peaks_VMPFC.iloc[0][['X','Y','Z']]
plot_stat_map(image, cut_coords = MNI_peak)

In [None]:
cluster_ID = peaks_VMPFC.iloc[0]['ID']

In [None]:
cluster_peak_IJK, cluster_mask = [ROIs[cluster_ID][key] for key in ['peak_IJK','mask']]

##### Select subjects to plot

In [None]:
lib_pair = [16,45]
con_pair = [20,35]
all_plot_subs = lib_pair + con_pair
print(all_plot_subs)

## Sphere-based method

In [None]:
from nltools.mask import create_sphere

In [None]:
radius = 6
masker = nilearn.input_data.NiftiSpheresMasker([MNI_peak], radius, standardize=False)

In [None]:
run = 3
all_out = pd.DataFrame()
for sub in all_plot_subs:
    print(sub)
    sub_funx = load_sub_funx(sub, run)
    out = masker.fit_transform(sub_funx)
    to_append = pd.DataFrame(out, columns = ['BOLD'])
    to_append['sub'] = sub
    to_append['TR'] = np.arange(1,np.shape(to_append)[0]+1)
    all_out = all_out.append(to_append)

In [None]:
# Select data
n_minutes = 2
tcdat = all_out.query('TR > 3 and TR <= %i'%(4 + n_minutes*40)).copy()
tcdat['time (s)'] = tcdat['TR']*1.5 - 6
plotdat = tcdat.pivot(index = 'time (s)', columns = 'sub', values = 'BOLD')
plotdat = plotdat.apply(scipy.stats.zscore)

# Remove outlier TRs
outlier_z = 30 #no outliers remove
len_before = np.shape(plotdat)[0]
plotdat = plotdat.iloc[np.where(((plotdat > -outlier_z) & (plotdat < outlier_z)).product(axis=1))[0],:].reset_index()
len_after = np.shape(plotdat)[0]
print('%i outlier TRs dropped'%(len_before - len_after))

# Smooth using rolling average
window_length_TRs = 4
plotdat = plotdat.rolling(window_length_TRs).mean()

In [None]:
plotdat

In [None]:
sphere = create_sphere(MNI_peak, radius = radius,mask = cluster_mask)

In [None]:
model = 'ideology'
run_model_dir = 'run-%i_model-%s'%(run,model)
run_model_dir

In [None]:
# Plot
fig,axes = plt.subplots(nrows=1, ncols=2, gridspec_kw = {'width_ratios':[1,4]}, figsize = [12,3])
plot_roi(sphere, cut_coords = [10], cmap = 'spring', alpha = 1, vmax = 1, display_mode = 'x', axes = axes[0])
ax = axes[1]
for si,s in enumerate(all_plot_subs):
    color = 'red' if s in con_pair else 'blue'
    pole = 'C' if s in con_pair else 'L'
    ideology = ID_dat.query('SubID == @s')['IdeologyScale_1'].iloc[0]
    ax.plot(plotdat['time (s)'], plotdat[s], label = 'Subject %i, ideology %i (%s)'%(s,ideology,pole),
                color = color, alpha = .5, lw = 2)
ax.legend()
ax.set(xlabel = 'Time (s)', ylabel = 'BOLD (z)', title = 'Orbitofrontal cortex activity during video 3');
plt.savefig('%s/%s/timecourses_%s-ROI-%i_peak-sphere-%imm.pdf'%(results_dir, run_model_dir, term, cluster_ID, radius),
            transparent = True, bbox_inches = 'tight');