In [1]:
#general import
import os
import numpy as np
import pandas as pd

#nn import
from torch import device, load

#stats import
from scipy.stats import ttest_rel, wilcoxon
from statistics import median
from statsmodels.stats import multitest

#general visualization import
#import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import gridspec, colors, colormaps

#brain visualization import
from nilearn import regions, datasets, surface, plotting, image, maskers
from nilearn.plotting import plot_roi, plot_stat_map

In [2]:
def sum_conditions(conditions) : 
    correct = True
    for condition in conditions:
        if condition == False:
            correct = False
            break
    return correct

In [3]:
def get_specific_file_path(path, *cond):
    specific_file_path = None
    for file in os.listdir(path):
        filepath = os.path.join(path, file)
        if os.path.isdir(filepath):
            continue
        conditions = [(obj in file) if boolean else (obj not in file) for (obj, boolean) in cond]
        specific_file = sum_conditions(conditions)
        specific_file_path = filepath if specific_file else specific_file_path
    return specific_file_path

In [4]:
def voxels_nii(voxel_data, voxel_mask, t_r=1.49, smoothing_fwhm=8):
#from voxels to nii
    voxel_masker = maskers.NiftiMasker(mask_img=voxel_mask, standardize=False, 
                                       detrend=False, t_r=t_r, smoothing_fwhm=smoothing_fwhm)
    voxel_masker.fit()
    vox_data_nii = voxel_masker.inverse_transform(voxel_data)
    return vox_data_nii

In [5]:
def voxel_map(voxel_data, vmax=None, cut_coords=None, tr = 1.49, bg_img=None, cmap = 'cold_hot', smoothing_fwhm=8) : 
    f = plt.Figure()
    data_nii = voxels_nii(voxel_data, voxel_mask, t_r=tr, smoothing_fwhm=smoothing_fwhm)
    if bg_img is not None : 
        plotting.plot_stat_map(data_nii, bg_img=bg_img, draw_cross=False, vmax=vmax,
                           display_mode='x', cut_coords=[-63, -57, 57, 63], figure=f,
                              black_bg=True, dim = 0, cmap=cmap)
    else :
        plotting.plot_stat_map(data_nii, draw_cross=False, vmax=vmax,
                           display_mode='x', cut_coords=[-63, -57, 57, 63], figure=f)
    return f

In [6]:
def surface_fig(parcel_data, vmax, threshold='auto', cmap='turbo', inflate=True, colorbar=True, no_background=True):     
    nii_data = regions.signals_to_img_labels(parcel_data, MIST_path)
    fig, ax = plotting.plot_img_on_surf(nii_data,
                              views=['lateral', 'medial'], hemispheres=['left', 'right'], inflate=inflate,
                              vmax=vmax, threshold=threshold, colorbar=colorbar, cmap=cmap, symmetric_cbar=False)
    return fig

In [7]:
def extend_colormap(original_colormap = 'twilight', percent_start = 0.25, percent_finish = 0.25):
    colormap = colormaps[original_colormap]
    nb_colors = colormap.N
    new_colors_range = colormap(np.linspace(0,1,nb_colors))

    n_start = round(nb_colors/(1-percent_start)) - nb_colors if percent_start != 0 else 0
    new_color_start = np.array([colormap(0)]*n_start).reshape(-1, new_colors_range.shape[1])
    n_finish = round(nb_colors/(1-percent_finish)) - nb_colors if percent_finish != 0 else 0
    new_color_finish = np.array([colormap(0)]*n_finish).reshape(-1, new_colors_range.shape[1])

    new_colors_range = np.concatenate((new_color_start,new_colors_range,new_color_finish), axis=0)
    new_colormap = colors.ListedColormap(new_colors_range)
    return new_colormap

