In [None]:
import os
import ants
import numpy as np
import os.path as op
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from skimage.measure import find_contours

In [None]:
plt.rcParams.update({
  "text.usetex": False,
  "font.family": "Helvetica",
  "font.size": 14
})

In [None]:
participant = "sub-XXX" 

paths_data = op.join("/path", "to", "data")
paths_save = op.join("paths", "to", "figure03")
os.makedirs(paths_save, exist_ok = True)

In [None]:
mask_fname  = op.join(paths_data, f"{participant}_space-ACPC_desc-brain_mask.nii.gz")
flair_fname = op.join(paths_data, f"{participant}_space-ACPC_desc-preproc_FLAIR.nii.gz")
wmh_fname   = op.join(paths_data, f"{participant}_space-ACPC_desc-WMH_desc-clean_dseg.nii.gz")

In [None]:
# load brain mask image (referenc for diffusion sampling)
mask_image  = ants.image_read(mask_fname)

# load anatomical images (to be resampled to diffusion space)
flair_image = ants.image_read(flair_fname)
wmh_image   = ants.image_read(wmh_fname)

# resample anatomical images to diffusion space
flair_image = ants.resample_image_to_target(flair_image, mask_image, interp_type = "nearestNeighbor")
wmh_image   = ants.resample_image_to_target(wmh_image, mask_image, interp_type = "genericLabel")

# create WMH mask from WMH segmentations
wmh_mask = (wmh_image > 0) * 1.0 # WMH values

In [None]:
z_slice = 49

plot_mask  = mask_image[:,:,z_slice].numpy()
plot_wmh   = wmh_mask[:,:,z_slice].numpy()
plot_flair = flair_image[:,:,z_slice].numpy()

plot_flair[plot_mask == 0] = np.nan
plot_wmh = find_contours(np.rot90(plot_wmh, -1), 0.5)

ny, nx = np.rot90(plot_flair, -1).shape
x_min = 14; x_max = 34; x_min = nx - x_min; x_max = nx - x_max 
y_min = 50; y_max = 70

fig, ax = plt.subplots(1, 1, figsize = (8, 8), tight_layout = True)
ax.imshow(np.rot90(plot_flair, -1), vmin = -6, vmax = 6, cmap = "gray")
for contour in plot_wmh: # for each contour
  ax.plot(contour[:, 1], contour[:, 0], color = "yellow", linewidth = 2)
rect = Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, 
                 edgecolor = "red", facecolor = "none", linewidth = 2)
ax.add_patch(rect)
ax.set_xticks([]); ax.set_yticks([])
plt.show()

fig.savefig(op.join(paths_save, "figure03_desc-zoom_flair.svg"))