In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import gridspec
from torch import load, device
from nilearn import regions, maskers
import surfplot as spl
from neuromaps import datasets, transforms
from PIL import Image

import sys
sys.path.append('/home/maelle/GitHub_repositories/cNeuromod_encoding_2020')  
import models

In [2]:
results_path = '/home/maelle/Results/figures/surface_evr_models'
r2_tests = '/home/maelle/Results/best_models'
subs = ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06']
MIST_path = '/home/maelle/DataBase/fMRI_parcellations/MIST_parcellation/Parcellations/MIST_ROI.nii.gz'
voxel_mask = '/home/maelle/GitHub_repositories/cNeuromod_encoding_2020/parcellation/STG_middle.nii.gz'
scales = ['MIST_ROI']#['auditory_Voxels','MIST_ROI']

In [3]:
#step 2 : function to create 1 surface map for one subject
def create_surface_map_from_parcellated_signal(data_nii, shape='inflated', 
                                               views=None, zoom=1.5, layout='grid', size=(500, 400),
                                               label_text=None,
                                               cbar_min=None, cbar_max=None): 
    
    #mesh, gii_like data for surface visualisation
    fslr = datasets.fetch_fslr()
    lh, rh = fslr[shape] #'midthickness', 'inflated', 'veryinflated', 'sphere'
    p = spl.Plot(surf_lh=lh, surf_rh=rh, zoom=zoom, views=views, layout=layout, size=size, label_text=label_text)
    #layout = 'grid', 'column', 'row'
    #views = 'lateral', 'medial', 'dorsal', 'ventral', 'anterior', 'posterior' or list of these
    
    #transform data from volumetric to surface/vertices (giimg)
    gii_lh, gii_rh = transforms.mni152_to_fslr(data_nii, fslr_density='32k', method='linear')
    
    if cbar_min is None and cbar_max is None : 
        p.add_layer({'left': gii_lh, 'right': gii_rh})
    else :
        cbar_min = 0 if cbar_min is None else cbar_min
        cbar_max = 0 if cbar_max is None else cbar_max
        p.add_layer({'left': gii_lh, 'right': gii_rh}, color_range = (cbar_min, cbar_max))
    
    return p.build(figsize=(10,10), scale=(3,3))

In [4]:
def list_positions_in_matrix2D(nb_rows, nb_col):
    i_rows = np.repeat(np.arange(nb_rows), nb_col)
    i_col = np.tile(np.arange(nb_col), nb_rows)
    pos = [(row, col) for row, col in zip(i_rows, i_col)]
    return pos

In [5]:
def title_and_save_fig(fig, result_path, title, filename, ext='.jpg', size = None):
    if size is None : 
        fig.suptitle(title, fontsize='xx-large')
    else:
        fig.suptitle(title, fontsize=size)
    savepath = os.path.join(results_path, filename+ext)
    fig.savefig(savepath)
    plt.close()
    return savepath

In [6]:
def voxels_nii(voxel_data, voxel_mask, t_r=1.49):
#from voxels to nii
    voxel_masker = maskers.NiftiMasker(mask_img=voxel_mask, standardize=False, detrend=False, t_r=t_r, smoothing_fwhm=8)
    voxel_masker.fit()
    vox_data_nii = voxel_masker.inverse_transform(voxel_data)
    return vox_data_nii

In [7]:
def map_fig_from_jpg(mean_im, subs_im, filepath, title, savefile):
    n_rows = 2
    n_col = 5
    pos = list_positions_in_matrix2D(n_rows, n_col)
    pos_subs = [pos_sub for pos_sub in pos if pos_sub[1]>=2]
    fig = plt.figure(constrained_layout=True, figsize=(35, 20))
    grid = gridspec.GridSpec(2,5, figure=fig)  
    
    img_style = {
                 'xticklabels':[],
                 'xticks':[],
                 'yticklabels':[],
                 'yticks':[],
                 'frame_on':False
                 }
    
    mean_path=os.path.join(filepath, mean_im)
    mean_ax = fig.add_subplot(grid[:,:2])
    image = Image.open(mean_path)
    image = image.crop((0,0,720,650))
    mean_ax.imshow(image)
    mean_ax.set(**img_style)
    
    for sub_pos, sub_im in zip(pos_subs, subs_im):
        sub_path = os.path.join(filepath, sub_im)
        ax = fig.add_subplot(grid[sub_pos])
        image = Image.open(sub_path)
        image = image.crop((0,0,720,600))
        ax.imshow(image)
        ax.set(**img_style)
    title_and_save_fig(fig, results_path, title, savefile, size=40)

In [11]:
#step 1.0 : load data with selected cond--> baseline / conv / diff conv - baseline - works with scale MIST for now
conv = 'conv4'
baseline = 'no_ft'
diff_name = '{}-{}'.format(conv, baseline)

vox_surfplot_args={
    'label_text' : {'bottom':['left hemisphere', 'right hemisphere']},                                                                                                       
    'zoom':3, 
    'views':'lateral', 
    'layout':'row', 
    'size':(1000,850),
    'shape':'veryinflated'
}

MIST_surfplot_args={                 
    'size':(1000,800)
}

