In [None]:
import nibabel as nib
import numpy as np
import matplotlib
import pandas as pd
import matplotlib.pyplot as plt

%config InlineBackend.figure_formats = ['svg']
%matplotlib inline  

import util

In [None]:

surfdata = pd.read_pickle('tmp/surfdata.pkl')

lesional_subjects = [subj for subj in surfdata.index.get_level_values('subject').unique() if 'P' in subj]
lesional_subjects = sorted(lesional_subjects)


In [None]:
surfdata.index.to_frame()['feature'].unique()

In [None]:
# define colors
color_3T = [0.48, 0, 0.51, 1.0] # purple
color_94T = [0.086, 0.490, 0.0039, 1.0] # green

colors_B0 = {'3T': color_3T, '94T': color_94T}

colors_subj_B0 = [[color_3T, color_94T],
                    [color_3T, color_94T]]


In [None]:
# define ROIs for statistics, different possible criteria
    
stat_ROIs = []
for subj in lesional_subjects:
    index = surfdata.index.to_frame()
    lesionlabel = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_label')].squeeze()
    lesion_distance = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_pial_distance')].squeeze()
    hemi = lesion_distance.name[3]
    if True: # default, the whole lesion label without other criteria
        stat_ROIs.append(lesionlabel == True)
    elif False: # only vertices in lesion label and with thickness > 3.0
        selection = (surfdata[(index['subject'] == subj) & (index['feature'] == 'nthickness') & (index['hemi'] == hemi)] > 3.0).all(axis=0)
        selection = selection & (lesionlabel == True)
        stat_ROIs.append(selection)
    elif False: #only based on distance
        stat_ROIs.append(lesion_distance < 10.0)

    if False: # visualize suing freeview
        orig_label = f'tmp/{subj}_lesion_ROI_{hemi}_orig.mgh'
        nib.save(nib.freesurfer.mghformat.MGHImage(lesionlabel.astype('float32'), affine=None), orig_label)

        selected_label = f'tmp/{subj}_lesion_ROI_{hemi}_selected.mgh'
        nib.save(nib.freesurfer.mghformat.MGHImage(stat_ROIs[-1].astype('float32'), affine=None), selected_label)

        cmd = f'''freeview 
                    -f data/derivatives/freesurfer/fsaverage_sym/surf/{hemi}.inflated:overlay={orig_label}:overlay={selected_label}'''
        cmd = ' '.join(cmd.split())
        util.bash_run(cmd)


In [None]:
# z-score features based on control mean and std

# remove rows with features 'cortex_label', 'lesion_label' and 'lesion_pial_distance'
surfdata_zscored = surfdata[~surfdata.index.get_level_values('feature').isin(['cortex_label', 'lesion_label', 'lesion_pial_distance'])]
def zscore(df):
    ctrls = df[df.index.get_level_values('subject').str.contains('C')]
    mean = ctrls.mean(axis=0)
    std = ctrls.std(axis=0)
    return (df - mean) / std
surfdata_zscored = surfdata_zscored.groupby(['B0', 'feature', 'hemi','depth'], group_keys=False, dropna=False).apply(zscore)
# add back the removed rows without z-scoring
surfdata_zscored = pd.concat([surfdata_zscored, surfdata[surfdata.index.get_level_values('feature').isin(['cortex_label', 'lesion_label', 'lesion_pial_distance'])]])
surfdata_zscored.head()

In [None]:
def plot_thickness_ROI_3T94T(surfdata,
                             subj,
                             color,
                             method='median',
                             ylabel='cortical thickness [mm]',
                             ymin=0.0,
                             ymax=4.5):
    plt.figure(figsize=(1.5,2.5))
    axs = plt.axes()    

    index = surfdata.index.to_frame()
    lesionlabel = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_label')].squeeze()
    hemi = lesionlabel.name[3]

    d = surfdata.loc[
        (index['subject'].str.contains(f'{subj}|C')) & 
        (index['feature'] == 'nthickness') & 
        (index['hemi'] == hemi),
        lesionlabel == True] 

    # replace index B0 values for better plotting
    d = d.reset_index()
    d['B0'] = d['B0'].replace({'94T': '9.4T'})
    d = d.set_index(['subject', 'feature', 'hemi','depth','B0'])

    if method == 'median':
        d = d.median(axis=1).reset_index(name='data')
    elif method == 'mean':
        d = d.mean(axis=1).reset_index(name='data')
    else:
        raise ValueError(f'Unknown method {method}, should be "median" or "mean"')

    d = d.sort_values('B0') # to have 3T first

    axs.plot('B0',
            'data',
            data=d[d['subject'] == subj],
            marker='.',
            label=subj,
            zorder=2.5,
            color='gray')

    # overlay scatter so that 3T and 94T points have colors colors[0] and colors[1]
    axs.scatter('B0', 'data', data=d[d['subject'] == subj], color=color, zorder=3)

    for c in [subj for subj in d['subject'].unique() if 'C' in subj]:
        axs.plot('B0', 'data', data=d[d['subject'] == c], color='gray', alpha=0.3, marker='o', label='_nolegend')


    #axs.legend()
    axs.set_xlim(-0.5, 1.5)
    axs.set_ylim(ymin, ymax) 
    axs.set_yticks(np.arange(round(axs.get_ylim()[0]), round(axs.get_ylim()[1]+1), 1.0))
    axs.set_ylabel(ylabel)
    axs.axhline(y=0, color='black', linewidth=1.0, linestyle='-')
    axs.grid(axis='y', linestyle='--', alpha=0.6)
    #axs.set_title('cortical thickness in lesion ROI')
    plt.tight_layout()

