In [None]:
%matplotlib inline
import os
import subprocess
import shutil
import numpy as np
import nibabel as nib
import pandas as pd
import seaborn as sns
import tqdm
import matplotlib
import colorsys
import matplotlib.pyplot as plt
import nilearn.plotting
from scipy.ndimage import binary_dilation

In [None]:
# choose either sub-P017 or sub-P026 here
subj_id = 'sub-P017'

In [None]:
# load the T2star volume
T2star_path = f'data/derivatives/T2starmaps/{subj_id}/ses-94T/anat/{subj_id}_ses-94T_acq-3DFLASH_proc-scanner_T2starmap.nii.gz'
T2star_path_copy = f'data/derivatives/laynii/{subj_id}_94T_T2star.nii.gz'
os.makedirs('data/derivatives/laynii', exist_ok=True)
shutil.copy2(T2star_path, T2star_path_copy)
T2star_path = T2star_path_copy

# coregister and align the aseg and lesion mask to the T2star space
aseg_path = f'data/derivatives/freesurfer/{subj_id}_94T/mri/aseg.mgz'
T2star_to_orig_lta_path = f'data/derivatives/freesurfer/{subj_id}_94T/new/GREecho1_scanner_to_orig_reg.lta'
aseg_in_T2star_path = f'data/derivatives/laynii/{subj_id}_94T_aseg_in_T2star.mgz'
cmd = f"mri_vol2vol --mov '{T2star_path}' --targ '{aseg_path}' --o {aseg_in_T2star_path} --lta {T2star_to_orig_lta_path} --inv --nearest"
subprocess.run(cmd, shell=True, executable='/bin/bash')

lesionmask_in_orig_path = f'data/derivatives/lesionlabels/{subj_id}/{subj_id}_94T_roi_reg.mgz'
lesionmask_in_T2star_path = f'data/derivatives/laynii/{subj_id}_94T_lesionmask_in_T2star.nii.gz'
cmd = f"mri_vol2vol --mov '{T2star_path}' --targ '{lesionmask_in_orig_path}' --o '{lesionmask_in_T2star_path}' --lta '{T2star_to_orig_lta_path}' --inv --nearest"
subprocess.run(cmd, shell=True, executable='/bin/bash')

lesion_img = nib.load(lesionmask_in_T2star_path)
lesion_data = lesion_img.get_fdata() > 0

In [None]:
# load distance map in volume space

# determine hemisphere based on mean lesion coordinates
mean_coords = [coords.mean() for coords in lesion_data.nonzero()]
mean_coords_RAS = nib.affines.apply_affine(lesion_img.affine, mean_coords)
if mean_coords_RAS[0] > 0:
    hemi = 'rh'
else:
    hemi = 'lh'

# map distance volume data to volume
orig_path = f'data/derivatives/freesurfer/{subj_id}_94T/mri/orig.mgz'
white_path = f'data/derivatives/freesurfer/{subj_id}_94T/surf/{hemi}.white'
distance_path = f'data/derivatives/lesionlabels/{subj_id}/{subj_id}_94T_{hemi}_distances_nativesurf_{hemi}.mgh'

distance_origspace_path = f'data/derivatives/laynii/{subj_id}_94T_distance_in_orig.nii.gz'

# map distance volume data to volume
cmd = f"SUBJECTS_DIR=data/derivatives/freesurfer mri_surf2vol --subject {subj_id}_94T --so {white_path} {distance_path} --o {distance_origspace_path}"
subprocess.run(cmd, shell=True, executable='/bin/bash')

distance_t2starspace_path = f'data/derivatives/laynii/{subj_id}_94T_distance_in_T2star.nii.gz'

# align to T2star space
cmd = f"mri_vol2vol --mov '{T2star_path}' --targ '{distance_origspace_path}' --o '{distance_t2starspace_path}' --lta '{T2star_to_orig_lta_path}' --inv --nearest"
subprocess.run(cmd, shell=True, executable='/bin/bash')

# load
distance_img = nib.load(distance_t2starspace_path)
distance_data = distance_img.get_fdata()

In [None]:
aseg_img = nib.load(aseg_in_T2star_path)
aseg_data = aseg_img.get_fdata()

# select by distance from centroid
MAX_DISTANCE_MM = 100
cortex_selection_vol = (distance_data <= MAX_DISTANCE_MM)
# remove other hemisphere from selection
if hemi == 'lh':
    cortex_selection_vol[aseg_data != 3] = 0
elif hemi == 'rh':
    cortex_selection_vol[aseg_data != 42] = 0
