In [None]:
import os
import os.path as op
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed
from scipy.stats import permutation_test, ttest_rel
from statsmodels.stats.multitest import fdrcorrection

sns.set_style("ticks")
sns.set_context("talk", font_scale=1.2, rc={"axes.labelpad": 10})
pd.set_option("display.float_format", "{:.3}".format)

warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option("display.float_format", "{:.3}".format)

In [None]:
# Define helper functions
def load_corr_wrapper(path):
    return np.load(path)


def extract_diagonal(mat):
    n_dims = len(mat.shape)
    if n_dims == 3:
        raise ValueError("Input matrix must be 2D")
    return np.array([mat[i, i] for i in range(mat.shape[0])])


def fingerprinting_score(corr_matrix):
    """
    Computes the fingerprinting score from a correlation matrix.

    Parameters:
    corr_matrix (numpy.ndarray): A 2D numpy array containing the correlation matrix.

    Returns:
    float: The fingerprinting score.
    """
    # Initialize a list to store the fingerprinting score for each row
    fp_score = []
    diag_score = np.diag(corr_matrix).mean()
    # Iterate over each row in the correlation matrix
    for row_index, row in enumerate(corr_matrix):
        # Check if the maximum value in the row is on the diagonal
        is_max_on_diagonal = np.argmax(row) == row_index

        # Check if the diagonal element is the only occurrence of that value in the row
        is_unique_in_row = np.count_nonzero(row == row[row_index]) == 1

        # If both conditions are met, append True to the fp_score list, otherwise append False
        fp_score.append(is_max_on_diagonal and is_unique_in_row)

    # Calculate the fingerprinting score as the mean of the fp_score list
    return np.mean(fp_score) * diag_score


def diagonality_index(corr):
    n_contrasts, n_subj = corr.shape[0], corr.shape[1]
    diag_index = np.zeros((n_contrasts, n_subj))
    for i in range(n_contrasts):
        mat = corr[i]
        for r in range(mat.shape[0]):
            off_diag_elements = []
            for c in range(mat.shape[1]):
                if r == c:
                    diag_element = mat[r, c]
                else:
                    off_diag_elements.append(mat[r, c])
            diag_index[i, r] = (
                np.mean(diag_element - np.array(off_diag_elements)) * diag_element
            )
    return diag_index


def cliffs_delta(group1, group2):
    """
    Calculate Cliff's Delta effect size between two groups.

    Parameters:
    group1, group2 : array-like
        The two groups to compare

    Returns:
    float : Cliff's Delta (-1 to +1)
    """
    # Convert inputs to numpy arrays for efficient computation
    g1, g2 = np.array(group1), np.array(group2)

    # Use broadcasting to compute all pairwise comparisons at once
    greater = g1[:, None] > g2
    less = g1[:, None] < g2

    # Sum up differences and normalize
    dominance = np.sum(greater) - np.sum(less)
    delta = dominance / (len(g1) * len(g2))

    return delta


def perm_ttest(a, b, num_permutations=1000, seed=42):
    def statistic(a, b):
        return ttest_rel(a, b).statistic

    delta = cliffs_delta(a, b)

    res = permutation_test(
        (a, b),
        statistic,
        vectorized=False,
        permutation_type="samples",
        alternative="two-sided",
        random_state=seed,
        n_resamples=num_permutations,
    )
    return res.statistic, res.pvalue, delta


def compute_statistics(
    df,
    contrasts,
    main_model="DeepTaskGen",
    compare_models=("Average", "Retest", "Linear Regression"),
    metric="Corr",
):
    def run_permutation_test(cont, model):
        t_stat, p_value, cliff, cohen = perm_ttest(
            df[(df["Contrast"] == cont) & (df["Method"] == main_model)][metric],
            df[(df["Contrast"] == cont) & (df["Method"] == model)][metric],
        )
        return {
            "Contrast": cont,
            "Model": model,
            "t_stat": t_stat,
            "p_value": p_value,
            "cliff": cliff,
        }

    perm_results = Parallel(n_jobs=-1)(
        delayed(run_permutation_test)(cont, model)
        for cont in contrasts
        for model in compare_models
    )
    perm_results = pd.DataFrame(perm_results)
    perm_results["p_value_fdr"] = fdrcorrection(perm_results["p_value"])[1]
    return perm_results

In [None]:
# Load HCP subjects
ABS_PATH = op.abspath(op.join(os.getcwd(), "../.."))
TEST_SUBJ = op.join(
    ABS_PATH, "transfer_learning/hcp_development/data/hcpd_test_ids.txt"
)

RESULTS_PATH = op.join(ABS_PATH, "transfer_learning/hcp_development/results/figures")
os.makedirs(RESULTS_PATH, exist_ok=True)

INCLUDE_CONTRASTS = (
    "EMOTION FACES-SHAPES",
    "GAMBLING REWARD",
)

