In [None]:
import os
import os.path as op
import sys
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed
from scipy.stats import permutation_test, ttest_rel
from statsmodels.stats.multitest import fdrcorrection

sys.path.append(op.abspath(op.join(op.abspath(""), "..")))
from utils.utils import get_contrasts

sns.set_style("ticks")
sns.set_context("talk", font_scale=1.2, rc={"axes.labelpad": 10})
pd.set_option("display.float_format", "{:.3}".format)

warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option("display.float_format", "{:.3}".format)

WORKING_DIR = op.join(Path.cwd().parent, "experiments/visualization")
fig_dir = op.join(WORKING_DIR, "figures")
os.makedirs(fig_dir, exist_ok=True)

results_data_path = op.join(WORKING_DIR, "results_data")
os.makedirs(results_data_path, exist_ok=True)

In [None]:
# Define helper functions
def load_corr_wrapper(path):
    return np.load(path)


def extract_diagonal(mat):
    n_dims = len(mat.shape)
    if n_dims == 3:
        raise ValueError("Input matrix must be 2D")
    return np.array([mat[i, i] for i in range(mat.shape[0])])


def fingerprinting_score(corr_matrix):
    """
    Computes the fingerprinting score from a correlation matrix.

    Parameters:
    corr_matrix (numpy.ndarray): A 2D numpy array containing the correlation matrix.

    Returns:
    float: The fingerprinting score.
    """
    # Initialize a list to store the fingerprinting score for each row
    fp_score = []
    diag_score = np.diag(corr_matrix).mean()
    # Iterate over each row in the correlation matrix
    for row_index, row in enumerate(corr_matrix):
        # Check if the maximum value in the row is on the diagonal
        is_max_on_diagonal = np.argmax(row) == row_index

        # Check if the diagonal element is the only occurrence of that value in the row
        is_unique_in_row = np.count_nonzero(row == row[row_index]) == 1

        # If both conditions are met, append True to the fp_score list, otherwise append False
        fp_score.append(is_max_on_diagonal and is_unique_in_row)

    # Calculate the fingerprinting score as the mean of the fp_score list
    return np.mean(fp_score) * diag_score


def diagonality_index(corr):
    n_contrasts, n_subj = corr.shape[0], corr.shape[1]
    diag_index = np.zeros((n_contrasts, n_subj))
    for i in range(n_contrasts):
        mat = corr[i]
        for r in range(mat.shape[0]):
            off_diag_elements = []
            for c in range(mat.shape[1]):
                if r == c:
                    diag_element = mat[r, c]
                else:
                    off_diag_elements.append(mat[r, c])
            diag_index[i, r] = (
                np.mean(diag_element - np.array(off_diag_elements)) * diag_element
            )
    return diag_index


def cliffs_delta(group1, group2):
    """
    Calculate Cliff's Delta effect size between two groups.

    Parameters:
    group1, group2 : array-like
        The two groups to compare

    Returns:
    float : Cliff's Delta (-1 to +1)
    """
    # Convert inputs to numpy arrays for efficient computation
    g1, g2 = np.array(group1), np.array(group2)

    # Use broadcasting to compute all pairwise comparisons at once
    greater = g1[:, None] > g2
    less = g1[:, None] < g2

    # Sum up differences and normalize
    dominance = np.sum(greater) - np.sum(less)
    delta = dominance / (len(g1) * len(g2))

    return delta


def perm_ttest(a, b, num_permutations=1000, seed=42):
    def statistic(a, b):
        return ttest_rel(a, b).statistic

    delta = cliffs_delta(a, b)

    res = permutation_test(
        (a, b),
        statistic,
        vectorized=False,
        permutation_type="samples",
        alternative="two-sided",
        random_state=seed,
        n_resamples=num_permutations,
    )
    return res.statistic, res.pvalue, delta


