In [None]:
import os
import util
import numpy as np
import nibabel as nib
import nibabel.processing
import nilearn
import nilearn.plotting
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline


# Create tmp directory if it doesn't exist
os.makedirs('tmp', exist_ok=True)

os.environ['SUBJECTS_DIR'] = 'data/derivatives/freesurfer'


def read_lta_affine(path_lta):
    with open(path_lta, 'r') as f:
        lines = f.readlines()
    # The affine matrix is 4 lines after a line that starts with "1 4 4"
    start_index = None
    for i, line in enumerate(lines):
        if line.startswith('1 4 4'):
            start_index = i + 1
            break
    affine_lines = lines[start_index:start_index+4]
    affine = np.array([[float(num) for num in line.split()] for line in affine_lines])
    return affine

def coreg_img_to(img, reg_affine, target, interpolation='linear'):
    # update header with coregistration affine
    out_img = nib.Nifti1Image(
        img.get_fdata(),
        target.affine @ reg_affine,
    )

    out_img = nilearn.image.resample_to_img(
        out_img,
        target,
        interpolation=interpolation,
    )
    return out_img

def crop_and_plot(img,
                  center,
                  margin,
                  vmin=None,
                  vmax=None,
                  display_mode='y'):
    def crop(img, center, margin):
        # map center RAS to voxel coordinates
        ras2vox = np.linalg.inv(img.affine)
        center_h = np.array([center[0], center[1], center[2], 1])
        center_vox_h = ras2vox @ center_h
        # round to nearest integer
        center_vox_h = np.round(center_vox_h).astype(int)
        x_min = max(center_vox_h[0] - margin, 0)
        x_max = min(center_vox_h[0] + margin, img.shape[0])
        y_min = max(center_vox_h[1] - margin, 0)
        y_max = min(center_vox_h[1] + margin, img.shape[1])
        z_min = max(center_vox_h[2] - margin, 0)
        z_max = min(center_vox_h[2] + margin, img.shape[2])

        im_cropped_data = img.get_fdata()[x_min:x_max, y_min:y_max, z_min:z_max]
        im_cropped_affine = img.affine.copy()
        im_cropped_affine[:3, 3] += np.array([x_min, y_min, z_min]) * img.header.get_zooms()[:3]
        im_cropped = nib.Nifti1Image(im_cropped_data, im_cropped_affine)
        return im_cropped

    img = crop(img, center, margin=margin)

    def reset_affine(img):
        # we need to reset the affine so that nilearn doesn't resample to MNI space
        # reset affine of t2star_img to only contain zooms
        reset_affine = np.eye(4)
        reset_affine[:3, :3] = np.diag(img.header.get_zooms())

        img_reset = nib.Nifti1Image(nib.as_closest_canonical(img).get_fdata(), reset_affine)
        return img_reset

    img = reset_affine(img)

    center_ras = (np.array([img.shape[0]//2, img.shape[1]//2, img.shape[2]//2, 1]) @ img.affine)

    fig = plt.figure(figsize=(3,3), layout='tight')
    #fig.get_layout_engine().set(w_pad=0/72., h_pad=1/72., hspace=0.0, wspace=0.0)
    if display_mode == 'y':
        cut_coords = [center_ras[1]]
    elif display_mode == 'x':
        cut_coords = [center_ras[0]]
    elif display_mode == 'z':
        cut_coords = [center_ras[2]]
    orthoslicer = nilearn.plotting.plot_anat(img,
                            display_mode=display_mode,
                            cut_coords=cut_coords,
                            cmap='gray',
                            radiological=True,
                            annotate=False,
                            vmin=vmin,
                            vmax=vmax,
                            figure=fig)
    orthoslicer.frame_axes.set_frame_on(False)
    for ax in orthoslicer.axes.values():
        ax.ax.set_frame_on(False)
    return orthoslicer

def add_colorbar_to_figure(orthoslicer, ticks = None, label='', format='%.1f'):
    fig = orthoslicer.frame_axes.figure
    axes = list(orthoslicer.axes.values())[0].ax


    # create cax for colorbar
    cax = fig.add_axes([0.7725, 0.025, 0.2, 0.3],
                    facecolor=[1.0, 1.0, 1.0, 0.7],
                    xticks=[],
                    yticks=[])
    cax.spines[:].set_visible(False)

    # add another cax within cax for the colorbar
    cax2 = fig.add_axes([0.925, 0.05, 0.025, 0.25],
                        xticks=[],
                        yticks=[])
    cax2.spines[:].set_visible(False)

    # add colorbar within the cax
    if ticks is None:
        ticks = [axes.images[0].get_clim()[0], axes.images[0].get_clim()[1]]
    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap="gray",
                                              norm=matplotlib.colors.Normalize(
                                                  vmin=ticks[0],
                                                  vmax=ticks[1],
                                              )
                        ),
                        ticks=ticks,
                        cax=cax2,
                        orientation='vertical',
                        )
    # set ticks to left side
    cbar.ax.yaxis.set_ticks_position('left')
    # add text for label at the left side
    plt.text(0.7, 0.5,
            label,
            rotation=0,
            va='center',
            ha='right',
            transform=cax.transAxes)
    cbar.ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter(format))
    


