In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import gridspec
from torch import load, device
from nilearn import regions, maskers, datasets, surface, plotting

import sys
sys.path.append('/home/maelle/GitHub_repositories/cNeuromod_encoding_2020')  
import models

In [2]:
results_path = '/home/maelle/Results/figures/surface_evr_models'
r2_tests = '/home/maelle/Results/best_models'
runs_path = '/home/maelle/Results/best_models/predict_S4_runs'
subs = ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06']

MIST_path = '/home/maelle/DataBase/fMRI_parcellations/MIST_parcellation/Parcellations/MIST_ROI.nii.gz'
voxel_mask = '/home/maelle/GitHub_repositories/cNeuromod_encoding_2020/parcellation/STG_middle.nii.gz'

mist_roi_labels = '/home/maelle/DataBase/fMRI_parcellations/MIST_parcellation/Parcel_Information/MIST_ROI.csv'
label_pd = pd.read_csv(mist_roi_labels, sep=';')
print(label_pd.loc[label_pd['name']=='right_SUPERIOR_TEMPORAL_GYRUS_middle'])
for a in list(label_pd.name):
    print(a)

scales = ['MIST_ROI']#['auditory_Voxels']
conv = 'conv4'
baseline = 'no_ft'
diff_name = '{}-{}'.format(conv, baseline)

     roi      label                                  name  size  symmetry  \
153  154  R_STgyr_m  right_SUPERIOR_TEMPORAL_GYRUS_middle   281       0.0   

     laterality     x     y     z  \
153        -1.0  57.9 -6.27 -0.14   

                                             neighbour  parent  overlap  