def compute_statistics(
    df,
    contrasts,
    main_model="DeepTaskGen",
    compare_models=("Average", "Retest", "Linear Regression"),
    metric="Corr",
):
    def run_permutation_test(cont, model):
        t_stat, p_value, cliff, cohen = perm_ttest(
            df[(df["Contrast"] == cont) & (df["Method"] == main_model)][metric],
            df[(df["Contrast"] == cont) & (df["Method"] == model)][metric],
        )
        return {
            "Contrast": cont,
            "Model": model,
            "t_stat": t_stat,
            "p_value": p_value,
            "cliff": cliff,
        }

    perm_results = Parallel(n_jobs=-1)(
        delayed(run_permutation_test)(cont, model)
        for cont in contrasts
        for model in compare_models
    )
    perm_results = pd.DataFrame(perm_results)
    perm_results["p_value_fdr"] = fdrcorrection(perm_results["p_value"])[1]
    return perm_results

In [None]:
# Load HCP subjects
ABS_PATH = sys.path[-1]
SUBJ_LIST = np.genfromtxt(
    op.join(WORKING_DIR, "training/data/hcp_test_ids.txt"), dtype=str
)
TEST_SUBJ = op.join(ABS_PATH, "training/data/hcp_test_ids.txt")

RESULTS_PATH = op.join(ABS_PATH, "training/results/figures")
os.makedirs(RESULTS_PATH, exist_ok=True)

CORR_SCORES = {
    "Average": op.join(WORKING_DIR, "training/results/corr_scores_group_avg.npy"),
    "Retest": op.join(WORKING_DIR, "training/results/corr_scores_retest.npy"),
    "Linear Regression": op.join(
        WORKING_DIR, "training/results/hcp-ya_tavor/corr_scores_tavor.npy"
    ),
    "DeepTaskGen": op.join(
        WORKING_DIR,
        "training/results-experiment/attentionunet_100_0.001_gm/corr_scores_deeptaskgen.npy",
    ),
}

corr_by_model = {}
dice_auc_by_model = {}
for model in CORR_SCORES.keys():
    corr_by_model[model] = np.load(CORR_SCORES[model])
    dice_auc_by_model[model] = np.load(
        CORR_SCORES[model].replace("corr_scores", "dice-auc_scores")
    ).transpose()

INCLUDE_CONTRASTS = (
    "EMOTION FACES-SHAPES",
    "GAMBLING REWARD",
    "WM 2BK-0BK",
    "LANGUAGE MATH-STORY",
    "RELATIONAL REL",
    "SOCIAL TOM-RANDOM",
    "MOTOR AVG",
)

CONTRASTS = get_contrasts()
CONTRASTS = np.array([f"{contrast[0]} {contrast[2]}" for contrast in CONTRASTS])

In [None]:
corr_df = pd.DataFrame()
fp_df = pd.DataFrame()
fp_norm_df = pd.DataFrame()
dice_auc_df = pd.DataFrame()
diag_index_df = pd.DataFrame()
rel_diag_index_df = pd.DataFrame()
diag_index_norm_df = pd.DataFrame()
for model in corr_by_model.keys():
    diag_index = diagonality_index(corr_by_model[model])
    for c, cont in enumerate(CONTRASTS):
        temp_df = pd.DataFrame(
            {
                "Subject": SUBJ_LIST.flatten(),
                "Reconstruction Accuracy": extract_diagonal(
                    np.squeeze(corr_by_model[model][c])
                ).flatten(),
            }
        ).assign(Method=model, Contrast=cont)
        corr_df = pd.concat([corr_df, temp_df], ignore_index=True)
        temp_df = pd.DataFrame(
            {
                "Subject": SUBJ_LIST.flatten(),
                "Fingerprint": fingerprinting_score(
                    np.squeeze(corr_by_model[model][c])
                ),
            }
        ).assign(Method=model, Contrast=cont)
        fp_df = pd.concat([fp_df, temp_df], ignore_index=True)
        temp_df = pd.DataFrame(
            {
                "Subject": SUBJ_LIST.flatten(),
                "Dice AUC": dice_auc_by_model[model][c],
            }
        ).assign(Method=model, Contrast=cont)
        dice_auc_df = pd.concat([dice_auc_df, temp_df], ignore_index=True)
        temp_df = pd.DataFrame(
            {
                "Subject": SUBJ_LIST.flatten(),
                "Diagonality Index": diag_index[c],
            }
        ).assign(Method=model, Contrast=cont)
        diag_index_df = pd.concat([diag_index_df, temp_df], ignore_index=True)