In [None]:
def plot_all_images_for_subject(subj_id, center, margin=40, display_mode='y'):
    dir_3T = f'data/derivatives/freesurfer/sub-{subj_id}_3T/'
    dir_94T = f'data/derivatives/freesurfer/sub-{subj_id}_94T/'
    path_coreg_3T_to_94T = f'data/derivatives/freesurfer/sub-{subj_id}_3T/new/3T_to_94T.lta'
    if not os.path.exists(path_coreg_3T_to_94T):
        cmd = f'bbregister --s sub-{subj_id}_94T --mov data/derivatives/freesurfer/sub-{subj_id}_3T/mri/orig.mgz --reg data/derivatives/freesurfer/sub-{subj_id}_3T/new/3T_to_94T.lta --t1 --12'
        util.bash_run(cmd)

    affine_3T_to_94T = read_lta_affine(path_coreg_3T_to_94T)

    unit1_94T = nib.load(f'{dir_94T}/new/UNIT1_conform_orig.mgz')
    t2star_94T = nib.load(f'{dir_94T}/new/T2star_to_orig.nii.gz')
    qsm_94T = nib.load(f'{dir_94T}/new/QSMTke3_to_orig.nii.gz')
    greecho1_94T = nib.load(f'{dir_94T}/new/GREecho1_offline_N4_to_orig.nii.gz')
    t1map_94T = nib.load(f'{dir_94T}/new/T1map_conform_orig.nii.gz')

    unit1_3T = nib.load(f'{dir_3T}/new/UNIT1_conform_orig.mgz')
    unit1_3T = coreg_img_to(unit1_3T, affine_3T_to_94T, unit1_94T)
    flair_3T = nib.load(f'{dir_3T}/new/FLAIR_N4_to_orig.nii.gz')
    flair_3T = coreg_img_to(flair_3T, affine_3T_to_94T, unit1_94T)
    t1w_3T = nib.load(f'{dir_3T}/new/T1w_N4_to_orig.nii.gz')
    t1w_3T = coreg_img_to(t1w_3T, affine_3T_to_94T, unit1_94T)
    t1map_3T = nib.load(f'{dir_3T}/new/T1map_conform_orig.nii.gz')
    t1map_3T = coreg_img_to(t1map_3T, affine_3T_to_94T, unit1_94T)
    t1map_3T = nib.Nifti1Image(t1map_3T.get_fdata()/1000.,t1map_3T.affine,t1map_3T.header)


    plot = crop_and_plot(unit1_3T, center, margin=margin, vmin=0, vmax=4095, display_mode=display_mode)
    plot = crop_and_plot(t1map_3T, center, margin=margin, vmin=0.3, vmax=2.2, display_mode=display_mode)
    add_colorbar_to_figure(plot, ticks=[0.3, 2.2], label='[s]')
    plot = crop_and_plot(t1w_3T, center, margin=margin, vmin=0, vmax=483, display_mode=display_mode)
    plot = crop_and_plot(flair_3T, center, margin=margin, vmin=0, vmax=191, display_mode=display_mode)
    plot = crop_and_plot(unit1_94T, center, margin=margin, vmin=0, vmax=1.0, display_mode=display_mode)
    plot = crop_and_plot(t1map_94T, center, margin=margin, vmin=0.5, vmax=3.0, display_mode=display_mode)
    add_colorbar_to_figure(plot, ticks=[0.5, 3.0], label='[s]')
    plot = crop_and_plot(greecho1_94T, center, margin=margin, vmin=0, vmax=2e-8, display_mode=display_mode)
    plot = crop_and_plot(t2star_94T, center, margin=margin, vmin=12, vmax=77., display_mode=display_mode)
    add_colorbar_to_figure(plot, ticks=[12, 77], label='[ms]')
    plot = crop_and_plot(qsm_94T, center, margin=margin, vmin=-0.05, vmax=0.05, display_mode=display_mode)
    add_colorbar_to_figure(plot, ticks=[-0.05, 0.05], label='[ppm]', format='%.2f')


In [None]:
subj_id = 'P017'
center = (18.23, 54.17, 20.78)

plot_all_images_for_subject(subj_id, center)

In [None]:
subj_id = 'P026'
center = (-21.43, 52.40, 23.38)

plot_all_images_for_subject(subj_id, center)

In [None]:
subj_id = 'P009'
center = (0.00, -8.23, 5.0)
plot_all_images_for_subject(subj_id, center, margin=80)

In [None]:
subj_id = 'P016'
center = (34.93, 20.47, -15.94)
plot_all_images_for_subject(subj_id, center, margin=60, display_mode='z')

In [None]:
subj_id = 'P016'
center = (-24.48, -24.95, -14.34)
plot_all_images_for_subject(subj_id, center, margin=60, display_mode='z')