plot_thickness_ROI_3T94T(surfdata, 
                            subj=lesional_subjects[0], 
                            color=colors_subj_B0[0])

plot_thickness_ROI_3T94T(surfdata, 
                            subj=lesional_subjects[1], 
                            color=colors_subj_B0[1])

plot_thickness_ROI_3T94T(surfdata_zscored, 
                            subj=lesional_subjects[0], 
                            color=colors_subj_B0[0],
                            ylabel='z-score',
                            ymin=None, ymax=None)

plot_thickness_ROI_3T94T(surfdata_zscored, 
                            subj=lesional_subjects[1], 
                            color=colors_subj_B0[1],
                            ylabel='z-score',
                            ymin=None, ymax=None)

In [None]:

def plot_feature_vs_depth_per_lesion(surfdata,
                                     lesional_subjects,
                                     feature,
                                     stat_ROIs,
                                     colors,
                                     method='median',
                                     B0=None,
                                     axs=None,
                                     xlabel=True,
                                     ylabel=None,
                                     yticks=True,
                                     ylim=None):
    for i_subj, subj in enumerate(lesional_subjects):
        index = surfdata.index.to_frame()
        lesionlabel = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_label')].squeeze()
        hemi = lesionlabel.name[3]

        d = surfdata.loc[
            (index['subject'].str.contains(f'{subj}|C')) & 
            (index['feature'] == feature) & 
            (index['hemi'] == hemi) &
            (index['B0'] == B0),
            stat_ROIs[i_subj]]
    
        if method == 'median':
            d = d.median(axis=1).reset_index(name='data')
        elif method == 'mean':
            d = d.mean(axis=1).reset_index(name='data')
        else:
            raise ValueError(f'Unknown method {method}, should be "median" or "mean"')

        # sort by depth and convert to str (so it's treated as categorical)
        d = d.sort_values('depth') # to have 3T first
        d['depth'] = d['depth'].astype(str)

        axs[i_subj].plot('depth',
                        'data', 
                        data=d[d['subject'] == subj],
                        marker='o',
                        label=subj,
                        zorder=2.5,
                        color=colors[i_subj][0 if B0 == '3T' else 1])    

        # plot controls
        for c in [subj for subj in d['subject'].unique() if 'C' in subj]:
            axs[i_subj].plot('depth', 'data', data=d[d['subject'] == c], color='lightgray', alpha=0.3, marker='o', label='_nolegend')

        # plot contralateral hemisphere
        contra_hemi = 'lh' if hemi == 'rh' else 'rh'
        d_contralat = surfdata.loc[
            (index['subject'].str.contains(f'{subj}|C')) & 
            (index['feature'] == feature) & 
            (index['hemi'] == contra_hemi) &
            (index['B0'] == B0),
            stat_ROIs[i_subj]]
        
        if method == 'median':
            d_contralat = d_contralat.median(axis=1).reset_index(name='data')
        elif method == 'mean':
            d_contralat = d_contralat.mean(axis=1).reset_index(name='data')
        else:
            raise ValueError(f'Unknown method {method}, should be "median" or "mean"')
        
        d_contralat = d_contralat.sort_values('depth')
        d_contralat['depth'] = d_contralat['depth'].astype(str)
        axs[i_subj].plot('depth',
                        'data', 
                        data=d_contralat[d_contralat['subject'] == subj],
                        marker='x',
                        label=f'{subj} contralateral homotopic region',
                        zorder=2.1,
                        color='royalblue',
                        alpha=0.5)    
        
        for c in [subj for subj in d['subject'].unique() if 'C' in subj]:
            axs[i_subj].plot('depth', 'data', data=d_contralat[d_contralat['subject'] == c], color='lightskyblue', alpha=0.3, marker='x', label='_nolegend')
                    
        # legend, labels etc.
        if xlabel:
            axs[i_subj].set_xlabel('depth [mm sub-pial]')
        else:
            plt.setp(axs[i_subj].get_xticklabels(), visible=False)
        axs[i_subj].grid(axis='y', linestyle='--', alpha=0.6)
        if not yticks:
            # disable y ticks
            axs[i_subj].set_yticklabels([])
        if i_subj != 0:
            axs[i_subj].set_yticklabels([])
        if ylim is not None:
            axs[i_subj].set_ylim(ylim)
        if ylabel is not None and i_subj == 0:
            axs[i_subj].set_ylabel(ylabel)