# dilate selection by 1 voxel to make sure there are no gaps
cortex_selection_vol = binary_dilation(cortex_selection_vol, iterations=1)

# replace 42 with 3 so that there is only one cortical GM label
aseg_data[aseg_data == 42] = 3
# and 41 with 2 for the WM
aseg_data[aseg_data == 41] = 2

# only keep selected cortex
aseg_data[(aseg_data == 3) & ~cortex_selection_vol] = 0

aseg_cropped_lesion_path = f'data/derivatives/laynii/{subj_id}_94T_aseg_in_T2star_lesioncrop.nii.gz'
os.makedirs('data/derivatives/laynii', exist_ok=True)
nib.save(
    nib.Nifti1Image(aseg_data, aseg_img.affine),
    aseg_cropped_lesion_path
)


In [None]:
# laynii rimify and layerification
rim_path = f'data/derivatives/laynii/{subj_id}_94T_rim.nii.gz'
cmd = f"~/laynii/LN2_RIMIFY -input '{aseg_cropped_lesion_path}' -innergm 2 -outergm 0 -gm 3 -output '{rim_path}'"
subprocess.run(cmd, shell=True, executable='/bin/bash')

metric_path = f'data/derivatives/laynii/{subj_id}_94T_layers_metric_equivol.nii'
layer_path = f'data/derivatives/laynii/{subj_id}_94T_layers_layers_equivol.nii'
cmd = f"~/laynii/LN2_LAYERS -rim {rim_path} -output data/derivatives/laynii/{subj_id}_94T_layers -equivol -nr_layers 10"
subprocess.run(cmd, shell=True, executable='/bin/bash')


In [None]:
# get total vol of gray matter within selected rim and determine number of columns
rim_vol = (nib.load(rim_path).get_fdata() == 3).sum()
COLUMN_VOL_MM3 = 40
column_vol_voxs = COLUMN_VOL_MM3 / (np.array(nib.load(rim_path).header.get_zooms()).prod())
n_columns = int(rim_vol / column_vol_voxs)
print(f'Using {n_columns} columns for target volume of {COLUMN_VOL_MM3} mm3 per column.')

In [None]:
# laynii columns
columns_path = f'data/derivatives/laynii/{subj_id}_94T_rim_columns{n_columns}.nii.gz'
cmd = f"~/laynii/LN2_COLUMNS -rim {rim_path} -midgm data/derivatives/laynii/{subj_id}_94T_layers_midGM_equivol.nii -nr_columns {n_columns}"
subprocess.run(cmd, shell=True, executable='/bin/bash')

t2star_img = nib.load(T2star_path)
columns_img = nib.load(columns_path)
metric_img = nib.load(metric_path)
layer_img = nib.load(layer_path)

t2star_data = t2star_img.get_fdata()
columns_data = columns_img.get_fdata()
metric_data = metric_img.get_fdata()
layer_data = layer_img.get_fdata()
# lesion_data was already loaded above


In [None]:
# one possibility is loading a manual segmentation label for the black line region
# load here if this exists, otherwise create an empty array

blacklineseg_path = f'data/derivatives/laynii/{subj_id}_94T_T2star_blacklineseg.nii.gz'
if os.path.isfile(blacklineseg_path):
    blacklineseg_img = nib.load(blacklineseg_path)
    blacklineseg_data = blacklineseg_img.get_fdata()
else:
    blacklineseg_data = np.zeros_like(t2star_data)

In [None]:
voxel_data = []
for (x, y, z) in zip(*np.nonzero(columns_data > 0)):
    voxel_data.append({
        'column': int(columns_data[x, y, z]),
        'depth': float(metric_data[x, y, z]),
        'layer': int(layer_data[x, y, z]),
        'lesion_label': int(lesion_data[x, y, z]),
        'distance': float(distance_data[x, y, z]),
        't2star': float(t2star_data[x, y, z]),
    })

voxel_df = pd.DataFrame(voxel_data)

# average/median voxel_df for each column (only for numerical columns)
voxel_df = voxel_df.groupby(['column', 'layer']).median().reset_index()

# z-score per layer
#voxel_df['t2star'] = voxel_df.groupby('layer')['t2star'].transform(lambda x: (x - x.mean()) / x.std())

## select columns intersecting the blackline segmentation
#columns_to_select = columns_data[blacklineseg_data > 0]
#columns_to_select = np.unique(columns_to_select[columns_to_select > 0])

