In [None]:
import os
import ants
import numpy as np
import os.path as op
import nibabel as nib

from fury import actor, window
from dipy.data import get_sphere

In [None]:
participant = "sub-XXX" 

# based on figure03a image slice
nx = 20; x_min = 14; x_max = x_min + nx # 20 width
ny = 20; y_min = 50; y_max = y_min + ny # 20 width
nz = 1; z_slice = 49 

paths_data = op.join("/path", "to", "data")
paths_save = op.join("paths", "to", "figure03")
os.makedirs(paths_save, exist_ok = True)

In [None]:
odf_fnames = {
  "Original": op.join(paths_data, f"{participant}_multi-shell_desc-original_model-CSD_odf.nii.gz"), 
  "FWE":      op.join(paths_data, f"{participant}_multi-shell_desc-fwe_model-CSD_odf.nii.gz"), 
  "MSMT":     op.join(paths_data, f"{participant}_multi-shell_model-MSMT_odf.nii.gz"),
}
mask_fname  = op.join(paths_data, f"{participant}_space-ACPC_desc-brain_mask.nii.gz")
flair_fname = op.join(paths_data, f"{participant}_space-ACPC_desc-preproc_FLAIR.nii.gz")
dseg_fname  = op.join(paths_data, f"{participant}_space-ACPC_desc-aseg_dseg.nii.gz")
wmh_fname   = op.join(paths_data, f"{participant}_space-ACPC_desc-WMH_desc-clean_dseg.nii.gz")

In [None]:
# load odfs and convert to numpy
odf_images = {k: nib.load(f) for k, f in odf_fnames.items()}

# load brain mask image (reference for diffusion sampling)
mask_image  = ants.image_read(mask_fname)

# load anatomical images (to be resampled to diffusion space)
flair_image = ants.image_read(flair_fname)
dseg_image  = ants.image_read(dseg_fname)

# resample anatomical images to diffusion space
flair_image = ants.resample_image_to_target(flair_image, mask_image, interp_type = "nearestNeighbor")
dseg_image  = ants.resample_image_to_target(dseg_image, mask_image, interp_type = "nearestNeighbor")

# create WM mask from discrete segmentations
dseg_values = dseg_image.numpy() == 0 # intialize with background
for i in [3, 4, 24, 42, 43]: dseg_values = np.logical_or(dseg_values, dseg_image.numpy() == i)
dseg_image = np.logical_not(dseg_values) * 1.0 # white matter values
dseg_zoom  = dseg_image[x_min:x_max, y_min:y_max, z_slice]

In [None]:
# create flair actor
flair_value = flair_image.numpy()[x_min:x_max, y_min:y_max, z_slice]
flair_value = flair_value.reshape((nx, ny, nz))

flair_actor = actor.slicer(
  data = flair_value, 
  value_range = (-5, 5)
)

# ODFs with FLAIR background
scene = window.Scene()
for method, image in odf_images.items():
  # create fodf actor
  image = image.get_fdata()[x_min:x_max, y_min:y_max, z_slice, ...]
  image[dseg_zoom == 0] = 0 # remove not wm voxels
  image = image.reshape((nx, ny, nz, image.shape[-1]))

  fodf_actor = actor.odf_slicer(
    odfs     = image, 
    sphere   = get_sphere("symmetric362"), 
    scale    = 0.8, 
    colormap = None # rgb
  )

  scene.add(flair_actor) # add flair background
  scene.add(fodf_actor)   # add current fodf
  scene.background((1, 1, 1))
  save_name = f"figure03_method-{method}_fODFs.png"
  window.record(
    scene    = scene, 
    out_path = op.join(paths_save, save_name), 
    size     = (2400, 2400)
  )
  print(f"Saved: {save_name}")
  scene.clear()