# feature @ different depths per lesion
fig = plt.figure(figsize=(6,8.5), constrained_layout=True)
gs = matplotlib.gridspec.GridSpec(11, len(lesional_subjects),
                                  height_ratios=(0.1,) + (0.2, 1.0)*5,
                                  figure=fig)
fig.get_layout_engine().set(w_pad=0.5/72., h_pad=1/72., hspace=0.0, wspace=0.0)

axs = []
for i in range(5):
    row_axes = []
    for j in range(len(lesional_subjects)):
        ax = fig.add_subplot(gs[2*i+2, j], sharex=axs[0][j] if i>0 else None)
        row_axes.append(ax)
    axs.append(row_axes)

axs = np.array(axs)

title_ax = fig.add_subplot(gs[0, 0])
title_ax.text(0.5, 0.5, lesional_subjects[0], ha="center", va="center", fontsize=13)
title_ax.axis("off")
title_ax = fig.add_subplot(gs[0, 1])
title_ax.text(0.5, 0.5, lesional_subjects[1], ha="center", va="center", fontsize=13)
title_ax.axis("off")

title_ax = fig.add_subplot(gs[1, :])
title_ax.text(0.5, 0.5, "3T normalised FLAIR", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata,
                                    lesional_subjects=lesional_subjects,
                                    feature='FLAIR_projabs', 
                                    stat_ROIs=stat_ROIs,
                                    colors=colors_subj_B0,
                                    B0='3T',
                                    axs=axs[0,:],
                                    xlabel=False,
                                    yticks=True,
                                    ylabel='[a.u.]',
                                    ylim=(-0.5, 1.2))


title_ax = fig.add_subplot(gs[3, :])
title_ax.text(0.5, 0.5, "3T T1 map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata,
                                lesional_subjects=lesional_subjects,
                                feature='T1map_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='3T',
                                axs=axs[1,:],
                                xlabel=False,
                                yticks=True,
                                ylabel='[s]',
                                ylim=(0.8, 1.6))

title_ax = fig.add_subplot(gs[5, :])
title_ax.text(0.5, 0.5, "9.4T T1 map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata,
                                lesional_subjects=lesional_subjects,
                                feature='T1map_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='94T',
                                axs=axs[2,:],
                                xlabel=False,
                                yticks=True,
                                ylabel='[s]',
                                ylim=(1.25, 2.25))

title_ax = fig.add_subplot(gs[7, :])
title_ax.text(0.5, 0.5, "9.4T T2* map", ha="center", va="top")
title_ax.axis("off")
    
plot_feature_vs_depth_per_lesion(surfdata=surfdata,
                                lesional_subjects=lesional_subjects,
                                feature='T2star_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='94T',
                                axs=axs[3,:],
                                xlabel=False,
                                yticks=True,
                                ylabel='[ms]',
                                ylim=(15, 35))

title_ax = fig.add_subplot(gs[9, :])
title_ax.text(0.5, 0.5, "9.4T QSM map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata,
                                lesional_subjects=lesional_subjects,
                                feature='QSMTke3_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='94T',
                                axs=axs[4,:],
                                xlabel=True,
                                yticks=True,
                                ylabel='[ppm]',
                                ylim=(-0.007, 0.007))



# feature @ different depths per lesion


z_score_lim=(-4.2, 4.2)

fig = plt.figure(figsize=(6,8.5), constrained_layout=True)
gs = matplotlib.gridspec.GridSpec(11, len(lesional_subjects),
                                  height_ratios=(0.1,) + (0.2, 1.0)*5,
                                  figure=fig)
fig.get_layout_engine().set(w_pad=0.5/72., h_pad=1/72., hspace=0.0, wspace=0.0)

axs = []
for i in range(5):
    row_axes = []
    for j in range(len(lesional_subjects)):
        ax = fig.add_subplot(gs[2*i+2, j], sharex=axs[0][j] if i>0 else None)
        row_axes.append(ax)
    axs.append(row_axes)