## select columns within the dilated blackline segmentation
#blacklineseg_data_dilated = binary_dilation(blacklineseg_data > 0, iterations=10)
#columns_to_select = columns_data[blacklineseg_data_dilated > 0]
#columns_to_select = np.unique(columns_to_select[columns_to_select > 0])
#voxel_df = voxel_df[voxel_df['column'].isin(columns_to_select)]

# select layer 5 and 2 for applying different criteria
voxel_df_layer5 = voxel_df[(voxel_df['layer'] == 5) & (voxel_df['lesion_label'] == 1)]
voxel_df_layer2 = voxel_df[(voxel_df['layer'] == 2) & (voxel_df['lesion_label'] == 1)]

    # absoslute cutoffs
columns_to_select = voxel_df_layer5[voxel_df_layer5['t2star'] < 30]['column'].values
columns_to_select = np.intersect1d(columns_to_select, voxel_df_layer2[voxel_df_layer2['t2star'] > 30]['column'].values)

    # z-score cutoffs
#columns_to_select = voxel_df_layer5[voxel_df_layer5['t2star'] < -1.5]['column'].values
#columns_to_select = np.intersect1d(columns_to_select, voxel_df_layer2[voxel_df_layer2['t2star'] > 0.0]['column'].values)

print(f"Highlighting {len(columns_to_select)} columns: {columns_to_select}")

In [None]:

if len(columns_to_select) == 0:
    print("Setting all columns to highlight to avoid empty selection")
    columns_to_select = voxel_df['column'].unique()

highlighted_columns_vol = columns_data.copy()
highlighted_columns_vol[~np.isin(columns_data, columns_to_select)] = 0
highlighted_columns_img = nib.Nifti1Image(highlighted_columns_vol, columns_img.affine)
highlighted_columns_path = f'data/derivatives/laynii/{subj_id}_94T_rim_columns150_highlighted.nii.gz'
nib.save(
    highlighted_columns_img,
    highlighted_columns_path
)


# compute average highlight profile
average_highlight_profile = voxel_df[voxel_df['column'].isin(columns_to_select.astype(int))].groupby('layer')['t2star'].mean().reset_index()

# save average highlight profile to file
average_highlight_profile_path = f'data/derivatives/laynii/{subj_id}_94T_average_highlight_profile.csv'
average_highlight_profile.to_csv(average_highlight_profile_path, index=False)

# optional: read average highlight profile of sub-P017 from file
average_highlight_profile = pd.read_csv('data/derivatives/laynii/sub-P017_94T_average_highlight_profile.csv')

# for each column, compute RMSE with average highlight profile
def compute_rmse(column_df, reference_profile):
    merged_df = pd.merge(column_df, reference_profile, on='layer', suffixes=('_col', '_ref'))
    rmse = np.sqrt(np.mean((merged_df['t2star_col'] - merged_df['t2star_ref']) ** 2))
    return rmse
column_correlation = []
for column in tqdm.tqdm(voxel_df['column'].unique()):
    column_df = voxel_df[voxel_df['column'] == column]
    rmse = compute_rmse(column_df, average_highlight_profile)
    column_correlation.append({'column': column, 'rmse': rmse})
column_correlation_df = pd.DataFrame(column_correlation)
column_correlation_df['rmse'] = -np.log(column_correlation_df['rmse'])

# compute correlation coefficient with average highlight profile
#def compute_correlation(column_df, reference_profile):
#    merged_df = pd.merge(column_df, reference_profile, on='layer', suffixes=('_col', '_ref'))
#    correlation = merged_df['t2star_col'].corr(merged_df['t2star_ref'])
#    return correlation
#column_correlation = []
#for column in tqdm.tqdm(voxel_df['column'].unique()):
#    column_df = voxel_df[voxel_df['column'] == column]
#    correlation = compute_correlation(column_df, average_highlight_profile)
#    column_correlation.append({'column': column, 'correlation': correlation})
#column_correlation_df = pd.DataFrame(column_correlation)

statistic_vol = np.zeros_like(columns_data)

lookup_array = np.zeros(int(columns_data.max()) + 1)
for _, row in column_correlation_df.iterrows():
    lookup_array[int(row['column'])] = row['rmse']
    
statistic_vol = lookup_array[columns_data.astype(int)]
statistic_img = nib.Nifti1Image(statistic_vol, columns_img.affine)
statistic_path = f'data/derivatives/laynii/{subj_id}_94T_rim_columns150_statistic.nii.gz'
nib.save(
    statistic_img,
    statistic_path
)

voxel_df = voxel_df.merge(column_correlation_df, on='column')

# set highlight column
voxel_df['highlight'] = voxel_df['column'].isin(columns_to_select.astype(int))