# DeepTaskGen – Finetuned = Red, DeepTaskGen - NoFinetune = Yellow, Tavor = Blue
PALETTE = {
    "Finetune": "#AE3033",
    "No Finetune": "#FBDF4F",
    "Linear Regression": "#283F94",
}

In [None]:
PRED_BY_MODEL = {
    "EMOTION\nFACES-SHAPES": {
        "Linear Regression": op.join(
            ABS_PATH,
            "transfer_learning/hcp_development/results/tavor/corr_scores_emotion_faces-shapes.npy",
        ),
        "No Finetune": op.join(
            ABS_PATH,
            "transfer_learning/hcp_development/results/nofinetune-attentionunet/corr_scores_emotion_faces-shapes.npy",
        ),
        "Finetune": op.join(
            ABS_PATH,
            "transfer_learning/hcp_development/results/finetuned_50_0.001_gambling-reward_attentionunet_gm_20/corr_scores_emotion_faces-shapes.npy",
        ),
    },
    "GAMBLING\nREWARD": {
        "Linear Regression": op.join(
            ABS_PATH,
            "transfer_learning/hcp_development/results/tavor/corr_scores_gambling_reward.npy",
        ),
        "No Finetune": op.join(
            ABS_PATH,
            "transfer_learning/hcp_development/results/nofinetune-attentionunet/corr_scores_gambling_reward.npy",
        ),
        "Finetune": op.join(
            ABS_PATH,
            "transfer_learning/hcp_development/results/finetuned_50_0.001_emotion-faces-shapes_attentionunet_gm_20/corr_scores_gambling_reward.npy",
        ),
    },
}
corr_by_model = {}
dice_auc_by_model = {}
for cont in PRED_BY_MODEL.keys():
    corr_by_model[cont] = {}
    dice_auc_by_model[cont] = {}
    for model in PRED_BY_MODEL[cont].keys():
        corr_by_model[cont][model] = np.load(PRED_BY_MODEL[cont][model])
        dice_auc_by_model[cont][model] = np.load(
            PRED_BY_MODEL[cont][model].replace("corr_scores", "dice_auc")
        ).transpose()

In [None]:
# Faces
corr_df = pd.DataFrame()
fp_df = pd.DataFrame()
fp_norm_df = pd.DataFrame()
dice_auc_df = pd.DataFrame()
diag_index_df = pd.DataFrame()
diag_index_norm_df = pd.DataFrame()
for cont in PRED_BY_MODEL.keys():
    for model in PRED_BY_MODEL[cont].keys():
        corr_df = pd.concat(
            [
                corr_df,
                pd.DataFrame(
                    extract_diagonal(np.squeeze(corr_by_model[cont][model])).T,
                    columns=["Reconstruction Accuracy"],
                ).assign(Method=model, Contrast=cont),
            ]
        )
        fp_df = pd.concat(
            [
                fp_df,
                pd.DataFrame(
                    [fingerprinting_score(np.squeeze(corr_by_model[cont][model]))],
                    columns=["Fingerprint"],
                ).assign(Method=model, Contrast=cont),
            ]
        )
        diag_index_df = pd.concat(
            [
                diag_index_df,
                pd.DataFrame(
                    diagonality_index(
                        np.expand_dims(corr_by_model[cont][model], axis=0)
                    ).reshape(-1, 1),
                    columns=["Diagonality Index"],
                ).assign(Method=model, Contrast=cont),
            ]
        )
        dice_auc_df = pd.concat(
            [
                dice_auc_df,
                pd.DataFrame(
                    dice_auc_by_model[cont][model].reshape(-1, 1),
                    columns=["Dice AUC"],
                ).assign(Method=model, Contrast=cont),
            ]
        )

results_data_path = op.join(RESULTS_PATH, "results_data")
os.makedirs(results_data_path, exist_ok=True)
corr_df.to_csv(
    op.join(results_data_path, "corr_hcpd.csv"),
    index=False,
)
dice_auc_df.to_csv(
    op.join(results_data_path, "dice_auc_hcpd.csv"),
    index=False,
)
fp_df.to_csv(
    op.join(results_data_path, "fp_hcpd.csv"),
    index=False,
)
diag_index_df.to_csv(
    op.join(results_data_path, "diag_index_hcpd.csv"),
    index=False,
)


In [None]:
# Plot Reconstruction Accuracy
plt.figure(figsize=(14, 10), dpi=300)
sns.boxplot(
    data=corr_df,
    x="Contrast",
    y="Corr",
    hue="Method",
    palette=PALETTE,
    flierprops={"marker": "d", "markerfacecolor": "black", "markersize": 5},
)
plt.ylim(-0.2, 0.8)  # Set y-axis limits
plt.ylabel("Reconstruction")
plt.legend()
sns.despine(offset=10, trim=True)
plt.savefig(op.join(RESULTS_PATH, "recon.pdf"), bbox_inches="tight")
plt.show()