axs = np.array(axs)

title_ax = fig.add_subplot(gs[0, 0])
title_ax.text(0.5, 0.5, lesional_subjects[0], ha="center", va="center", fontsize=13)
title_ax.axis("off")
title_ax = fig.add_subplot(gs[0, 1])
title_ax.text(0.5, 0.5, lesional_subjects[1], ha="center", va="center", fontsize=13)
title_ax.axis("off")

title_ax = fig.add_subplot(gs[1, :])
title_ax.text(0.5, 0.5, "3T normalised FLAIR", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata_zscored,
                                    lesional_subjects=lesional_subjects,
                                    feature='FLAIR_projabs', 
                                    stat_ROIs=stat_ROIs,
                                    colors=colors_subj_B0,
                                    B0='3T',
                                    axs=axs[0,:],
                                    xlabel=False,
                                    yticks=True,
                                    ylabel='z-score',
                                    ylim=z_score_lim)

title_ax = fig.add_subplot(gs[3, :])
title_ax.text(0.5, 0.5, "3T T1 map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata_zscored,
                                lesional_subjects=lesional_subjects,
                                feature='T1map_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='3T',
                                axs=axs[1,:],
                                xlabel=False,
                                yticks=True,
                                ylabel='z-score',
                                ylim=z_score_lim)

title_ax = fig.add_subplot(gs[5, :])
title_ax.text(0.5, 0.5, "9.4T T1 map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata_zscored,
                                lesional_subjects=lesional_subjects,
                                feature='T1map_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='94T',
                                axs=axs[2,:],
                                xlabel=False,
                                yticks=True,
                                ylabel='z-score',
                                ylim=z_score_lim)

title_ax = fig.add_subplot(gs[7, :])
title_ax.text(0.5, 0.5, "9.4T T2* map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_depth_per_lesion(surfdata=surfdata_zscored,
                                lesional_subjects=lesional_subjects,
                                feature='T2star_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='94T',
                                axs=axs[3,:],
                                xlabel=False,
                                yticks=True,
                                ylabel='z-score',
                                ylim=z_score_lim)

title_ax = fig.add_subplot(gs[9, :])
title_ax.text(0.5, 0.5, "9.4T QSM map", ha="center", va="top")
title_ax.axis("off")
    
plot_feature_vs_depth_per_lesion(surfdata=surfdata_zscored,
                                lesional_subjects=lesional_subjects,
                                feature='QSMTke3_projabs', 
                                stat_ROIs=stat_ROIs,
                                colors=colors_subj_B0,
                                B0='94T',
                                axs=axs[4,:],
                                xlabel=True,
                                yticks=True,
                                ylabel='z-score',
                                ylim=z_score_lim)


########################
# optional: plot vertical line at lesion median cortical thickness

#median_cortical_thickness_3T_94T = {}
#for subj in lesional_subjects:
#    index = surfdata.index.to_frame()
#    lesionlabel = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_label')].squeeze()
#    hemi = lesionlabel.name[3]
#
#    d = surfdata.loc[
#        (index['subject'] == subj) & 
#        (index['feature'] == 'nthickness') & 
#        (index['hemi'] == hemi),
#        lesionlabel == True] 
#
#    median_value = d.median(axis=1).median()
#    median_cortical_thickness_3T_94T[subj] = median_value
#median_cortical_thickness_3T_94T


# add vertical line at median cortical thickness
# converting to categorical x-axis plotting scale
# -6 mm is 0, -1 mm is 5
#for subj in lesional_subjects:
#    for ax in axs[:, lesional_subjects.index(subj)]:
#        ax.axvline(x=-median_cortical_thickness_3T_94T[subj]+6, color='gray', linewidth=1.0, linestyle='--')


In [None]:
fig = plt.figure(figsize=(3,2))
# add custom legend 
custom_lines = [matplotlib.lines.Line2D([], [], marker='o', color=colors_subj_B0[0][i_B0], label=f'{B0} patient lesion ROI') for i_B0, B0 in enumerate(['3T', '9.4T'])]
custom_lines.append(matplotlib.lines.Line2D([], [], marker='o', color='lightgray', alpha=0.3, label='controls lesion ROI'))
custom_lines.append(matplotlib.lines.Line2D([], [], marker='x', color='royalblue', alpha=0.5, label='contralateral homotopic region'))
custom_lines.append(matplotlib.lines.Line2D([], [], marker='x', color='lightskyblue', alpha=0.3, label='controls contralateral homotopic region'))
fig.legend(handles=custom_lines, loc='center', fontsize='small', ncol=3)

