In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import gridspec
from scipy import stats

In [2]:
#paths + global variables

data_runs_path = '/home/maelle/Results/best_models/predict_S4_runs'
results_path = '/home/maelle/Results/'
subs = ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06']
convs = ['baseline', 'conv7', 'conv6', 'conv5', 'conv4']
scales = ['auditory_Voxels', 'MIST_ROI']

In [3]:
def select_data_in_percentile(data, high_percent=1.0, low_percent=0.0):
    h_threshold = np.quantile(data, high_percent)
    l_threshold = np.quantile(data, low_percent)
    best_data = data[data < high_percent]
    best_data = best_data[low_percent < best_data].reshape(1, -1)
    return best_data

In [4]:
def compute_significativity_between_array2D(arrayA, arrayB, axis=0, threshold = 0.05, independancy=True):
    if independancy:
        t, p = stats.ttest_ind(arrayA, arrayB, axis)
    else:
        t, p = stats.ttest_rel(arrayA, arrayB, axis)
        #non param - Wilcoxon, permutations(10000) ... FDR
    return (t, p)

In [5]:
#recup data runs

def difference_conv_baseline_by_runs(scale, ref_conv, baseline, h_percent, l_percent, 
                                     roi=False, ordered = True):
    nb = 556 if scale == 'auditory_Voxels' else 210
    dico = {}
    for sub in subs:
        dico[sub]={}
        for conv in convs:
            dico[sub][conv]=np.array([]).reshape(-1,nb)

    for model_runs_predicts in os.listdir(data_runs_path):
        if scale in model_runs_predicts:
            model_path = os.path.join(data_runs_path, model_runs_predicts)
            sub = model_runs_predicts[:6]
            i = model_runs_predicts.find('f_conv')
            conv = model_runs_predicts[i+2:i+len('f_conv')+1] if i>-1 else 'baseline'
            r2_runs = np.load(model_path)
            if roi:
                r2_runs=r2_runs.T
            
            best_runs = []
            #select percentile for each run
            for run in r2_runs:
                h_threshold = np.quantile(run, h_percent)
                l_threshold = np.quantile(run, l_percent)
                best_run = run[run < h_threshold]
                best_run = best_run[l_threshold < best_run].reshape(1, -1)
                best_runs.append(best_run)
            best_runs = np.array(best_runs).squeeze()
            dico[sub][conv] = best_runs
    
    subs_to_plot = {sub:[] for sub in subs}
    for sub in subs:
        absolu = dico[sub][ref_conv]
        (_, p_absolu) = compute_significativity_between_array2D(absolu, dico[sub][baseline], axis=1, independancy=False)
        relatif = np.subtract(dico[sub][ref_conv], dico[sub][baseline])
        zero_array = np.zeros_like(relatif)
        (_, p_relative) = compute_significativity_between_array2D(relatif, zero_array, axis=1, independancy=False)
        
        if ordered:
            absolu_medians = np.quantile(absolu, 0.5, axis=1)
            sorted_index = np.flip(np.argsort(absolu_medians))
            absolu = absolu[sorted_index]
            p_absolu = p_absolu[sorted_index]
            relatif = relatif[sorted_index]
            p_relative = p_relative[sorted_index]
            
        subs_to_plot[sub] = (absolu, p_absolu, relatif, p_relative)
    
    return subs_to_plot

In [6]:
def draw_plot(ax, data, edge_color, fill_color, offset = 0,
              label=None, probability=None, threshold = 0.05):
    sig_data = data[:]
    print(data.shape, sig_data.shape)
    if probability is not None : 
        pos = np.argwhere(probability<=threshold).squeeze()
        sig_data = data[pos].squeeze().T
        no_sig_pos = np.argwhere(probability>threshold).squeeze()
        no_sig_data = data[no_sig_pos].squeeze().T
        print(sig_data.shape, pos.shape, no_sig_data.shape, no_sig_pos.shape)
        
        no_sig_bp = ax.boxplot(no_sig_data, positions=no_sig_pos+1, widths=0.3, 
                               patch_artist=True,manage_ticks=False, showfliers=True)
        plt.setp(no_sig_bp['medians'], color='white')
        for element in ['boxes', 'whiskers', 'fliers', 'caps']:
            plt.setp(no_sig_bp[element], color='grey')
        for patch in no_sig_bp['boxes']:
            patch.set(facecolor='grey')
    
    else:
        pos = np.arange(data.shape[0])+offset
    
    bp = ax.boxplot(sig_data, positions=pos+1, widths=0.3, patch_artist=True, 
                    manage_ticks=False, showfliers=True)
    plt.setp(bp['medians'], color='white')
    for element in ['boxes', 'whiskers', 'fliers', 'caps']:
        plt.setp(bp[element], color=edge_color)
    for patch in bp['boxes']:
        patch.set(facecolor=fill_color)