153  [21, 36, 40, 57, 92, 134, 147, 150, 163, 170, ...      88      1.0  
right_SUPERIOR_PARIETAL_LOBULE
left_SUPERIOR_PARIETAL_LOBULE
right_CAUDATE_ventral
left_CAUDATE_ventral
left_CEREBELLUM_VERMIS
right_CEREBELLUM_VERMIS
left_INFERIOR_PARIETAL_LOBULE
left_SOMATOMOTOR_NETWORK_anteromedial
right_SOMATOMOTOR_NETWORK_anteromedial
left_ANTERIOR_CINGULATE_CORTEX_dorsal
right_ANTERIOR_CINGULATE_CORTEX_dorsal
left_MIDDLE_FRONTAL_GYRUS_posterorostral
left_SOMATOMOTOR_NETWORK_dorsolateral
left_CEREBELLUM_VI_posterior
right_CEREBELLUM_VI_posterior
left_CAUDATE_dorsal
right_CAUDATE_dorsal
left_CEREBELLUM_IX_dorsal
right_CEREBELLUM_IX_dorsal
left_ANTERIOR_INSULA_ventral
right_ANTERIOR_INSULA_vent

In [3]:
columns_label = label_pd['name']
df = pd.DataFrame()
for model in os.listdir(runs_path):
    if not '_f_conv' in model and not 'Voxels' in model:
        sub = model[:6]
        sub_path = os.path.join(runs_path, model)
        sub_array = np.load(sub_path)
        print(sub, sub_array.shape)
        sub_df = pd.DataFrame(sub_array).rename(label_pd['name'], axis='columns')
        sub_df_melt = sub_df.melt(var_name='label', value_name='r2', ignore_index=False).reset_index()
        sub_df_melt.rename({'index' : 'run'}, axis='columns', inplace=True)
        
        #sub_list = [sub]*sub_df_melt.shape[0]
        #sub_df_melt['subject'] = sub_list
        
        df_median  = sub_df_melt.groupby(by=['label']).median()
        df_order_by_r2 = df_median.sort_values(by='r2', ascending=False)
        sub_list = [sub]*df_order_by_r2.shape[0]
        df_order_by_r2['subject'] = sub_list
        print(df_order_by_r2.head())
        df = pd.concat([df, df_order_by_r2], axis=0, join='outer')
        
print(df.shape)

sub-02 (47, 210)
                                              run        r2 subject
label                                                              
right_SUPERIOR_TEMPORAL_GYRUS_middle         23.0  0.255970  sub-02
left_SUPERIOR_TEMPORAL_GYRUS_middle          23.0  0.173715  sub-02
right_SUPERIOR_TEMPORAL_GYRUS_posterior      23.0  0.142101  sub-02
left_SUPERIOR_TEMPORAL_GYRUS_posterior       23.0  0.128430  sub-02
right_LATERAL_VISUAL_NETWORK_dorsoposterior  23.0  0.114949  sub-02
sub-04 (47, 210)
                                          run        r2 subject
label                                                          
right_SUPERIOR_TEMPORAL_GYRUS_middle     23.0  0.348107  sub-04
left_SUPERIOR_TEMPORAL_GYRUS_middle      23.0  0.259205  sub-04
left_SUPERIOR_TEMPORAL_GYRUS_posterior   23.0  0.207760  sub-04
right_SUPERIOR_TEMPORAL_GYRUS_posterior  23.0  0.197105  sub-04
right_HESCHLS GYRUS                      23.0  0.165898  sub-04
sub-01 (47, 210)
                         

In [4]:
for scale in scales:
    subs_data = {sub:[] for sub in subs}
    for sub in subs:
        print(sub)
        sub_path = os.path.join(r2_tests, sub)
        for model in os.listdir(sub_path):
            
            select_baseline = model.find(baseline) > -1
            select_conv = model.find(conv) > -1
            select_scale = model.find(scale) > -1
            if baseline == 'no_ft' and model.find('f_conv') == -1 and select_scale :
                baseline_model = model
            elif select_conv and select_scale:
                selected_model = model
            elif select_baseline and select_scale:
                baseline_model = model
            else :
                pass
        print(selected_model, baseline_model)
            
        conv_path = os.path.join(sub_path, selected_model)
        basl_path = os.path.join(sub_path, baseline_model)
    
        conv_data = load(conv_path, map_location=device('cpu'))['test_r2'].reshape(1,-1)
        basl_data = load(basl_path, map_location=device('cpu'))['test_r2'].reshape(1,-1)
        print(conv_data.shape, type(conv_data))
        diff_data = np.subtract(conv_data, basl_data)
        sub_data = {conv : conv_data, 
                    baseline : basl_data, 
                    diff_name : diff_data}
        subs_data[sub] = sub_data
    
    #mean of all subjects for 3 conditions (conv, baseline, diff)
    mean = {conv : (), baseline : (), diff_name : ()}
    for cond in mean.keys():
        all_cond = np.array([sub_dic[cond] for sub_dic in subs_data.values()])
        cond_mean = np.mean(all_cond, axis=0)
        mean[cond] = cond_mean
    subs_data['mean']=mean

sub-01
friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_f_conv4_wbid2svgtnue_20220524-113541.pt friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_wbid1j4np8t3_20220524-113522.pt
(1, 210) <class 'numpy.ndarray'>
sub-02
friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_f_conv4_wbidpcmtk3zb_20220428-093227.pt friends_MIST_ROI_SoundNetEncoding_conv_0800715_0e+00_1e-05_1e-03_opt110_wbid28gqq8u1_20220429-044704.pt
(1, 210) <class 'numpy.ndarray'>
sub-03
friends_MIST_ROI_SoundNetEncoding_conv_0700515_0e+00_1e-04_1e-03_opt110_f_conv4_wbid2jl8l7ad_20220407-095803.pt friends_MIST_ROI_SoundNetEncoding_conv_0700515_0e+00_1e-04_1e-03_opt110_wbid2mlgalkf_20220407-095244.pt
(1, 210) <class 'numpy.ndarray'>
sub-04
friends_MIST_ROI_SoundNetEncoding_conv_0800615_0e+00_1e-05_1e-03_opt110_f_conv4_wbidf0qeh2ll_20220329-112905.pt friends_MIST_ROI_SoundNetEncoding_conv_0800615_0e+00_1e-05_1e-03_opt110_wbidqds77g2d_20220329-110804.pt
(1, 210) 

In [None]:
for sub, data in subs_data.items():
    x = data['no_ft']
    y = x[x > 0.15]
    z = x[x > 0.3]
    print(sub, y.shape, z.shape, x.max())

In [None]:
def list_positions_in_matrix2D(nb_rows, nb_col):
    i_rows = np.repeat(np.arange(nb_rows), nb_col)
    i_col = np.tile(np.arange(nb_col), nb_rows)
    pos = [(row, col) for row, col in zip(i_rows, i_col)]
    return pos

In [None]:
def voxels_nii(voxel_data, voxel_mask, t_r=1.49):
#from voxels to nii
    voxel_masker = maskers.NiftiMasker(mask_img=voxel_mask, standardize=False, detrend=False, t_r=t_r, smoothing_fwhm=8)
    voxel_masker.fit()
    vox_data_nii = voxel_masker.inverse_transform(voxel_data)
    return vox_data_nii

In [None]:
def surfplot_plt_axe(nii_data, fig, title=None, 
                     hemis = ['left', 'right'], views = ['lateral', 'medial'],threshold=0, 
                     display = 'grid', texture = 'pial', zoom = [0,0],
                     symmetric_cbar='auto', cmap='turbo', vmax=None):
    
    if display == 'horizontal' : 
        nrows = 1
        ncols = len(hemis)+len(views)
        wspace = 0
        hspace = 0
    elif display == 'vertical':
        nrows = len(views)+len(hemis)
        ncols = 1
        wspace = 0
        hspace = -0.2
    else :
        nrows = len(views)
        ncols = len(hemis)
        wspace = -0.1
        hspace = -0.2
        
    grid = fig.add_gridspec(nrows = nrows, ncols = ncols, wspace=wspace, hspace=hspace)
    positions = list_positions_in_matrix2D(nb_rows=nrows, nb_col=ncols)
    all_views = np.repeat(views, len(hemis))
    specs = [(view, hemi) for view, hemi in zip(hemis*len(views), all_views)]
    
    fsaverage = datasets.fetch_surf_fsaverage()  
    for pos, (hemi, view) in zip(positions, specs):
        ax = fig.add_subplot(grid[pos], projection = '3d')
        #title = hemi if pos[0] == 0 else None
        if hemi == 'left' :
            kwargs = {
            'surf_mesh' : fsaverage.infl_left,
            'stat_map' : surface.vol_to_surf(nii_data, fsaverage.pial_left),
            'bg_map' : fsaverage.sulc_left,
            'hemi' : hemi,
            'view' : view,
            'axes' : ax,
            'figure' : fig, 
            'colorbar' : False,
            'vmax': vmax,
            'threshold' : threshold
            }
        else: 
            kwargs = {
            'surf_mesh' : fsaverage.infl_right,
            'stat_map' : surface.vol_to_surf(nii_data, fsaverage.pial_right),
            'bg_map' : fsaverage.sulc_right,
            'hemi' : hemi,
            'view' : view,
            'axes' : ax,
            'figure' : fig, 
            'colorbar' : True,
            'vmax': vmax,
            'threshold' : threshold
            }
        plotting.plot_surf_stat_map(**kwargs)
        
        xmin, xmax = ax.get_xlim()
        ymin, ymax = ax.get_ylim()
        zmin, zmax = ax.get_zlim()
        print(xmin, xmax, ymin, ymax, zmin, zmax)
        #ax.set_xlim(xmin+30, xmax-30)
        #ax.set_ylim(ymin+30, ymax-30)
        #ax.set_zlim(0, 1)
        

fig = plt.figure()
vmax = round(mean[conv].max(), ndigits=3)
vmin = round(mean[conv].min(), ndigits=3)
conv_nii = voxels_nii(mean[conv], voxel_mask)
#conv_nii = regions.signals_to_img_labels(mean[conv], MIST_path)
surfplot_plt_axe(nii_data = conv_nii, fig=fig, vmax=vmax, symmetric_cbar=False, 
                 views = ['lateral'], zoom = [-0.2, -0.1], threshold = 0.05)

In [None]:
for condition in [baseline]: #diff_name, conv
    fig = plt.figure(figsize=[14,10], frameon=False, layout='tight')
    grid = fig.add_gridspec(nrows=2, ncols=3)
    positions = list_positions_in_matrix2D(nb_rows=3, nb_col=3)
    for (sub, data), pos in zip(subs_data.items(), positions):
        if sub == 'mean' :  
            pass
            #subfig = fig.add_subfigure(grid[2,:3], frameon=False)
            #display = 'horizontal'
        else :        
            subfig = fig.add_subfigure(grid[pos], frameon=False)
            display = 'grid'
            
            vmax = round(data[condition].max(), ndigits=3)

            if scale == 'auditory_Voxels':
                conv_nii = voxels_nii(data[condition], voxel_mask)
                vmax = 0.47
            else :
                conv_nii = regions.signals_to_img_labels(data[condition], MIST_path)
                vmax = 0.37

            surfplot_plt_axe(nii_data=conv_nii, fig=subfig, display = display, views = ['lateral'], 
                             vmax=0.4)
            subfig.suptitle(sub, ha='left', va='top')

    
    
    savepath = os.path.join(results_path, '{}_r2_map_{}'.format(scale, condition))
    fig.savefig(savepath)
    plt.close()
        
        