# create empty axes to show the figure
fig.add_axes([0,0,1,1], frameon=False, xticks=[], yticks=[])

In [None]:
def plot_thickness_vs_distance_per_lesion(surfdata, 
                                          lesional_subjects, 
                                          distance_cutoff,
                                          n_bins,
                                          colors,
                                          method='median',
                                          axs=None,
                                          xlabel=True,
                                          ylabel=True):
    if axs is None:
        fig, axs = plt.subplots(len(lesional_subjects), 1, figsize=(3, 6))

    for i_subj, subj in enumerate(lesional_subjects):
        for i_B0, B0 in enumerate(['3T', '94T']):
            index = surfdata.index.to_frame()
            lesion_distance = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_pial_distance')].squeeze()
            hemi = lesion_distance.name[3]

            #bins = np.linspace(0, distance_cutoff, n_bins)
            bins = [np.quantile(lesion_distance[lesion_distance<distance_cutoff], i) for i in np.linspace(0, 1, n_bins)]
            dist_bins = pd.cut(lesion_distance, bins=bins, include_lowest=True, labels=False)
            dist_data_list = []
            for i_bin in range(n_bins-1):
                d = surfdata.loc[
                    (index['subject'].str.contains(f'{subj}|C')) & 
                    (index['feature'] == 'nthickness') & 
                    (index['hemi'] == hemi) &
                    (index['B0'] == B0),
                    dist_bins == i_bin] 

                if method == 'median':
                    d = d.median(axis=1).reset_index(name='data')
                elif method == 'mean':
                    d = d.mean(axis=1).reset_index(name='data')
                else:
                    raise ValueError(f'Unknown method {method}, should be "median" or "mean"')
                
                d['distance_bin'] = (bins[i_bin] + bins[i_bin+1]) / 2
                dist_data_list.append(d)

            d = pd.concat(dist_data_list)

            axs[i_subj].plot('distance_bin',
                'data', 
                data=d[d['subject'] == subj],
                marker='o',
                markersize=3.0,
                label={B0},
                zorder=2.5,
                color=colors[i_subj][i_B0])
            
            for c in [subj for subj in d['subject'].unique() if 'C' in subj]:
                axs[i_subj].plot('distance_bin',
                                    'data',
                                    data=d[d['subject'] == c],
                                    color='gray',
                                    alpha=0.3,
                                    marker='o',
                                    markersize=3.0,
                                    label='_nolegend')
            

        axs[i_subj].set_ylim(0.0, 7.0) 
        if xlabel:
            axs[i_subj].set_xlabel('distance to lesion centroid [mm]')
        else:
            plt.setp(axs[i_subj].get_xticklabels(), visible=False)
        axs[i_subj].axhline(y=0, color='black', linewidth=1.0, linestyle='-')
        axs[i_subj].grid(axis='y', linestyle='--', alpha=0.6)
        if i_subj != 0:
            axs[i_subj].set_yticklabels([])
        elif ylabel:
                axs[i_subj].set_ylabel('[mm]')

