In [None]:
import os
import matplotlib
import numpy as np
import pandas as pd
import os.path as op
import matplotlib.pyplot as plt

In [None]:
plt.rcParams.update({
  "text.usetex": False,
  "font.family": "Helvetica",
  "font.size": 14
})

def plot_ribbon(ax, x, y, s, color, alpha = 0.25, label = None):
  ax.fill_between(x, y - s, y + s, color = color,
                  edgecolor = None, alpha = alpha)
  ax.plot(x, y, color = color, label = label)

In [None]:
paths_data = op.join("/path", "to", "data")
paths_save = op.join("paths", "to", "figure10")
os.makedirs(paths_save, exist_ok = True)

In [None]:
df_scores = pd.read_csv(op.join(paths_data, "fazekas_scores.csv"))
df_scores = df_scores[["participant", "total"]].rename(columns = {"total": "fazekas"})
df_scores.head()

In [None]:
df_profiles = pd.read_csv(op.join(paths_data, "profiles.csv"))
df_profiles = df_profiles[df_profiles["dataset"].isin(["multi-shell", "single-shell"])]
df_profiles = df_profiles[df_profiles["method"].isin(["afq-original", "afq-fwe", "afq-msmt"])]
df_profiles = df_profiles[df_profiles["metric"].str.match("^(DTI|DKI)-")]
df_profiles = df_profiles.merge(df_scores, on = "participant")
df_profiles = df_profiles[df_profiles["fazekas"] > 1] # bc only two participants 
df_profiles.head()

In [None]:
df_wmh = pd.read_csv(op.join(paths_data, "wmh_profiles.csv"))
df_wmh = df_wmh[df_wmh["dataset"].isin(["multi-shell", "single-shell"])]
df_wmh = df_wmh[df_wmh["method"].isin(["afq-original", "afq-fwe", "afq-msmt"])]
df_wmh = df_wmh[df_wmh["metric"] == "WMH Mask"]
df_wmh["value"] = (df_wmh["value"] > 0.1).astype(float) # binarize by percentage
df_wmh = df_wmh.groupby(["dataset", "method", "tract", "node"])["value"].mean().reset_index()
df_wmh.head()

In [None]:
# define colormap for WMH overlap background gradient
gradient_range = [0, 1] # [min max]
gradient_edges = np.linspace(np.min(gradient_range), np.max(gradient_range), 255) 

gradient_cmap = matplotlib.colormaps["gray_r"]
gradient_cmap = gradient_cmap(np.linspace(0, 1, num = 255))[..., :3]

# define colormap for fazekas score
fazekas_cmap = matplotlib.colormaps["Spectral"]
fazekas_cmap = fazekas_cmap(np.linspace(0, 1, num = 20))
fazekas_cmap = { 
  2: fazekas_cmap[0], 3: fazekas_cmap[5], 4: fazekas_cmap[10], 
  5: fazekas_cmap[15], 6: fazekas_cmap[19]
}

# define grouping variables and ploting variables
group_vars      = ["dataset", "method", "metric", "tract"]
fazekas_list    = [2, 3, 4, 5, 6] # one person with fazekas score of 1
gradient_extent = [df_wmh["node"].min(), df_wmh["node"].max(), 
                   np.min(gradient_range), np.max(gradient_range)]

for (dataset, method, metric, tract), df_group in df_profiles.groupby(group_vars): 
  # extract background gradient of wmh overlap
  df_gradient = df_wmh[df_wmh["dataset"] == dataset]
  df_gradient = df_gradient[df_gradient["method"] == method]
  df_gradient = df_gradient[df_gradient["tract"] == tract]

  # prepare background gradient image
  y = df_gradient["value"].values; Y = np.tile(y, (200, 1)) 
  gradient_indices = np.digitize(Y, gradient_edges) - 1 # 0-based index
  gradient_image = gradient_cmap[gradient_indices,:] 
  gradient_image = (gradient_image * 255).astype(np.uint8)

  # define figure aesthetics variables
  match metric:
    case "DKI-AWF": ylim = (0.05, 0.65); dy = 0.10; yticks = np.arange(ylim[0] + (dy/2), ylim[1], dy)
    case "DKI-FA":  ylim = (0.05, 0.75); dy = 0.10; yticks = np.arange(ylim[0] + (dy/2), ylim[1], dy)
    case "DKI-MD":  ylim = (0.00, 3e-3); dy = 6e-4; yticks = np.arange(ylim[0], ylim[1] + dy, dy)
    case "DKI-MK":  ylim = (0.55, 1.35); dy = 0.10; yticks = np.arange(ylim[0] + (dy/2), ylim[1], dy)
    case "DTI-FA":  ylim = (0.05, 0.75); dy = 0.10; yticks = np.arange(ylim[0] + (dy/2), ylim[1], dy)
    case "DKI-MD":  ylim = (0.00, 3e-3); dy = 6e-4; yticks = np.arange(ylim[0], ylim[1] + dy, dy)

  # figure plotting
  fig, ax = plt.subplots(1, 1, figsize = (7, 5), tight_layout = True)
  ax.imshow(gradient_image, aspect = "auto", origin = "lower", extent = gradient_extent)
  for fazekas in fazekas_list: # for each method and fazekas
    df_plot = df_group[df_group["fazekas"] == fazekas]
    df_plot = (df_plot.groupby(["tract", "node"])["value"]
                      .aggregate(
                        mean = "mean", 
                        sem = lambda x: np.std(x) / np.sqrt(np.sum(~np.isnan(x)))
                      ).reset_index())
    x = df_plot["node"].values; y = df_plot["mean"].values; s = df_plot["sem"].values

    # plot fazekas score ribbon plot
    plot_ribbon(ax, x, y, s, fazekas_cmap[fazekas], label = f"Faz. {fazekas}")
  ax.set_xlim([-4.95, 103.95]); ax.set_ylim(ylim); ax.set_yticks(yticks)
  ax.set_xticks([]); ax.set_ylabel(metric)
  ax.set_title(f"{tract} ({method})")
  ax.legend()
  
  paths_out = op.join(paths_save, dataset, method, metric)
  os.makedirs(paths_out, exist_ok = True)
  
  fig.savefig(op.join(paths_out, f"figure10_{dataset}_{method}_{metric}_{tract}.svg"))
  plt.show(); plt.close()