In [None]:
h_percent=1.0
l_percent=0.0
ref_conv = 'conv4'
baseline = 'baseline'
nb_rows = 2
nb_col = int(len(subs)/nb_rows)
size_row = 17
size_col = 5

stg_data = difference_conv_baseline_by_runs('auditory_Voxels', roi=True, ordered = True,
                                            ref_conv=ref_conv, baseline=baseline,
                                            h_percent=h_percent, l_percent=l_percent)
wholebrain_data = difference_conv_baseline_by_runs('MIST_ROI', roi=True, ordered = True,
                                            ref_conv=ref_conv, baseline=baseline,
                                            h_percent=h_percent, l_percent=l_percent)

for scale, dico in zip(scales, [stg_data, wholebrain_data]):
    #main figure
    fig = plt.figure(constrained_layout=True, figsize=(nb_rows*size_row,nb_col*size_col))
    subfigs = fig.subfigures(nb_rows, nb_col, wspace=0.1)
    #positions of each subfig in main figure
    i_rows = np.repeat(np.arange(nb_rows), nb_col)
    i_col = np.tile(np.arange(nb_col), nb_rows)
    pos = [(row, col) for row, col in zip(i_rows, i_col)]
    x_label = 'voxels' if scale == 'auditory_Voxels' else 'MIST regions'
    
    for i, (sub, data) in zip(pos, dico.items()):
        grid = gridspec.GridSpec(2,1, figure=subfigs[i])
        
        absolu_data = data[0]
        absolu_prob = data[1]
        ax_absolu = subfigs[i].add_subplot(grid[0])  
        draw_plot(ax_absolu, absolu_data, "skyblue", "skyblue", probability=absolu_prob)
        ax_absolu.set_title('{}'.format(sub), fontsize=22)
        ax_absolu.axhline(ls='--', lw=0.5, c='grey')
        ax_absolu.set_ylabel('{}'.format(ref_conv), fontsize=20)
        ax_absolu.set_xticks([])
        ax_absolu.tick_params(axis='y', labelsize=18)
        
        relative_data = data[2]
        relative_prob = data[3]
        ax_relatif = subfigs[i].add_subplot(grid[1])
        draw_plot(ax_relatif, relative_data, "tomato", "tomato", probability=relative_prob)
        ax_relatif.axhline(ls='--', lw=0.5, c='grey')
        ax_relatif.set_xlabel(x_label, fontsize=18)
        ax_relatif.set_ylabel('{} - {}'.format(ref_conv, baseline), fontsize=20)
        ax_relatif.tick_params(axis='y', labelsize=18)
        
        grid.update(hspace=0)
        subfigs[i].supylabel('r² score', fontsize=22)

        
    fig.suptitle('{} R² Difference : {} - {} for all runs of season 4'.format(scale, ref_conv, baseline), fontsize=26)
    fig.patch.set_facecolor='w'
    title = '{}_draft_boxplot_r2_scores_differences_for_all_runs_of_the_season.png'.format(scale)
    savepath = os.path.join(results_path, title)
    plt.savefig(savepath, bbox_inches='tight')
    plt.close()

(556, 45) (556, 45)
(45, 242) (242,) (45, 314) (314,)
(556, 45) (556, 45)
(45, 242) (242,) (45, 314) (314,)
(556, 45) (556, 45)
(45, 525) (525,) (45, 31) (31,)
(556, 45) (556, 45)
(45, 525) (525,) (45, 31) (31,)
(556, 45) (556, 45)
(45, 367) (367,) (45, 189) (189,)
(556, 45) (556, 45)
(45, 367) (367,) (45, 189) (189,)
(556, 45) (556, 45)
(45, 452) (452,) (45, 104) (104,)
(556, 45) (556, 45)
(45, 452) (452,) (45, 104) (104,)
(556, 42) (556, 42)
(42, 191) (191,) (42, 365) (365,)
(556, 42) (556, 42)
(42, 191) (191,) (42, 365) (365,)
(556, 45) (556, 45)
(45, 248) (248,) (45, 308) (308,)
(556, 45) (556, 45)
(45, 248) (248,) (45, 308) (308,)
(210, 45) (210, 45)
(45, 86) (86,) (45, 124) (124,)
(210, 45) (210, 45)
(45, 86) (86,) (45, 124) (124,)
(210, 45) (210, 45)
(45, 35) (35,) (45, 175) (175,)
(210, 45) (210, 45)
(45, 35) (35,) (45, 175) (175,)
(210, 45) (210, 45)
(45, 79) (79,) (45, 131) (131,)
(210, 45) (210, 45)
(45, 79) (79,) (45, 131) (131,)
(210, 45) (210, 45)
(45, 108) (108,) (45, 10