In [None]:
import os
import re
import numpy as np
import pandas as pd
import os.path as op
import nibabel as nib

from fury import actor, colormap, window
from dipy.io.streamline import load_tractogram

In [None]:
def lines_as_tubes(sl, linewidth, **kwargs):
  line_actor = actor.line(sl, **kwargs)
  line_actor.GetProperty().SetRenderLinesAsTubes(1)
  line_actor.GetProperty().SetLineWidth(linewidth)
  return line_actor

In [None]:
paths_data = op.join("/path", "to", "data")
paths_save = op.join("paths", "to", "figure09")
os.makedirs(paths_save, exist_ok = True)

In [None]:
# define conditions from lists
participant_list = [
  "sub-XXX", # early reconstruction
  "sub-XXX", # incomplete reconstruction
  "sub-XXX", # failed reconstruction
]
tract = "RightArcuate"

for participant in participant_list: # for each participant
  # define the data directory
  data_dir = op.join(paths_data, participant)
  save_dir = op.join(paths_save, participant)
  os.makedirs(save_dir, exist_ok = True)

  # load glass brain mask
  glass_image = nib.load(op.join(data_dir, f"{participant}_space-ACPC_desc-glass_mask.nii.gz"))

  # load tractograms
  trk_bname = f"{participant}_multi-shell_method-XXX_{tract}.trx"
  trks = {
    "FWE":  load_tractogram(op.join(data_dir, re.sub("XXX", "FWE", trk_bname)), glass_image),
    "MSMT": load_tractogram(op.join(data_dir, re.sub("XXX", "MSMT", trk_bname)), glass_image),
    "Original": load_tractogram(op.join(data_dir, re.sub("XXX", "Original", trk_bname)), glass_image)
  }

  # remove invalid streamlines
  for method in trks.keys(): 
    trks[method].remove_invalid_streamlines()

  # load wmh image
  wmh_image = nib.load(op.join(data_dir, f"{participant}_space-ACPC_label-WMH_desc-clean_dseg.nii.gz"))
  wmh_data  = (wmh_image.get_fdata() > 0) * 1.0 # binarize to wmh

  # load profiles csv
  df = pd.read_csv(op.join(data_dir, f"{participant}_profiles.csv"))
  df = df[df["dataset"] == "multi-shell"]
  df = df[df["method"].isin(["afq-original", "afq-fwe", "afq-msmt"])]
  df = df[df["tract"] == tract]
  df = df[df["metric"] == "DKI-FA"]
  df = df.sort_values("node", ascending = True)

  # define method specific profiles
  profiles = {
    "FWE":      df[df["method"] == "afq-fwe"]["value"].values,
    "MSMT":     df[df["method"] == "afq-msmt"]["value"].values,
    "Original": df[df["method"] == "afq-original"]["value"].values
  }
  x_profile = np.linspace(0, 1, df["node"].max() + 1)

  # define glass brain actor
  glass_actor = actor.contour_from_roi(
    data    = glass_image.get_fdata(), 
    affine  = glass_image.affine, 
    color   = [0, 0, 0],
    opacity = 0.05
  )

  wmh_actor = actor.contour_from_roi(
    data    = wmh_data, 
    affine  = wmh_image.affine, 
    color   = [0, 1, 1],
    opacity = 0.5
  )

  # open scene window
  scene = window.Scene()
  for method, curr_trk in trks.items(): # for each tractogram
    scene.clear() # clear scene
    scene.add(glass_actor) # add glass brain actor
    scene.add(wmh_actor) # add wmh actor

    curr_profiles = profiles[method] # get profiles for current method
    for sl in curr_trk.streamlines: # for each streamline
      sl_profile = np.interp(np.linspace(0, 1, sl.shape[0]), x_profile, curr_profiles)
      sl_colors  = colormap.create_colormap(sl_profile, name = "Spectral")
      sl_actor   = lines_as_tubes([sl], linewidth = 7, colors = sl_colors)
      scene.add(sl_actor)

    # set camera position
    tract_direction = re.sub("^(Left|Right|Callosum)\\w+", "\\1", tract)
    match tract_direction: 
      case "Left":
        camera_kwargs = {
          "position": (-505.34, -99.47, 34.40), 
          "focal_point": (0.50, -16.50, 6.50), 
          "view_up": (0.07, -0.07, 1.00)
        }
      case "Right":
        camera_kwargs = {
          "position": (513.47, -19.48, -13.26), 
          "focal_point": (0.50, -16.50, 6.50),  
          "view_up": (0.04, 0.06, 1.00)
        }
      case "Callosum":
        camera_kwargs = {
          "position": (8.27, 27.84, 516.38), 
          "focal_point":  (-0.44, -3.52, 5.69), 
          "view_up": (-0.02, 1.00, -0.06)
        }

    scene.set_camera(**camera_kwargs)
    scene.background((1, 1, 1)) # white background

    # window.show(scene); scene.camera_info()
    save_name = f"{participant}_method-{method}_{tract}.png"
    window.record(
      scene = scene,
      out_path = op.join(save_dir, save_name),
      size = (2400, 2400)
    )
    print(f"Saved: {save_name}")

  scene.clear()