def plot_feature_vs_distance(surfdata, 
                             lesional_subjects, 
                             feature, 
                             B0,
                             depth, 
                             distance_cutoff, 
                             n_bins, 
                             colors, 
                             zscore=False, 
                             method='median',
                             axs=None,
                             xlabel=True,
                             ylabel=True,
                             ylim=None):
    if axs is None:
        fig, axs = plt.subplots(len(lesional_subjects), 1, figsize=(3, 6))
    for i_subj, subj in enumerate(lesional_subjects):
        index = surfdata.index.to_frame()
        lesion_distance = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_pial_distance')].squeeze()
        hemi = lesion_distance.name[3]


        # intra-subject z-score normalization
        d_normalization = surfdata.loc[
            (index['subject'].str.contains(f'{subj}|C')) & 
            (index['feature'] == feature) & 
            (index['depth'] == depth) &
            (index['hemi'] == hemi) &
            (index['B0'] == B0),
            :].copy()

        cortexlabel = surfdata[(index['subject'] == 'fsaverage_sym') & (index['feature'] == 'cortex_label')].squeeze()
        d_normalization.loc[:, cortexlabel == False] = np.nan
        d_normalization = d_normalization.loc[:, lesion_distance > distance_cutoff] # exclude lesion vertices

        mean = d_normalization.mean(axis=1)
        std = d_normalization.std(axis=1)

        #bins = np.linspace(0, distance_cutoff, n_bins)
        bins = [np.quantile(lesion_distance[lesion_distance<distance_cutoff], i) for i in np.linspace(0, 1, n_bins)]
        dist_bins = pd.cut(lesion_distance, bins=bins, include_lowest=True, labels=False)
        dist_data_list = []
        for i_bin in range(n_bins-1):
            d = surfdata.loc[
                (index['subject'].str.contains(f'{subj}|C')) & 
                (index['feature'] == feature) & 
                (index['depth'] == depth) &
                (index['hemi'] == hemi) &
                (index['B0'] == B0),
                dist_bins == i_bin] 
        
            if zscore:
                # z-score
                d = d.sub(mean, axis=0).div(std, axis=0)

            if method == 'median':
                d = d.median(axis=1).reset_index(name='data')
            elif method == 'mean':
                d = d.mean(axis=1).reset_index(name='data')
            else:
                raise ValueError(f'Unknown method {method}, should be "median" or "mean"')
        

            d['distance_bin'] = (bins[i_bin] + bins[i_bin+1]) / 2
            dist_data_list.append(d)


        d = pd.concat(dist_data_list)

        axs[i_subj].plot('distance_bin',
            'data', 
            data=d[d['subject'] == subj],
            marker='o',
            markersize=3.0,
            label=subj,
            zorder=2.5,
            color=colors[B0])
        
        for c in [subj for subj in d['subject'].unique() if 'C' in subj]:
            axs[i_subj].plot('distance_bin',
                                'data',
                                data=d[d['subject'] == c],
                                color='lightgray',
                                alpha=0.3,
                                marker='o',
                                markersize=3.0,
                                label='_nolegend')
            


        ###### contralateral
        # intra-subject z-score normalization
        contralat_hemi = 'lh' if hemi == 'rh' else 'rh'
        d_normalization_contralat = surfdata.loc[
            (index['subject'].str.contains(f'{subj}|C')) & 
            (index['feature'] == feature) & 
            (index['depth'] == depth) &
            (index['hemi'] == contralat_hemi) &
            (index['B0'] == B0),
            :].copy()

        cortexlabel = surfdata[(index['subject'] == 'fsaverage_sym') & (index['feature'] == 'cortex_label')].squeeze()
        d_normalization_contralat.loc[:, cortexlabel == False] = np.nan
        d_normalization_contralat = d_normalization_contralat.loc[:, lesion_distance > distance_cutoff] # exclude lesion vertices

        mean_contralat = d_normalization_contralat.mean(axis=1)
        std_contralat = d_normalization_contralat.std(axis=1)

        #bins = np.linspace(0, distance_cutoff, n_bins)
        bins = [np.quantile(lesion_distance[lesion_distance<distance_cutoff], i) for i in np.linspace(0, 1, n_bins)]
        dist_bins = pd.cut(lesion_distance, bins=bins, include_lowest=True, labels=False)
        dist_data_list_contralat = []
        for i_bin in range(n_bins-1):
            d_contralat = surfdata.loc[
                (index['subject'].str.contains(f'{subj}|C')) & 
                (index['feature'] == feature) & 
                (index['depth'] == depth) &
                (index['hemi'] == contralat_hemi) &
                (index['B0'] == B0),
                dist_bins == i_bin] 
        
            if zscore:
                # z-score
                d_contralat = d_contralat.sub(mean_contralat, axis=0).div(std_contralat, axis=0)

            if method == 'median':
                d_contralat = d_contralat.median(axis=1).reset_index(name='data')
            elif method == 'mean':
                d_contralat = d_contralat.mean(axis=1).reset_index(name='data')
            else:
                raise ValueError(f'Unknown method {method}, should be "median" or "mean"')
            
            d_contralat['distance_bin'] = (bins[i_bin] + bins[i_bin+1]) / 2
            dist_data_list_contralat.append(d_contralat)

        d_contralat = pd.concat(dist_data_list_contralat)

        axs[i_subj].plot('distance_bin',
            'data', 
            data=d_contralat[d_contralat['subject'] == subj],
            marker='x',
            markersize=3.0,
            label=subj,
            zorder=2.1,
            color='royalblue',
            alpha=0.5)
        
        for c in [subj for subj in d_contralat['subject'].unique() if 'C' in subj]:
            axs[i_subj].plot('distance_bin',
                                'data',
                                data=d_contralat[d_contralat['subject'] == c],
                                color='lightskyblue',
                                alpha=0.3,
                                marker='x',
                                markersize=3.0,
                                label='_nolegend')


        #########
        
        #axs[i_subj].legend()
        if ylim is not None:
            axs[i_subj].set_ylim(ylim)
        if xlabel:
            axs[i_subj].set_xlabel('distance to lesion centroid [mm]')
        else:
            plt.setp(axs[i_subj].get_xticklabels(), visible=False)
        if i_subj != 0:
            axs[i_subj].set_yticklabels([])
        elif ylabel is not None:
            axs[i_subj].set_ylabel(ylabel)
        axs[i_subj].grid(axis='y', linestyle='--', alpha=0.6)