# Save all dataframes to csv files
corr_df.to_csv(
    op.join(results_data_path, "corr_hcp.csv"),
    index=False,
)
fp_df.to_csv(
    op.join(results_data_path, "fp_hcp.csv"),
    index=False,
)
dice_auc_df.to_csv(
    op.join(results_data_path, "dice_auc_hcp.csv"),
    index=False,
)
diag_index_df.to_csv(
    op.join(results_data_path, "diag_index_hcp.csv"),
    index=False,
)

In [None]:
# # Plot HCP Results
corr_df = pd.read_csv(op.join(results_data_path, "corr_hcp.csv"))
fp_df = pd.read_csv(op.join(results_data_path, "fp_hcp.csv"))
dice_auc_df = pd.read_csv(op.join(results_data_path, "dice_auc_hcp.csv"))
diag_index_df = pd.read_csv(op.join(results_data_path, "diag_index_hcp.csv"))
# Average = Dark Gray, BrainVolCNN = Red, Retest = Yellow, Linear Regression = Blue
PALETTE = {
    "Average": "#A9A9A9",
    "Retest": "#FBDF4F",
    "Linear Regression": "#283F94",
    "DeepTaskGen": "#AE3033",
}

In [None]:
## PLOT RESULTS FOR ALL 47 TASK CONTRAST MAPS

# Plot Reconstruction Accuracy
plt.figure(figsize=(40, 15), dpi=300)
sns.boxplot(data=corr_df, x="Contrast", y="Corr", hue="Method", palette=PALETTE)
plt.ylim(-0.2, 1.05)  # Set y-axis limits
plt.legend(loc="upper center", ncol=4)
sns.despine(offset=10, trim=True)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Reconstruction")
plt.savefig(op.join(RESULTS_PATH, "recon_all_maps.pdf"), bbox_inches="tight")
plt.show()

# Plot Dice AUC Score
plt.figure(figsize=(40, 15), dpi=300)
sns.boxplot(data=dice_auc_df, x="Contrast", y="Dice AUC", hue="Method", palette=PALETTE)
plt.ylim(0.10, 0.36)  # Set y-axis limits
plt.legend(loc="upper center", ncol=4)
sns.despine(offset=10, trim=True)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Dice AUC")
plt.savefig(op.join(RESULTS_PATH, "dice_auc_all_maps.pdf"), bbox_inches="tight")
plt.show()

# Plot Fingerprinting Score
plt.figure(figsize=(40, 15), dpi=300)
sns.pointplot(data=fp_df, x="Contrast", y="Fingerprint", hue="Method", palette=PALETTE)
plt.ylim(-0.05, 1.05)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.legend(loc="upper center", ncol=4)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Discriminbility")
plt.subplots_adjust(bottom=0.2)
plt.savefig(op.join(RESULTS_PATH, "disc_all_maps.pdf"), bbox_inches="tight")
plt.show()

# Plot Diagonality Index
plt.figure(figsize=(40, 15), dpi=300)
ax = sns.pointplot(
    data=diag_index_df,
    x="Contrast",
    y="Diagonality Index",
    hue="Method",
    palette=PALETTE,
)
ax.set_yscale("symlog", linthresh=0.05)
ax.set(
    ylim=(-0.001, 0.25),  # Slightly beyond your data range for padding
    yticks=[
        0,
        0.03,
        0.05,
        0.1,
        0.25,
    ],  # Meaningful tick marks
)
sns.despine(offset=10, trim=True)
plt.legend(loc="upper center", ncol=4).set_visible(False)
plt.xticks(rotation=45, ha="right")
plt.savefig(op.join(fig_dir, "diag_index_all_maps.pdf"), bbox_inches="tight")
plt.show()