In [None]:

# plot voxel_df as lineplot with column as hue and depth on x-axis and t2star on y-axis
matplotlib.rcParams['figure.dpi'] = 600
fig = plt.figure(figsize=(4.6, 2.2), constrained_layout=True)
sns.lineplot(
    data=voxel_df[voxel_df['lesion_label'] == 1],
    x='layer',
    y='t2star',
    orient='x',
    units='column',
    estimator=None,
    hue='highlight',
    palette=['gray', 'darkred'],
    size='highlight',
    style='highlight',
    sizes=(1.5, 0.5),
    legend=False,
)

#plt.xlim(-4.5,4.5)
plt.ylim(20,45)

plt.xlabel('cortical layers (white matter to pial boundary)')
plt.xticks(np.arange(1, 11))
plt.ylabel('T2* [ms]')
# add horizontal grid lines
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
sns.despine()
matplotlib.rcParams['figure.dpi'] = 600

In [None]:
# plot correlation score vs. distance from lesion centroid
matplotlib.rcParams['figure.dpi'] = 600

column_df = voxel_df.groupby('column').median().reset_index()
column_df = column_df[column_df['distance'] < 25.0]
fig = plt.figure(figsize=(4.6, 2.2), constrained_layout=True)
# set legend to be outside the plot
sns.scatterplot(
    data=column_df.replace(
        {'lesion_label': {0: 'outside', 1: 'inside manual label'}}
    ),
    x='distance',
    y='rmse',
    hue='rmse',
    palette='inferno',
  #  hue_order=['outside', 'inside manual label'],
    s=20,
    legend=False,
    edgecolor='lightgray'
)
plt.ylabel('-log(RMSE) to \n average target profile')
plt.xlabel('distance to lesion centroid [mm]')
plt.ylim(-4.5, 0.0)
plt.yticks([-4.0, -2.0, 0.0])
plt.grid(True, axis='y', linestyle='--', alpha=0.7)

# add vline at 11.420
dist_threshold = 11.4209 if subj_id == 'sub-P017' else 12.5609
plt.axvline(11.4209, color='gray', linestyle='--', linewidth=0.7)
#plt.axhline(0, color='gray', linestyle='--')
column_df


In [None]:
# we need to reset the affine so that nilearn doesn't resample to MNI space
# reset affine of t2star_img to only contain zooms
reset_affine = np.eye(4)
reset_affine[:3, :3] = np.diag(t2star_img.header.get_zooms())

t2star_img = nib.Nifti1Image(nib.as_closest_canonical(t2star_img).get_fdata(), reset_affine)
columns_img = nib.Nifti1Image(nib.as_closest_canonical(columns_img).get_fdata(), reset_affine)
layer_img = nib.Nifti1Image(nib.as_closest_canonical(layer_img).get_fdata(), reset_affine)
highlighted_columns_img = nib.Nifti1Image(nib.as_closest_canonical(highlighted_columns_img).get_fdata(), reset_affine)
statistic_img = nib.Nifti1Image(nib.as_closest_canonical(statistic_img).get_fdata(), reset_affine)
target_spacing = np.repeat(np.min(t2star_img.header.get_zooms()), 3)

# upsample to target spacing with nearest neighbor interpolation to avoid interpolation artifacts
t2star_up = nilearn.image.resample_img(
    t2star_img,
    target_affine=np.diag(target_spacing),
    interpolation='nearest'
) 

columns_up = nilearn.image.resample_img(
    columns_img,
    target_affine=np.diag(target_spacing),
    interpolation='nearest'
)

layers_up = nilearn.image.resample_img(
    layer_img,
    target_affine=np.diag(target_spacing),
    interpolation='nearest'
)

highlight_up = nilearn.image.resample_img(
    highlighted_columns_img,
    target_affine=np.diag(target_spacing),
    interpolation='nearest'
)

statistic_up = nilearn.image.resample_img(
    statistic_img,
    target_affine=np.diag(target_spacing),
    interpolation='nearest'
)

# crop to region around highlighted area
def crop(img, center, margin):
    x_min = max(center[0] - margin, 0)
    x_max = min(center[0] + margin, img.shape[0])
    y_min = max(center[1] - margin, 0)
    y_max = min(center[1] + margin, img.shape[1])
    z_min = max(center[2] - margin, 0)
    z_max = min(center[2] + margin, img.shape[2])

    im_cropped_data = img.get_fdata()[x_min:x_max, y_min:y_max, z_min:z_max]
    im_cropped_affine = img.affine.copy()
    im_cropped_affine[:3, 3] += np.array([x_min, y_min, z_min]) * img.header.get_zooms()[:3]
    im_cropped = nib.Nifti1Image(im_cropped_data, im_cropped_affine)
    return im_cropped