def plot_fraction_in_roi_bar(lesional_subjects, 
                             distance_cutoff, 
                             n_bins,
                             bottom=0.0, 
                             axs=None):
    dist_ROI_thresholds = []
    for i_subj, subj in enumerate(lesional_subjects):
        index = surfdata.index.to_frame()
        lesion_distance = surfdata[(index['subject'] == subj) & (index['feature'] == 'lesion_pial_distance')].squeeze()
        bins = [np.quantile(lesion_distance[lesion_distance<distance_cutoff], i) for i in np.linspace(0, 1, n_bins)]
        dist_bins = pd.cut(lesion_distance, bins=bins, include_lowest=True, labels=False)
        dist_data_list = []
        fraction_in_ROI_list = []
        for i_bin in range(n_bins-1):
            # also calculate fraction of vertices in this bin that are in the lesion ROI (stat_ROIs)
            roi = stat_ROIs[i_subj]
            fraction_in_ROI = (roi & (dist_bins == i_bin)).sum() / (dist_bins == i_bin).sum()
            fraction_in_ROI_list.append({'distance_bin': (bins[i_bin] + bins[i_bin+1]) / 2,
                                        'fraction_in_ROI': fraction_in_ROI})
            
        fraction_in_ROI_df = pd.DataFrame(fraction_in_ROI_list)
        axs[i_subj].bar(x = [bins[i] + (bins[i+1]-bins[i])/2 for i in range(len(bins)-1)],
                        height = (1,)*len(fraction_in_ROI_df),
                        width = [(bins[i+1]-bins[i]) for i in range(len(bins)-1)],
                        color = plt.get_cmap('Grays')(fraction_in_ROI_df['fraction_in_ROI']),
                        edgecolor = plt.get_cmap('Grays')(fraction_in_ROI_df['fraction_in_ROI']),
                        bottom = bottom)
        
        # find first bin below 0.5 and add a vertical line
        below_threshold = fraction_in_ROI_df[fraction_in_ROI_df['fraction_in_ROI'] < 0.5]
        if not below_threshold.empty:
            first_below = below_threshold.iloc[0]['distance_bin']
        else:
            first_below = None
        dist_ROI_thresholds.append(first_below)
        
        # disable x ticklabels
        plt.setp(axs[i_subj].get_xticklabels(), visible=False)


        # disable y labels and ticks
        axs[i_subj].set_yticklabels([])
        axs[i_subj].set_yticks([])

        # remove all spines
        axs[i_subj].spines['top'].set_visible(False)
        axs[i_subj].spines['right'].set_visible(False)
        axs[i_subj].spines['left'].set_visible(False)
        axs[i_subj].spines['bottom'].set_visible(False)

    return dist_ROI_thresholds
        


fig = plt.figure(figsize=(6,8.5), constrained_layout=True)
gs = matplotlib.gridspec.GridSpec(15, len(lesional_subjects),
                                  height_ratios=(0.3,) + (0.3, 0.2) + (0.3, 1.0)*6,
                                  figure=fig)
fig.get_layout_engine().set(w_pad=0.5/72., h_pad=1/72., hspace=0.0, wspace=0.0)

axs = []
for i in range(7):
    row_axes = []
    for j in range(len(lesional_subjects)):
        ax = fig.add_subplot(gs[2*i+2, j], sharex=axs[0][j] if i>0 else None)
        row_axes.append(ax)
    axs.append(row_axes)

axs = np.array(axs)

title_ax = fig.add_subplot(gs[0, 0])
title_ax.text(0.5, 0.5, lesional_subjects[0], ha="center", va="center", fontsize=13)
title_ax.axis("off")
title_ax = fig.add_subplot(gs[0, 1])
title_ax.text(0.5, 0.5, lesional_subjects[1], ha="center", va="center", fontsize=13)
title_ax.axis("off")


title_ax = fig.add_subplot(gs[1, :])
title_ax.text(0.5, 0.5, "fraction of vertices in manual lesion ROI", ha="center", va="top")
title_ax.axis("off")

dist_ROI_thresholds = plot_fraction_in_roi_bar(lesional_subjects, 
                                       distance_cutoff=25.0, 
                                       n_bins=50,
                                       bottom=0.5,
                                       axs=axs[0,:])  

title_ax = fig.add_subplot(gs[3, :])
title_ax.text(0.5, 0.5, "cortical thickness", ha="center", va="top")
title_ax.axis("off")