In [None]:
def filter_and_sort_df(df, INCLUDE_CONTRASTS):
    """
    Filters and sorts a DataFrame based on the specified contrasts.

    Parameters:
    - corr_df: DataFrame to be filtered and sorted.
    - INCLUDE_CONTRASTS: List of contrasts to include and the order to sort by.

    Returns:
    - DataFrame filtered and sorted based on the specified contrasts.
    """
    # Create boolean mask
    mask = df["Contrast"].isin(INCLUDE_CONTRASTS)

    # Filter dataframe
    df_filtered = df[mask]

    # Convert "Contrast" column to a categorical type with specified order
    df_filtered["Contrast"] = pd.Categorical(
        df_filtered["Contrast"], categories=INCLUDE_CONTRASTS, ordered=True
    )

    # Sort df_filtered by the "Contrast" column
    df_filtered = df_filtered.sort_values("Contrast")

    # Replace spaces with newlines in the "Contrast" column
    df_filtered["Contrast"] = [
        cont.replace(" ", "\n") for cont in np.array(df_filtered["Contrast"])
    ]

    return df_filtered


## PLOT RESULTS ONLY FOR 7 REPRESENTATIVE TASK CONTRAST MAPS
# Filter plot df for only the 7 representative contrasts
# Plot Reconstruction Accuracy
plt.figure(figsize=(25, 10), dpi=300)
sns.boxplot(
    data=filter_and_sort_df(corr_df, INCLUDE_CONTRASTS),
    x="Contrast",
    y="Corr",
    hue="Method",
    palette=PALETTE,
)
plt.ylim(-0.2, 1.05)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.ylabel("Reconstruction")
plt.legend()
plt.savefig(op.join(RESULTS_PATH, "recon_7_maps.pdf"), bbox_inches="tight")
plt.show()

# Plot Dice AUC Score
plt.figure(figsize=(25, 10), dpi=300)
sns.boxplot(
    data=filter_and_sort_df(dice_auc_df, INCLUDE_CONTRASTS),
    x="Contrast",
    y="Dice AUC",
    hue="Method",
    palette=PALETTE,
)
plt.ylim(0.10, 0.36)  # Set y-axis limits
plt.legend(loc="upper center", ncol=4)
sns.despine(offset=10, trim=True)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Dice AUC")
plt.savefig(op.join(RESULTS_PATH, "dice_auc_7_maps.pdf"), bbox_inches="tight")
plt.show()

# Plot Fingerprinting Score
plt.figure(figsize=(25, 10), dpi=300)
sns.pointplot(
    data=filter_and_sort_df(fp_df, INCLUDE_CONTRASTS),
    x="Contrast",
    y="Fingerprint",
    hue="Method",
    palette=PALETTE,
)
plt.ylim(-0.05, 1.05)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.ylabel("Fingerprinting Score")
plt.legend()
plt.subplots_adjust(bottom=0.2)
plt.savefig(op.join(RESULTS_PATH, "disc_7_maps.pdf"), bbox_inches="tight")
plt.show()


# Plot Diagonality Index
plt.figure(figsize=(25, 10), dpi=300)
ax = sns.pointplot(
    data=filter_and_sort_df(diag_index_df, INCLUDE_CONTRASTS),
    x="Contrast",
    y="Diagonality Index",
    hue="Method",
    palette=PALETTE,
)
ax.set_yscale("symlog", linthresh=0.05)
ax.set(
    ylim=(-0.001, 0.25),  # Slightly beyond your data range for padding
    yticks=[
        0,
        0.03,
        0.05,
        0.1,
        0.25,
    ],  # Meaningful tick marks
)
sns.despine(offset=10, trim=True)
plt.legend(loc="upper center", ncol=4).set_visible(False)
plt.xticks(rotation=45, ha="right")
plt.savefig(op.join(fig_dir, "diag_index_7_maps.pdf"), bbox_inches="tight")
plt.show()