center = highlight_up.get_fdata().nonzero()
center = [int(np.mean(center[dim])) for dim in range(3)]
margin = 100  # in voxels at target_spacing
t2star_up = crop(t2star_up, center, margin)
columns_up = crop(columns_up, center, margin)
layers_up = crop(layers_up, center, margin)
highlight_up = crop(highlight_up, center, margin)
statistic_up = crop(statistic_up, center, margin)
cut_coords = nilearn.plotting.find_xyz_cut_coords(highlight_up)


In [None]:
#cut_slices = [c * 2* target_spacing[1] + cut_coords[1] for c in range(-2, 3, 1)]
cut_slices = [c * 2* target_spacing[1] + cut_coords[1] for c in [-2, 2]]

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='y',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='y',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)

# define colormap with 150 randomly sampled colors
np.random.seed(42)
columns_cmap = [[np.random.uniform(), 0.7, 0.7] for c in range(n_columns + 1)]
# convert from hsv to rgb

columns_cmap = [colorsys.hsv_to_rgb(*c) for c in columns_cmap]
columns_cmap = matplotlib.colors.ListedColormap(columns_cmap)
p.add_overlay(columns_up,
                cmap=columns_cmap,
)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='y',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)

N = 10 
colors = plt.cm.viridis(np.linspace(0, 1, N))
layers_cmap = matplotlib.colors.ListedColormap(colors)

p.add_overlay(layers_up,
                vmin=0,
                vmax=10,
                cmap=layers_cmap,
)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='y',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)
p.add_overlay(highlight_up,
                cmap=columns_cmap,
)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='y',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)
p.add_overlay(statistic_up,
                cmap='inferno',
)



In [None]:
fig = plt.figure(figsize=(3, 1), constrained_layout=True)
fig.add_axes([0,0,1,1], frameon=False, xticks=[], yticks=[])
cb = plt.colorbar(
    plt.cm.ScalarMappable(cmap=layers_cmap, norm=plt.Normalize(vmin=0.5, vmax=10.5)),
    ticks=np.arange(1,11),
    ax=fig.axes[0],
    orientation='horizontal',
)
cb.ax.tick_params(labelsize=15)
cb.set_label('cortical layers', fontsize=15)
cb.set_ticks([2, 9], labels=['white matter', 'pial surface'])
# set tick length to 0
cb.ax.tick_params(length=0)
cb.outline.set_visible(False)

fig = plt.figure(figsize=(3, 1), constrained_layout=True)
fig.add_axes([0,0,1,1], frameon=False, xticks=[], yticks=[])
cb = plt.colorbar(
    plt.cm.ScalarMappable(cmap='inferno', norm=plt.Normalize(vmin=-1, vmax=1)),
    ticks=[],
    ax=fig.axes[0],
    orientation='horizontal'
)
cb.set_label('-log(RMSE) to average target profile', fontsize=12)
cb.outline.set_visible(False)

In [None]:

cut_slices = [c * 2* target_spacing[2] + cut_coords[2] for c in [-1, 3]]


fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='z',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='z',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)

# define colormap with 150 randomly sampled colors
np.random.seed(42)
columns_cmap = [[np.random.uniform(), 0.7, 0.7] for c in range(n_columns + 1)]
# convert from hsv to rgb

columns_cmap = [colorsys.hsv_to_rgb(*c) for c in columns_cmap]
columns_cmap = matplotlib.colors.ListedColormap(columns_cmap)
p.add_overlay(columns_up,
                cmap=columns_cmap,
)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='z',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)

N = 10 
colors = plt.cm.viridis(np.linspace(0, 1, N))
layers_cmap = matplotlib.colors.ListedColormap(colors)

p.add_overlay(layers_up,
                vmin=0,
                vmax=10,
                cmap=layers_cmap,
)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='z',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)
p.add_overlay(highlight_up,
                cmap=columns_cmap,
)

fig = plt.figure(figsize=(8, 4), constrained_layout=True)
p = nilearn.plotting.plot_anat(t2star_up,
                               display_mode='z',
                               annotate=False,
                               cut_coords=cut_slices,
                               vmin=15,
                               vmax=60,
                               radiological=True,
                               draw_cross=False,
                               figure=fig)
p.add_overlay(statistic_up,
                cmap='inferno',
)

