In [None]:
import os
import re
import matplotlib
import numpy as np
import pandas as pd
import os.path as op
import seaborn as sns
import scipy.stats as stats 
import matplotlib.pyplot as plt

from utils import TRACT_DICT

In [None]:
plt.rcParams.update({
  "text.usetex": False,
  "font.family": "Helvetica",
  "font.size": 14
})

In [None]:
paths_data = op.join("/path", "to", "data")
paths_save = op.join("paths", "to", "figure04")
os.makedirs(paths_save, exist_ok = True)

In [None]:
df = pd.read_csv(op.join(paths_data, "tract_dice.csv"))
df = df[df["dataset"].isin(["multi-shell", "single-shell"])]
df["tract"]      = [TRACT_DICT[x] for x in df["tract"]]
df["hemisphere"] = [re.sub("(Left|Right) (\w+)", "\\1", x) for x in df["tract"]]
df["tract"]      = [re.sub("(Left|Right) (\w+)", "\\2", x) for x in df["tract"]]
df = (df.groupby(["participant", "dataset", "method", "tract"])["dice"]
        .mean().reset_index()) # collapse across hemispheres by averaging dice
df.head()

In [None]:
difference_dict = {
  "original-fwe":  ["afq-original", "afq-fwe"],
  "original-msmt": ["afq-original", "afq-msmt"],
}

df_tract = [] # initialize
for (participant, dataset, tract), df_group in \
  df.groupby(["participant", "dataset", "tract"]):

  df_curr = {"participant": participant, "dataset": dataset, "tract": tract} # initialize
  for key, methods in difference_dict.items(): # for each difference metric
    if np.sum(df_group["method"].isin(methods)) == 2: # if both methods exist
      source = df_group[df_group["method"] == methods[0]]["dice"].values[0]
      target = df_group[df_group["method"] == methods[1]]["dice"].values[0]
      
      # calculate percent difference
      df_curr[key] = ((target - source) / ((source + target) / 2)) * 100 
      # df_curr[key] = (target - source)  # calculate difference

  df_tract.append(df_curr) # append to list
    
df_tract = pd.DataFrame(df_tract)
df_tract.head()

In [None]:
stats_kwargs = {
  "popmean": 0, 
  "nan_policy": "omit",
}

cmap = matplotlib.colormaps["tab20"]
cmap = cmap(np.linspace(0, 1, num = 20))

id_cols   = ["participant", "dataset", "tract"]
diff_dict = {
  "original-fwe":  cmap[6], # red
  "original-msmt": cmap[0]  # blue
}

alpha = 0.05 # significance level, pre-bonferonni
alpha_corrected = alpha / np.unique(df_tract["tract"]).shape[0]

x_sig     = 5  # pad between sem and %diff sig marker
x_comp    = 10 # %diff sig marker and comparison bar
x_sigcomp = 3  # comparison bar and comparison sig marker

y_cond   = 0.2  # y-axis condition padding
y_sigbar = 0.25 # significance bar +/- width
y_sigbar = np.array([-y_sigbar, y_sigbar])

xlim  = (-55, 140) # x-axis limits

diff_cols = list(diff_dict.keys())
keep_cols = id_cols + diff_cols
diff_dict = {re.sub("original-", "", k).upper(): v for k, v in diff_dict.items()}