for scale in scales:
    subs_data = {sub:[] for sub in subs}
    for sub in subs:
        print(sub)
        sub_path = os.path.join(r2_tests, sub)
        for model in os.listdir(sub_path):
            
            select_baseline = model.find(baseline) > -1
            select_conv = model.find(conv) > -1
            select_scale = model.find(scale) > -1
            if baseline == 'no_ft' and model.find('f_conv') == -1 and select_scale :
                baseline_model = model
            elif select_conv and select_scale:
                selected_model = model
            elif select_baseline and select_scale:
                baseline_model = model
            else :
                pass
        print(selected_model, baseline_model)
            
        conv_path = os.path.join(sub_path, selected_model)
        basl_path = os.path.join(sub_path, baseline_model)
    
        conv_data = load(conv_path, map_location=device('cpu'))['test_r2'].reshape(1,-1)
        basl_data = load(basl_path, map_location=device('cpu'))['test_r2'].reshape(1,-1)
        diff_data = np.subtract(conv_data, basl_data)
        sub_data = {conv : conv_data, 
                    baseline : basl_data, 
                    diff_name : diff_data}
        
        for name, data in sub_data.items():
            #transform data from parcellation/voxels to volumetric (niimg)
            if scale == 'auditory_Voxels':
                data_nii = voxels_nii(data, voxel_mask)
                #vox_surfplot_args['cbar_max'] = 0.07 if name == diff_name else 0.46
                surfplot_args = vox_surfplot_args
            elif scale == 'MIST_ROI':
                data_nii = regions.signals_to_img_labels(data,MIST_path)
                #MIST_surfplot_args['cbar_max'] = 0.05 if name == diff_name else 0.35
                surfplot_args = MIST_surfplot_args
            
            fig = create_surface_map_from_parcellated_signal(data_nii, **surfplot_args)
            
            title = '{} : {}'.format(sub, name)
            filename = '{}_{}_{}'.format(scale, sub, name)
            savepath = title_and_save_fig(fig, results_path, title, filename, ext = '.jpg')
            sub_data[name] = (savepath, data)
        subs_data[sub] = sub_data
    
    #mean of all subjects for 3 conditions (conv, baseline, diff)
    mean = {conv : (), baseline : (), diff_name : ()}
    for cond in mean.keys():
        all_cond = np.array([sub_dic[cond][1] for sub_dic in subs_data.values()])
        cond_mean = np.mean(all_cond, axis=0)
        
        #transform data from parcellation/voxels to volumetric (niimg)
        if scale == 'auditory_Voxels':
            cond_mean_nii = voxels_nii(cond_mean, voxel_mask)
            #vox_surfplot_args['cbar_max'] = 0.07 if cond == diff_name else 0.46
            surfplot_args = vox_surfplot_args
        elif scale == 'MIST_ROI':
            cond_mean_nii = regions.signals_to_img_labels(cond_mean,MIST_path)
            #MIST_surfplot_args['cbar_max'] = 0.05 if cond == diff_name else 0.35
            surfplot_args = MIST_surfplot_args
        
        fig = create_surface_map_from_parcellated_signal(cond_mean_nii, **surfplot_args)
        title = '{} : {}'.format('mean',cond)
        filename = '{}_{}_{}'.format(scale, 'mean', cond)
        savepath = title_and_save_fig(fig, results_path, title, filename, ext = '.jpg')
        
        mean[cond] = (savepath, cond_mean)
    subs_data['mean']=mean
    
    whole_fig = plt.figure()
    subs_figs = whole_fig.subfigures(3,1)

    for cond in mean.keys():
        title = 'surface maps for {} : {}'.format(scale, cond)
        savefile =  '{}_{}_surface_map'.format(scale, cond)
        images_cond = [data[cond][0] for sub, data in subs_data.items()]
        map_fig_from_jpg(mean_im = images_cond[-1], subs_im = images_cond[:-1], filepath = results_path, title=title, savefile=savefile)
 

sub-01
friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_f_conv4_wbid2svgtnue_20220524-113541.pt friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_wbid1j4np8t3_20220524-113522.pt
sub-02
friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_f_conv4_wbidpcmtk3zb_20220428-093227.pt friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_wbid28gqq8u1_20220429-044704.pt
sub-03
friends_MIST_ROI_SoundNetEncoding_conv_0700515_0e+00_1e-04_1e-03_opt110_f_conv4_wbid2jl8l7ad_20220407-095803.pt friends_MIST_ROI_SoundNetEncoding_conv_0700515_0e+00_1e-04_1e-03_opt110_wbid2mlgalkf_20220407-095244.pt
sub-04
friends_MIST_ROI_SoundNetEncoding_conv_0800615_0e+00_1e-05_1e-03_opt110_f_conv4_wbidf0qeh2ll_20220329-112905.pt friends_MIST_ROI_SoundNetEncoding_conv_0800615_0e+00_1e-05_1e-03_opt110_wbidqds77g2d_20220329-110804.pt
sub-05
friends_MIST_ROI_SoundNetEncoding_conv_0600715_0e+00_1e-06_1e-03_opt110_f_conv4_wbid3p9qsjur_20220624

<Figure size 432x288 with 0 Axes>