In [None]:
import os
import ants
import numpy as np
import os.path as op
import matplotlib.pyplot as plt
from skimage.measure import find_contours

In [None]:
plt.rcParams.update({
  "text.usetex": False,
  "font.family": "Helvetica",
  "font.size": 14
})

In [None]:
paths_data = op.join("/path", "to", "data")
paths_save = op.join("paths", "to", "figure02")
os.makedirs(paths_save, exist_ok = True)

In [None]:
corr2_images = {
  "Original": ants.image_read(op.join(paths_data, "figure02_Multi-Shell_desc-original_corr2.nii.gz")),
  "FWE": ants.image_read(op.join(paths_data, "figure02_Multi-Shell_desc-fwe_corr2.nii.gz")),
  "MSMT": ants.image_read(op.join(paths_data, "figure02_Multi-Shell_desc-MSMT_corr2.nii.gz"))
}

flair_image = ants.image_read(op.join(paths_data, "figure02_space-ACPC_desc-preproc_FLAIR.nii.gz"))
flair_image = ants.resample_image_to_target(flair_image, corr2_images["Original"], interp_type = "linear").numpy()

wmh_image = ants.image_read(op.join(paths_data, "figure02_space-ACPC_desc-WMH_desc-clean_dseg.nii.gz"))
wmh_image = ants.resample_image_to_target(wmh_image, corr2_images["Original"], interp_type = "nearestNeighbor").numpy()
wmh_image = (wmh_image > 0) * 1.0 # binarize wmh, collapses across distinct rois

mask_image = ants.image_read(op.join(paths_data, "figure02_space-ACPC_desc-brain_mask.nii.gz"))
mask_image = ants.resample_image_to_target(mask_image, corr2_images["Original"], interp_type = "nearestNeighbor").numpy()

dseg_image = ants.image_read(op.join(paths_data, "figure02_space-ACPC_desc-aseg_dseg.nii.gz"))
dseg_image = ants.resample_image_to_target(dseg_image, corr2_images["Original"], interp_type = "nearestNeighbor")

dseg_values = dseg_image.numpy() == 0 # intialize with background
for i in [3, 4, 24, 42, 43]: dseg_values = np.logical_or(dseg_values, dseg_image.numpy() == i)
dseg_image = np.logical_not(dseg_values) * 1.0 # white matter values only

diff_image = ( corr2_images["FWE"].numpy() - corr2_images["Original"].numpy()) / (corr2_images["Original"].numpy()) * 100

In [None]:
z_slice   = (diff_image.shape[2] // 2) + 2
cbar_lim  = 200

# extract the 2D images for plotting
plot_wmh   = wmh_image[..., z_slice].copy()
plot_dseg  = dseg_image[..., z_slice].copy()
plot_mask  = mask_image[..., z_slice].copy()
plot_flair = flair_image[..., z_slice].copy()
plot_diff  = diff_image[..., z_slice].copy()

plot_flair[plot_mask == 0] = np.nan
plot_flair = np.rot90(plot_flair, -1)

plot_diff[plot_dseg == 0] = np.nan
plot_diff = np.rot90(plot_diff, -1) 

plot_wmh = np.rot90(plot_wmh, -1)
plot_wmh = find_contours(plot_wmh, 0.5)

fig, ax = plt.subplots(1, 1, figsize = (8, 8), tight_layout = True)
ax.imshow(plot_flair, vmin = -6, vmax = 6, cmap = "gray")
h = ax.imshow(plot_diff, vmin = -cbar_lim, vmax = cbar_lim, 
              cmap = "RdBu_r", alpha = 1.0)
for contour in plot_wmh: # for each wmh contour
    ax.plot(contour[:, 1], contour[:, 0], color = "yellow", linewidth = 2)
ax.set_xticks([]); ax.set_yticks([])
cbar = fig.colorbar(h, ax = ax, label = "fODF $r^{2}$ Percent Difference\n(FWE - Original)")
cbar_ticks = np.linspace(-cbar_lim, cbar_lim, 7)
cbar_ticks_str = [f"{x:.0f}%" for x in cbar_ticks]
cbar.set_ticks(cbar_ticks); cbar.set_ticklabels(cbar_ticks_str)
plt.show()

save_name = f"figure02_FWE-Original_corr2.svg"
fig.savefig(op.join(paths_save, save_name))
print(f"Saved: {save_name}")