In [None]:
import os
import os.path as op
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed
from scipy.stats import permutation_test, ttest_rel

sys.path.append(op.abspath(op.join(op.abspath(""), "..")))
from utils.utils import get_contrasts

sns.set_style("ticks")
sns.set_context("talk", font_scale=1.2, rc={"axes.labelpad": 10})
pd.set_option("display.float_format", "{:.3}".format)

warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option("display.float_format", "{:.3}".format)

In [None]:
# Define helper functions
def load_corr_wrapper(path):
    return np.load(path)


def extract_diagonal(mat):
    n_dims = len(mat.shape)
    if n_dims == 3:
        raise ValueError("Input matrix must be 2D")
    return np.array([mat[i, i] for i in range(mat.shape[0])])


def fingerprinting_score(corr_matrix):
    """
    Computes the fingerprinting score from a correlation matrix.

    Parameters:
    corr_matrix (numpy.ndarray): A 2D numpy array containing the correlation matrix.

    Returns:
    float: The fingerprinting score.
    """
    # Initialize a list to store the fingerprinting score for each row
    fp_score = []

    # Iterate over each row in the correlation matrix
    for row_index, row in enumerate(corr_matrix):
        # Check if the maximum value in the row is on the diagonal
        is_max_on_diagonal = np.argmax(row) == row_index

        # Check if the diagonal element is the only occurrence of that value in the row
        is_unique_in_row = np.count_nonzero(row == row[row_index]) == 1

        # If both conditions are met, append True to the fp_score list, otherwise append False
        fp_score.append(is_max_on_diagonal and is_unique_in_row)

    # Calculate the fingerprinting score as the mean of the fp_score list
    return np.mean(fp_score)


def perm_ttest(a, b, num_permutations=1000, seed=42):
    def statistic(a, b):
        return ttest_rel(a, b).statistic

    res = permutation_test(
        (a, b),
        statistic,
        vectorized=False,
        permutation_type="samples",
        alternative="two-sided",
        random_state=seed,
        n_resamples=num_permutations,
    )
    return res.statistic, res.pvalue


def compute_statistics(
    df,
    contrasts,
    main_model="DeepTaskGen",
    compare_models=("Average", "Retest", "Linear Regression"),
):
    def run_permutation_test(cont, model):
        t_stat, p_value = perm_ttest(
            df[(df["Contrast"] == cont) & (df["Method"] == main_model)]["Corr"],
            df[(df["Contrast"] == cont) & (df["Method"] == model)]["Corr"],
        )
        return {
            "Contrast": cont,
            "Model": model,
            "t_stat": t_stat,
            "p_value": p_value,
        }

    perm_results = Parallel(n_jobs=-1)(
        delayed(run_permutation_test)(cont, model)
        for cont in contrasts
        for model in compare_models
    )
    return pd.DataFrame(perm_results)

In [None]:
# Load HCP subjects
ABS_PATH = sys.path[-1]
TEST_SUBJ = op.join(ABS_PATH, "training/data/hcp_test_ids.txt")

RESULTS_PATH = op.join(ABS_PATH, "training/results/figures")
os.makedirs(RESULTS_PATH, exist_ok=True)

CONTRASTS = np.array(
    [f"{contrast[0]} {contrast[2]}" for contrast in np.array(get_contrasts())]
)

INCLUDE_CONTRASTS = (
    "EMOTION FACES-SHAPES",
    "GAMBLING REWARD",
    "WM 2BK-0BK",
    "LANGUAGE MATH-STORY",
    "RELATIONAL REL",
    "SOCIAL TOM-RANDOM",
    "MOTOR AVG",
)

# Average = Dark Gray, DeepTaskGen = Red, Retest = Yellow, Tavor = Blue
PALETTE = {
    "Average": "#A9A9A9",
    "DeepTaskGen": "#AE3033",
    "Retest": "#FBDF4F",
    "Linear Regression": "#283F94",
}

In [None]:
# Train-test correlation matrices.
PRED_BY_MODEL = {
    "Average": load_corr_wrapper(
        os.path.join(ABS_PATH, "training/results/corr_scores.npy")
    ),
    "Retest": load_corr_wrapper(
        os.path.join(ABS_PATH, "training/data/corr_scores.npy")
    ),
    "Linear Regression": load_corr_wrapper(
        os.path.join(ABS_PATH, "training/results/tavor/corr_scores.npy")
    ),
    "DeepTaskGen": load_corr_wrapper(
        os.path.join(
            ABS_PATH,
            "training/results/unetminimal_100_0.001/corr_scores.npy",
        )
    ),
}


# Prepare HCP Results DataFrame
def prepare_results_df():
    corr_df = pd.DataFrame()
    fp_df = pd.DataFrame()
    for model in PRED_BY_MODEL.keys():
        for c, cont in enumerate(CONTRASTS):
            # Compute Reconstruction Accuracy.
            corr_df = pd.concat(
                [
                    corr_df,
                    pd.DataFrame(
                        extract_diagonal(np.squeeze(PRED_BY_MODEL[model][c])),
                        columns=["Corr"],
                    ).assign(Method=model, Contrast=cont),
                ]
            )
            corr_df.to_csv(
                op.join(RESULTS_PATH, "corr_hcp.csv"),
                index=False,
            )
            # Compute Fingerprinting scores.
            fp_df = pd.concat(
                [
                    fp_df,
                    pd.DataFrame(
                        [fingerprinting_score(np.squeeze(PRED_BY_MODEL[model][c]))],
                        columns=["Fingerprint"],
                    ).assign(Method=model, Contrast=cont),
                ]
            )
            fp_df.to_csv(
                op.join(RESULTS_PATH, "fingerprint_hcp.csv"),
                index=False,
            )
    return corr_df, fp_df