In [11]:
class Statistical_R2_BrainMap():    
    def __init__(self, sub, scale, finetune_a='no_ft', finetune_b=None, map_threshold=0.05):
        self.subject = sub          
        self.scale = scale
        self.finetune_a = finetune_a
        self.finetune_b = 'no_init' if finetune_b is None else finetune_b 
        
        if finetune_a == finetune_b:
            self.finetune_b = 'no_init'
            #raise exception ... ?
        self.percent_data=False
        
        self.map_type = 'flat' if self.finetune_b == 'no_init' else 'diff'            
        if self.map_type == 'flat':
            self.color = 'turbo'
            self.percent_start = 0.50
        else:
            self.color = 'plasma'
            self.percent_start = 0.45

        print(self.color)
        self.cmp = extend_colormap(original_colormap=self.color, percent_start=self.percent_start, percent_finish=0)
        self.abs_threshold = map_threshold

    def load_data(self, data_path):
        ft_a = ('_f_', False) if self.finetune_a == 'no_ft' else (self.finetune_a, True)
        ft_b = ('_f_', False) if self.finetune_b == 'no_ft' else (self.finetune_b, True)
        
        filepath_a = get_specific_file_path(data_path, (self.subject, True),(self.scale, True),ft_a)
        filepath_b = get_specific_file_path(data_path, (self.subject, True),(self.scale, True),ft_b)
        self.data_a = np.load(filepath_a)
        self.data_b = np.load(filepath_b)

    def test_sign_roi(self, test=wilcoxon, difference=True):
        all_rois_pval = []
        data_a = self.data_a.T
        data_b = self.data_b.T
        
        for i, (roi_a, roi_b) in enumerate(zip(data_a, data_b)):
            if difference :
                roi_a = roi_a - roi_b
                roi_b = None
            stat, pvalue = test(roi_a, roi_b)
            all_rois_pval.append(pvalue)

        self.sign_rois = [i for i, pval in enumerate(all_rois_pval) if pval <= self.abs_threshold]
        self.not_sign_rois = [i for i, pval in enumerate(all_rois_pval) if pval > self.abs_threshold]
        
        return all_rois_pval
    
    def sign_roi_with_FDR_correction(self, rois_pval, method='n'):
        rejected, p_vals_corr = multitest.fdrcorrection(rois_pval, method=method)
        
        not_sign_fdr = [(i, p_val) for i, (p_val, sign) in enumerate(zip(p_vals_corr, rejected)) if not sign]
        sign_fdr = [(i, p_val) for i, (p_val, sign) in enumerate(zip(p_vals_corr, rejected)) if sign]
        
        self.sign_rois = [idx for idx, p_value in sign_fdr]
        self.not_sign_rois = [idx for idx, p_value in not_sign_fdr]
        
        return (sign_fdr, not_sign_fdr)

    def prepare_roi_r2_for_brainmap(self, sign_values_only=True, fill_values=np.nan, percent_diff=True):
        mean_data_a = np.mean(self.data_a, axis=0) #np.mean(run_baseline_data, axis=0) #diff_means_runs if percent threshold
        if sign_values_only:
            mean_data_a = [value if i in self.sign_rois else fill_values for i, value in enumerate(mean_data_a)]
        if self.abs_threshold is not None:
            mean_threshold_a = [value if np.abs(value) >= self.abs_threshold else fill_values for value in mean_data_a]  
        else:
            self.abs_threshold = np.nanmin(np.abs(mean_data_a))
        mean_threshold_a = np.array(mean_threshold_a).reshape(1, -1)
        
        if self.map_type == 'diff':
            index_nan = np.argwhere(np.isnan(mean_threshold_a)).squeeze()
            mean_data_b = np.mean(self.data_b, axis=0)
            mean_threshold_b = [np.nan if i in index_nan else value for i, value in enumerate(mean_data_b)]
            diff_data = np.array(mean_threshold_a) - np.array(mean_threshold_b)
            diff_data = diff_data.reshape(1,-1)
            if percent_diff:
                percent_data = diff_data*100/np.abs(np.array(mean_threshold_b).reshape(1, -1))
                diff_data = percent_data
                self.abs_threshold = np.nanmin(np.abs(percent_data))
                self.percent_data=True

        self.brainmap_data = diff_data if self.map_type == 'diff' else mean_threshold_a  

    def r2_voxel_map(self, out_path, background, namefile='', smoothing_fwhm=8,  
                     colorbar=True, cut_coords=[-63, -57, 57, 63]):
        percent_str = 'percent' if self.percent_data else ''
        output_file = os.path.join(out_path, f'{self.subject}_{self.scale}_{self.map_type}{percent_str}-r2_{self.finetune_a}_contrast-{self.finetune_b}_{namefile}.png') 

        if self.percent_data:
            vmax=30
        elif self.map_type == 'diff':
            vmax=0.06
        else:
            vmax=0.5 
            
        fig = voxel_map(self.brainmap_data, cut_coords=cut_coords, cmap=self.cmp, vmax=vmax, 
                        smoothing_fwhm=smoothing_fwhm, bg_img=background)
        fig.set_size_inches(28,3)
        fig.savefig(output_file, dpi=100)

    def r2_roi_map(self, out_path, colorbar=True):
        percent_str = 'percent' if self.percent_data else ''
        output_file = os.path.join(out_path, f'{self.subject}_{self.scale}_{self.map_type}{percent_str}-r2_{self.finetune_a}_contrast-{self.finetune_b}.png') 
      
        if self.percent_data:
            vmax=30
        elif self.map_type == 'diff':
            vmax=0.04
        else:
            vmax=0.4
        fig = surface_fig(self.brainmap_data, threshold=self.abs_threshold, vmax=vmax, colorbar=colorbar, cmap=self.cmp)         
        fig.set_size_inches(8,10)
        fig.savefig(output_file, dpi=100)