plot_thickness_vs_distance_per_lesion(surfdata,
                                        lesional_subjects,
                                        distance_cutoff=25.0,
                                        n_bins=50,
                                        colors=colors_subj_B0,
                                        axs=axs[1,:],
                                        xlabel=False,
                                        ylabel=True)

title_ax = fig.add_subplot(gs[5, :])
title_ax.text(0.5, 0.5, "3T normalised FLAIR", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_distance(surfdata=surfdata,
                            lesional_subjects=lesional_subjects,
                            feature='FLAIR_projabs',
                            B0='3T',
                            depth=-3.0,
                            distance_cutoff=25.0,
                            n_bins=50, 
                            colors=colors_B0,
                            zscore=False,
                            axs=axs[2,:],
                            xlabel=False,
                            ylabel='[a.u.]',
                            ylim=(-0.3, 1.2))

title_ax = fig.add_subplot(gs[7, :])
title_ax.text(0.5, 0.5, "3T T1 map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_distance(surfdata=surfdata,
                            lesional_subjects=lesional_subjects,
                            feature='T1map_projabs',
                            B0='3T',
                            depth=-3.0,
                            distance_cutoff=25.0,
                            n_bins=50, 
                            colors=colors_B0,
                            zscore=False,
                            axs=axs[3,:],
                            xlabel=False,
                            ylabel='[s]',
                            ylim=(0.7, 1.5))

title_ax = fig.add_subplot(gs[9, :])
title_ax.text(0.5, 0.5, "9.4T T1 map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_distance(surfdata=surfdata,
                            lesional_subjects=lesional_subjects,
                            feature='T1map_projabs',
                            B0='94T',
                            depth=-3.0,
                            distance_cutoff=25.0,
                            n_bins=50, 
                            colors=colors_B0,
                            zscore=False,
                            axs=axs[4,:],
                            xlabel=False,
                            ylabel='[s]',
                            ylim=(1.2, 2.2))

title_ax = fig.add_subplot(gs[11, :])
title_ax.text(0.5, 0.5, "9.4T T2* map", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_distance(surfdata=surfdata,
                            lesional_subjects=lesional_subjects,
                            feature='T2star_projabs',
                            B0='94T',
                            depth=-3.0,
                            distance_cutoff=25.0,
                            n_bins=50, 
                            colors=colors_B0,
                            zscore=False,
                            axs=axs[5,:],
                            xlabel=False,
                            ylabel='[ms]',
                            ylim=(15, 35))

title_ax = fig.add_subplot(gs[13, :])
title_ax.text(0.5, 0.5, "9.4T QSM Ï‡", ha="center", va="top")
title_ax.axis("off")

plot_feature_vs_distance(surfdata=surfdata,
                            lesional_subjects=lesional_subjects,
                            feature='QSMTke3_projabs',
                            B0='94T',
                            depth=-3.0,
                            distance_cutoff=25.0,
                            n_bins=50, 
                            colors=colors_B0,
                            zscore=False,
                            axs=axs[6,:],
                            xlabel=True,
                            ylabel='[ppm]',
                            ylim=(-0.015, 0.015))

for i in range(7):
    axs[i,0].axvline(x=dist_ROI_thresholds[0], color='gray', linestyle='--', linewidth=1.0)
    axs[i,1].axvline(x=dist_ROI_thresholds[1], color='gray', linestyle='--', linewidth=1.0)

axs[6,0].set_xlim(0.0, 25.0)
axs[6,1].set_xlim(0.0, 25.0)



In [None]:
print(dist_ROI_thresholds)

In [None]:
fig = plt.figure(figsize=(3,2))
# add custom legend 
custom_lines = [matplotlib.lines.Line2D([], [], marker='o', color=colors_subj_B0[0][i_B0], label=f'{B0} patient lesion') for i_B0, B0 in enumerate(['3T', '9.4T'])]
custom_lines.append(matplotlib.lines.Line2D([], [], marker='o', color='lightgray', alpha=0.3, label='controls homotopic region'))
custom_lines.append(matplotlib.lines.Line2D([], [], marker='x', color='royalblue', alpha=0.5, label='patient contralateral homotopic region'))
custom_lines.append(matplotlib.lines.Line2D([], [], marker='x', color='lightskyblue', alpha=0.3, label='controls contralateral homotopic region'))
custom_lines.append(matplotlib.lines.Line2D([], [], color='gray', linestyle='--', label='distance where <50% vertices in lesion ROI'))
fig.legend(handles=custom_lines, loc='center', fontsize='small', ncol=3)

# create empty axes to show the figure
fig.add_axes([0,0,1,1], frameon=False, xticks=[], yticks=[])