# Prepare dataframes for reconstruction accuracy and discriminability (i.e., fingerprinting score)
corr_df, fp_df = prepare_results_df()

# Save the dataframes into .csv files.
corr_df.groupby(["Contrast", "Method"]).mean().to_csv(
    op.join(RESULTS_PATH, "recon.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)
fp_df.groupby(["Contrast", "Method"]).mean().to_csv(
    op.join(RESULTS_PATH, "disc.csv"),
    float_format="%.3f",
    decimal=",",
    sep=";",
)

In [None]:
## PLOT RESULTS FOR ALL 47 TASK CONTRAST MAPS

# Plot Reconstruction Accuracy
plt.figure(figsize=(40, 15), dpi=300)
sns.boxplot(data=corr_df, x="Contrast", y="Corr", hue="Method", palette=PALETTE)
plt.ylim(-0.2, 1.05)  # Set y-axis limits
plt.legend(loc="upper center", ncol=4)
sns.despine(offset=10, trim=True)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Reconstruction")
plt.savefig(op.join(RESULTS_PATH, "recon_all_maps.svg"), bbox_inches="tight")
plt.show()

# Plot Fingerprinting Score
plt.figure(figsize=(40, 15), dpi=300)
sns.pointplot(data=fp_df, x="Contrast", y="Fingerprint", hue="Method", palette=PALETTE)
plt.ylim(-0.05, 1.05)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.legend(loc="upper center", ncol=4)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Discriminbility")
plt.subplots_adjust(bottom=0.2)
plt.savefig(op.join(RESULTS_PATH, "disc_all_maps.svg"), bbox_inches="tight")
plt.show()

In [None]:
def filter_and_sort_df(df, INCLUDE_CONTRASTS):
    """
    Filters and sorts a DataFrame based on the specified contrasts.
    
    Parameters:
    - corr_df: DataFrame to be filtered and sorted.
    - INCLUDE_CONTRASTS: List of contrasts to include and the order to sort by.
    
    Returns:
    - DataFrame filtered and sorted based on the specified contrasts.
    """
    # Create boolean mask
    mask = df["Contrast"].isin(INCLUDE_CONTRASTS)
    
    # Filter dataframe
    df_filtered = df[mask]
    
    # Convert "Contrast" column to a categorical type with specified order
    df_filtered['Contrast'] = pd.Categorical(df_filtered['Contrast'], categories=INCLUDE_CONTRASTS, ordered=True)
    
    # Sort df_filtered by the "Contrast" column
    df_filtered = df_filtered.sort_values('Contrast')
    
    # Replace spaces with newlines in the "Contrast" column
    df_filtered["Contrast"] = [cont.replace(" ", "\n") for cont in np.array(df_filtered["Contrast"])]
    
    return df_filtered



## PLOT RESULTS ONLY FOR 7 REPRESENTATIVE TASK CONTRAST MAPS
# Filter plot df for only the 7 representative contrasts
# Plot Reconstruction Accuracy
df_filtered = filter_and_sort_df(corr_df, INCLUDE_CONTRASTS)
plt.figure(figsize=(25, 10), dpi=300)
sns.boxplot(data=df_filtered, x="Contrast", y="Corr", hue="Method", palette=PALETTE)
plt.ylim(0, 0.9)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.ylabel("Reconstruction")
plt.legend()
plt.savefig(op.join(RESULTS_PATH, "recon_7_maps.svg"), bbox_inches="tight")
plt.show()

# Plot Discriminability
df_filtered = filter_and_sort_df(fp_df, INCLUDE_CONTRASTS)
plt.figure(figsize=(25, 10), dpi=300)
sns.pointplot(
    data=df_filtered, x="Contrast", y="Fingerprint", hue="Method", palette=PALETTE
)
plt.ylim(-0.05, 1.05)  # Set y-axis limits
sns.despine(offset=10, trim=True)
plt.ylabel("Discriminability")
plt.legend()
plt.subplots_adjust(bottom=0.2)
plt.savefig(op.join(RESULTS_PATH, "disc_7_maps.svg"), bbox_inches="tight")
plt.show()

In [None]:
## Post-hoc Comparisons between models in terms of reconstruction accuracy and discriminability
# Significance is determined using permutation tests with 1000 permutations.

# RECONSTRUCTION ACCURACY
corr_ttest = compute_statistics(corr_df, CONTRASTS)
corr_ttest.to_csv(
    op.join(RESULTS_PATH, "recon_ttest.csv"),
    float_format="%.3f",
    decimal=".",
    sep=";",
)
# Print the number of significant and non-significant results for each model
print("\nReconstruction Accuracy")
for model in ("Average", "Retest", "Tavor"):
    tmp_df = corr_ttest[corr_ttest["Model"] == model]
    print(
        f"Model: {model}, "
        f"Positive: {len(tmp_df.query('t_stat > 0 and p_value < 0.05'))}, "
        f"Negative: {len(tmp_df.query('t_stat < 0 and p_value < 0.05'))}, "
        f"Non-significant: {len(tmp_df.query('p_value >= 0.05'))}"
    )


# FINGERPRINTING SCORE
print("\nDiscriminability")
fp_df = pd.read_csv(op.join(RESULTS_PATH, "disc.csv"), sep=";", decimal=",")
for model in ("Average", "Retest", "Linear Regression"):
    n_better_fingerprint = np.sum(
        fp_df[fp_df["Method"] == "DeepTaskGen"]
        .sort_values("Contrast")["Fingerprint"]
        .values
        > fp_df[fp_df["Method"] == model].sort_values("Contrast")["Fingerprint"].values
    )
    print(
        f"Model: {model}, Number of greater fingerprinting scores: {n_better_fingerprint}"
    )