In [None]:
# First save the tables into .csv files.
os.makedirs(op.join(WORKING_DIR, "tables/"), exist_ok=True)
corr_df.groupby(["Contrast", "Method"]).describe().to_csv(
    op.join(WORKING_DIR, "tables/corr_hcp.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)
fp_df.groupby(["Contrast", "Method"]).describe().to_csv(
    op.join(WORKING_DIR, "tables/fp_hcp.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)
dice_auc_df.groupby(["Contrast", "Method"]).describe().to_csv(
    op.join(WORKING_DIR, "tables/dice_auc_hcp.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)
diag_index_df.groupby(["Contrast", "Method"]).describe().to_csv(
    op.join(WORKING_DIR, "tables/diag_index_hcp.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)

# Post-hoc Comparisons between models in terms of various performance metrics
# Significance is determined using permutation tests with 1000 permutations.
compute_statistics(corr_df, CONTRASTS, metric="Reconstruction Accuracy").to_csv(
    op.join(WORKING_DIR, "tables/corr_hcp_ttest.csv"),
    float_format="%.3f",
    decimal=".",
    sep=";",
)
compute_statistics(dice_auc_df, CONTRASTS, metric="Dice AUC").to_csv(
    op.join(WORKING_DIR, "tables/dice_auc_hcp_ttest.csv"),
    float_format="%.3f",
    decimal=".",
    sep=";",
)
compute_statistics(diag_index_df, CONTRASTS, metric="Diagonality Index").to_csv(
    op.join(WORKING_DIR, "tables/diag_index_hcp_ttest.csv"),
    float_format="%.3f",
    decimal=".",
    sep=";",
)

In [None]:
# Print significant performance differences between models
corr_ttest = pd.read_csv(
    op.join(WORKING_DIR, "tables/corr_hcp_ttest.csv"),
    sep=";",
    decimal=".",
    index_col=None,
)
dice_ttest = pd.read_csv(
    op.join(WORKING_DIR, "tables/dice_auc_hcp_ttest.csv"),
    sep=";",
    decimal=".",
    index_col=None,
)
diag_ttest = pd.read_csv(
    op.join(WORKING_DIR, "tables/diag_index_hcp_ttest.csv"),
    sep=";",
    decimal=".",
    index_col=None,
)

print("Reconstruction Accuracy")
for model in ("Average", "Retest", "Linear Regression"):
    tmp_df = corr_ttest[corr_ttest["Model"] == model]
    print(
        f"Model: {model}, "
        f"Positive: {len(tmp_df.query('t_stat > 0 and p_value_fdr < 0.05'))}, "
        f"Negative: {len(tmp_df.query('t_stat < 0 and p_value_fdr < 0.05'))}, "
        f"Non-significant: {len(tmp_df.query('p_value_fdr >= 0.05'))}"
    )
print("\nDice AUC")
for model in ("Average", "Retest", "Linear Regression"):
    tmp_df = dice_ttest[dice_ttest["Model"] == model]
    print(
        f"Model: {model}, "
        f"Positive: {len(tmp_df.query('t_stat > 0 and p_value_fdr < 0.05'))}, "
        f"Negative: {len(tmp_df.query('t_stat < 0 and p_value_fdr < 0.05'))}, "
        f"Non-significant: {len(tmp_df.query('p_value_fdr >= 0.05'))}"
    )
print("\nDiagonality Index")
for model in ("Average", "Retest", "Linear Regression"):
    tmp_df = diag_ttest[diag_ttest["Model"] == model]
    print(
        f"Model: {model}, "
        f"Positive: {len(tmp_df.query('t_stat > 0 and p_value_fdr < 0.05'))}, "
        f"Negative: {len(tmp_df.query('t_stat < 0 and p_value_fdr < 0.05'))}, "
        f"Non-significant: {len(tmp_df.query('p_value_fdr >= 0.05'))}"
    )
print("\nFingerprinting")
for model in ("Average", "Retest", "Linear Regression"):
    n_better_fingerprint = np.sum(
        fp_df[fp_df["Method"] == "DeepTaskGen"]
        .sort_values("Contrast")["Fingerprint"]
        .values
        > fp_df[fp_df["Method"] == model].sort_values("Contrast")["Fingerprint"].values
    )
    print(
        f"Model: {model}, Number of greater fingerprinting scores: {n_better_fingerprint}"
    )