In [13]:
out_path = '/home/maellef/Results/finefriends/figures/surface/revised_paper'
r2_by_run_path = '/home/maellef/Results/finefriends/best_models/predict_S4_runs' #'/home/maellef/Results/finefriends/fwhm0/s4_r2' 
MIST_path = '/home/maellef/DataBase/fMRI_parcellations/MIST_parcellation/Parcellations/MIST_ROI.nii.gz'
voxel_mask = '/home/maellef/git/cNeuromod_encoding_2020/parcellation/STG_middle.nii.gz'
background_path = '/home/maellef/Results/ref_anat'
background_file = 'MNI152NLin6Asym_desc-preproc_T1w.nii.gz'

subs = ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06']
for sub in subs:
    scale = 'auditory_Voxels'
    anatomical_bg_path = get_specific_file_path(background_path,(sub, True),(background_file, True))   
    for ft_main, ft_contrast in [('no_ft', None), ('conv4', None), ('conv4', 'no_ft')]:
        sub_map = Statistical_R2_BrainMap(sub, scale, finetune_a=ft_main, finetune_b=ft_contrast, map_threshold=0.05)
        sub_map.load_data(r2_by_run_path)
        rois_pvals = sub_map.test_sign_roi(test=wilcoxon, difference=True)
        sub_map.sign_roi_with_FDR_correction(rois_pvals, method='n')
        sub_map.prepare_roi_r2_for_brainmap(sign_values_only=True, fill_values=np.nan, percent_diff=True)
        sub_map.r2_voxel_map(out_path, background=anatomical_bg_path, namefile='fwhm8', smoothing_fwhm=0, 
                             colorbar=True, cut_coords=[-63, -57, 57, 63])

turbo
turbo



KeyboardInterrupt



Other utils------

In [None]:
def brain_with_cuts(axe = 'sagittal', axe_slice = 50, cut_x = [], cut_y = [], cut_z = [], ax=None):
    brain = datasets.load_mni152_template()
    brain_arr = image.get_data(brain)
    
    #select the array
    if axe == 'x' or axe == 'sagittal':
        brain_slice = brain_arr[axe_slice,:,:].squeeze()
        cuts = [cut_y, cut_z]
    elif axe == 'y' or axe == 'coronal':
        brain_slice = brain_arr[:,axe_slice,:].squeeze()
        cuts = [cut_x, cut_z]
    elif axe == 'z' or axe == 'axial':
        brain_slice = brain_arr[:,:,axe_slice].squeeze()
        cuts = [cut_x, cut_y]
    else : 
        raise ValueError('the axe of the brain slice {} does not exist'.format(axe))
    
    #remove noise + black background
    brain_slice[brain_slice < 10**-10] = 0
    brain_slice[brain_slice == 0.0] = 1
    
    #add the cut in the array
    print(cuts)
    for cut in cuts[0] : 
        brain_slice[cut, :] = 0
    for cut in cuts[1] : 
        brain_slice[:, cut] = 0
    
    brain_slice = np.flip(brain_slice.T, axis=0)
    
    if ax is not None:
        ax.imshow(brain_slice, cmap='gray')
    else : 
        plt.imshow(brain_slice, cmap='gray')