for dataset, df_group in df_tract.groupby("dataset"): # for each dataset
  df_group  = df_group[keep_cols] # subset to keep columns
  df_plot   = df_group.melt(id_vars = id_cols, var_name = "difference", value_name = "dice")
  df_plot["difference"] = df_plot["difference"].str.replace("original-", "")
  df_plot["difference"] = df_plot["difference"].str.upper()
  df_plot   = df_plot[~np.isinf(df_plot["dice"])]
  df_plot   = df_plot[~np.isnan(df_plot["dice"])]  
  trk_order = (df_plot[df_plot["difference"] == "FWE"]
                .groupby("tract")["dice"].mean()
                .sort_values(ascending = False).index.to_list())
                
  fig, ax = plt.subplots(1, 1, figsize = (8, 10), tight_layout = True)
  ax.axvline(x = 0, color = "black", linestyle = "--")
  sns.barplot(data = df_plot, x = "dice", y = "tract", hue = "difference",
              palette = diff_dict, order = trk_order, errorbar = "se", ax = ax)
  
  for trk in trk_order: # for each tract    
    df_trk = df_group[df_group["tract"] == trk] # subset by tract
    
    diff_sig = np.array([True, True]) # initialize
    for i, diff in enumerate(diff_cols): # for each difference metric
      trk_values = df_trk[diff] # get values for current difference value
      
      results = stats.ttest_1samp(trk_values, **stats_kwargs)
      if results.pvalue < alpha_corrected: sig_str = "*"; y_sig = 0.1;  # bonferonni significant
      elif results.pvalue < alpha: sig_str = "+"; y_sig = -0.015 # significant
      else : sig_str = ""; y_sig = 0; diff_sig[i] = False # not significant

      if (results.pvalue < alpha_corrected) or (results.pvalue < alpha):
        x_avg = np.nanmean(trk_values) # bar height
        x_sem = stats.sem(trk_values, nan_policy = "omit") # sem value
        x_sem = x_sem if x_avg > 0 else -x_sem # sign adjustment for sem value
        x_adj = x_sig if x_avg > 0 else -x_sig # sign adjustment for sig. marker
        x_height = x_avg + x_sem + x_adj
        
        y_avg = trk_order.index(trk) # y-axis position
        y_adj = -y_cond if diff == diff_cols[0] else y_cond
        y_height = y_avg + y_adj + y_sig
        
        diff_label = diff.replace("original-", "").upper()
        ax.text(x = x_height, y = y_height, s = sig_str, 
                color = diff_dict[diff_label], ha = "center", va = "center")
      
    x_diff = df_trk[diff_cols[0]].values # fwe
    y_diff = df_trk[diff_cols[1]].values # msmt 

    results = stats.wilcoxon(y_diff, x_diff, nan_policy = "omit")
    if results.pvalue < alpha_corrected: sig_str = "*"; y_sig = 0.1 # bonferonni significant
    elif results.pvalue < alpha: sig_str = "+"; y_sig = -0.015 # significant
    else : sig_str = ""; y_sig = 0 # not significant
    
    if (results.pvalue < alpha_corrected) or (results.pvalue < alpha):
      x_avg = [np.nanmean(x_diff), np.nanmean(y_diff)] # bar heights
      x_sem = [stats.sem(x_diff, nan_policy = "omit"), 
               stats.sem(y_diff, nan_policy = "omit")] # sem values
      x_avg = [x + y if x > 0 else x - y for x, y in zip(x_avg, x_sem)] # sign adjustment
      x_avg = np.max(x_avg) if x_avg[0] > 0 else np.min(x_avg)
      
      if np.any(diff_sig): # if any difference metric is significant
        x_adj1 = x_sig if x_avg > 0 else -x_sig # sign adjustment for sig. marker
      else: x_adj1 = 0 # no sign adjustment
      x_adj2 = x_comp if x_avg > 0 else -x_comp # sign adjustment for comparison bar
      x_adj3 = x_sigcomp if x_avg > 0 else -x_sigcomp # sign adjustment for comparison sig. marker
      x_compbar = x_avg + x_adj1 + x_adj2  
            
      y_avg = trk_order.index(trk) # y-axis position
      ax.plot([x_compbar, x_compbar], y_avg + y_sigbar, "black")
      ax.text(x = x_compbar + x_adj3, y = y_avg + y_sig, s = sig_str, 
              color = "black", ha = "center", va = "center")
    
  ax.set_xlabel(f"Weighted Dice Coefficient\nPercent Difference from Original")
  ax.set_ylabel(""); ax.set_xlim(xlim); ax.set_title(dataset.capitalize())
  ax.legend(title = "Method", loc = "lower right")
  ax.margins(y = 0.01) # adjust y-axis margins
  plt.show()

  fig.savefig(op.join(paths_save, f"figure04_{dataset}_dice.svg"))