# Plot Dice AUC
plt.figure(figsize=(14, 10), dpi=300)
sns.boxplot(
    data=dice_auc_df,
    x="Contrast",
    y="Dice AUC",
    hue="Method",
    palette=PALETTE,
    flierprops={"marker": "d", "markerfacecolor": "black", "markersize": 5},
).set()
plt.ylim(0.05, 0.30)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.legend(loc="upper center", ncol=4).set_visible(False)
plt.xticks(rotation=45, ha="right")
plt.savefig(op.join(RESULTS_PATH, "dice_auc.pdf"), bbox_inches="tight")
plt.show()

# Plot Fingerprinting Score
plt.figure(figsize=(14, 10), dpi=300)
sns.pointplot(data=fp_df, x="Contrast", y="Fingerprint", hue="Method", palette=PALETTE)
plt.ylim(0, 1)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.legend()
plt.ylabel("Fingerprinting Score")
plt.subplots_adjust(bottom=0.2)
plt.savefig(op.join(RESULTS_PATH, "fp.pdf"), bbox_inches="tight")
plt.show()

# Plot Diagonality Index
plt.figure(figsize=(14, 10), dpi=300)
ax = sns.boxplot(
    data=diag_index_df,
    x="Contrast",
    y="Diagonality Index",
    hue="Method",
    palette=PALETTE,
    flierprops={"marker": "d", "markerfacecolor": "black", "markersize": 5},
)
ax.set(
    yticks=[
        -0.05,
        0,
        0.05,
        0.1,
        0.15,
    ],  # Meaningful tick marks
)
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.ylim(-0.05, 0.151)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.xlabel("")
plt.ylabel("")
plt.legend(loc="upper center", ncol=4).set_visible(False)
plt.xticks(rotation=45, ha="right")
plt.savefig(op.join(RESULTS_PATH, "diag_index.pdf"), bbox_inches="tight")
plt.show()

In [None]:
# First save the tables into .csv files.
corr_df.groupby(["Contrast", "Method"]).describe().to_csv(
    op.join(RESULTS_PATH, "tables/corr_hcpd.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)
dice_auc_df.groupby(["Contrast", "Method"]).describe().to_csv(
    op.join(RESULTS_PATH, "tables/dice_auc_hcpd.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)
fp_df.groupby(["Contrast", "Method"]).mean().to_csv(
    op.join(RESULTS_PATH, "tables/fp_hcpd.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)
diag_index_df.groupby(["Contrast", "Method"]).mean().to_csv(
    op.join(RESULTS_PATH, "tables/diag_index_hcpd.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)

# Post-hoc Comparisons between models in terms of various performance metrics
# Significance is determined using permutation tests with 1000 permutations.
target_contrasts = ["EMOTION\nFACES-SHAPES", "GAMBLING\nREWARD"]
target_models = ("Linear Regression", "No Finetune")
compute_statistics(  # Reconstruction Accuracy
    corr_df,
    target_contrasts,
    main_model="Finetune",
    compare_models=target_models,
    metric="Reconstruction Accuracy",
).to_csv(
    op.join(RESULTS_PATH, "tables/corr_hcpd_ttest.csv"),
    float_format="%.3f",
    index=False,
    decimal=",",
    sep=";",
)
compute_statistics(  # Dice AUC
    dice_auc_df,
    target_contrasts,
    main_model="Finetune",
    compare_models=target_models,
    metric="Dice AUC",
).to_csv(
    op.join(RESULTS_PATH, "tables/dice_auc_hcpd_ttest.csv"),
    float_format="%.3f",
    index=False,
    decimal=",",
    sep=";",
)
compute_statistics(  # Diagonality Index
    diag_index_df,
    target_contrasts,
    main_model="Finetune",
    compare_models=target_models,
    metric="Diagonality Index",
).to_csv(
    op.join(RESULTS_PATH, "tables/diag_index_hcpd_ttest.csv"),
    float_format="%.3f",
    index=False,
    decimal=",",
    sep=";",
)

In [None]:
# Print results
print("Reconstruction Accuracy")
print(
    pd.read_csv(
        op.join(RESULTS_PATH, "tables/corr_hcpd_ttest.csv"),
        sep=";",
        decimal=",",
        index_col=None,
    )
)
print("\nDice AUC")
print(
    pd.read_csv(
        op.join(RESULTS_PATH, "tables/dice_auc_hcpd_ttest.csv"),
        sep=";",
        decimal=",",
        index_col=None,
    )
)
print("\nFingerprinting")
print(
    pd.read_csv(
        op.join(RESULTS_PATH, "tables/fp_hcpd.csv"),
        sep=";",
        decimal=",",
        index_col=None,
    )
)
print("\nDiagonality Index")
print(
    pd.read_csv(
        op.join(RESULTS_PATH, "tables/diag_index_hcpd_ttest.csv"),
        sep=";",
        decimal=",",
        index_col=None,
    )
)
