In [None]:
import os
import os.path as op
import pickle
import sys
import warnings
from collections import defaultdict
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import permutation_test, ttest_rel
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.metrics import balanced_accuracy_score

sys.path.append(op.abspath(op.join(op.abspath(""), "..")))
from utils.utils import correlation_score

sns.set_style("ticks")
sns.set_context("talk", font_scale=1, rc={"axes.labelpad": 10})
pd.set_option("display.float_format", "{:.3}".format)

warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option("display.float_format", "{:.3}".format)

In [None]:
ABS_PATH = sys.path[-1]
RESULTS_PATH = op.join(ABS_PATH, "3_prediction/results")
FIG_DIR = op.join(ABS_PATH, "3_prediction/figures")
os.makedirs(FIG_DIR, exist_ok=True)

PALETTE = {
    "Actual": "#283F94",
    "Predicted": "#AE3033",
}

CONTRASTS = (
    "REST",
    "EMOTION FACES-SHAPES",
    "GAMBLING REWARD",
    "LANGUAGE MATH-STORY",
    "RELATIONAL REL",
    "SOCIAL TOM-RANDOM",
    "WM 2BK-0BK",
    "MOTOR AVG",
)

CONTRASTS_MAP = {
    "ukb_actual": (
        "rest",
        "emotion_faces-shapes",
    ),
    "ukb_pred": (
        "emotion_faces-shapes",
        "gambling_reward",
        "language_math-story",
        "relational_rel",
        "social_tom-random",
        "wm_2bk-0bk",
        "motor_avg",
    ),
}

SCORE_FUNCS = {
    "age": correlation_score,
    "fluid": correlation_score,
    "sex": balanced_accuracy_score,
    "strength": correlation_score,
    "overall_health": correlation_score,
    "alcohol_freq": correlation_score,
    "beer_freq": correlation_score,
    "depression": balanced_accuracy_score,
    "hypertension": balanced_accuracy_score,
    "neuroticism": correlation_score,
    "GAD": correlation_score,
    "PHQ": correlation_score,
    "RDS": correlation_score,
}

In [None]:
# Turn of version warnings.
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


# Prepare dataframes for plotting.
def results_files_dict():
    pred_scores = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    for target in SCORE_FUNCS.keys():
        for dset in CONTRASTS_MAP:
            for cont in CONTRASTS_MAP[dset]:
                cont = cont.replace(" ", "_").lower()
                try:
                    fname = glob(
                        op.join(
                            RESULTS_PATH,
                            dset,
                            f"{dset}_{target}_{cont}_ridge*_perm.pkl",
                        )
                    )[0]
                    pred_scores[target][dset][cont] = fname
                except:
                    print(f"Could not find {dset}_{target}_{cont}.pkl")
                    pass
    return pred_scores


def compute_scores(pred_scores):
    perm_scores = {}
    cv_scores = {}
    for target in pred_scores:
        perm_scores[target] = {}
        _scores = pd.DataFrame()
        for dset in pred_scores[target]:
            perm_scores[target][dset] = {}
            for i, task in enumerate(pred_scores[target][dset]):
                preds = pickle.load(open(pred_scores[target][dset][task], "rb"))
                if i == 0:
                    print(
                        f"Dataset: {dset}, Target: {target}, n = {np.hstack(preds[0]['y_true']).shape[0]}"
                    )
                __sc = []
                for pred in preds:
                    __sc.append(
                        [
                            SCORE_FUNCS[target](y_true, y_pred)
                            for y_true, y_pred in zip(pred["y_true"], pred["y_pred"])
                        ]
                    )
                perm_scores[target][dset][task] = __sc
                tmp = pd.DataFrame(__sc[0], columns=["CV Score"]).assign(
                    Contrast=task.replace("_", "\n").upper(),
                    Dataset=dset,
                    Actual="Actual"
                    if "finetune" not in dset
                    else dset.split("_")[1].capitalize(),
                )
                _scores = pd.concat(
                    [
                        _scores,
                        tmp,
                    ],
                    axis=0,
                )
        cv_scores[target] = _scores
    return cv_scores, perm_scores


cv_scores, perm_scores = compute_scores(results_files_dict())

In [None]:
## HELPER FUNCTIONS
def comp_perm_sig_cv(perms):
    """Compute permutation significance from given permutation scores."""
    cv = np.mean(perms[0])
    perm_cv = [np.mean(p) for p in perms]
    return cv, (perm_cv >= cv).mean()


def comp_perm_sig_ttest(main, comp):
    """Compute permutation significance from given two sets of permutation scores."""
    t, _ = ttest_rel(main[0], comp[0])
    perm_t = [ttest_rel(m, c)[0] for m, c in zip(main, comp)]
    if t > 0:
        return t, (perm_t >= t).mean()
    else:
        return t, (perm_t <= t).mean()


def cliffs_delta(group1, group2):
    """
    Calculate Cliff's Delta effect size between two groups.

    Parameters:
    group1, group2 : array-like
        The two groups to compare

    Returns:
    float : Cliff's Delta (-1 to +1)
    """
    # Convert inputs to numpy arrays for efficient computation
    g1, g2 = np.array(group1), np.array(group2)

    # Use broadcasting to compute all pairwise comparisons at once
    greater = g1[:, None] > g2
    less = g1[:, None] < g2

    # Sum up differences and normalize
    dominance = np.sum(greater) - np.sum(less)
    delta = dominance / (len(g1) * len(g2))

    return delta


def compare_tasks(perm_scores, target):
    """
    Compare predicted and actual contrasts for a given target.

    Args:
        perm_scores (dict): Dictionary containing permutation scores for different contrasts.
        target (str): Target contrast.

    Returns:
        pandas.DataFrame: DataFrame containing the comparison results.
    """
    tasks_comp = []
    for task in CONTRASTS_MAP["ukb_finetune"]:
        main = perm_scores[target]["ukb_finetune"][task]
        for comp_task in CONTRASTS_MAP["ukb"]:
            comp = perm_scores[target]["ukb"][comp_task]
            t, t_sig = comp_perm_sig_ttest(main, comp)
            delta = cliffs_delta(main[0], comp[0])
            tasks_comp.append(
                {
                    "Predicted Contrast": task,
                    "Actual Contrast": comp_task,
                    "t": t,
                    "p": t_sig,
                    "delta": delta,
                }
            )
    t, t_sig = comp_perm_sig_ttest(
        perm_scores[target]["ukb"]["connectome_d50"],
        perm_scores[target]["ukb"]["emotion_faces-shapes"],
    )
    delta = cliffs_delta(
        perm_scores[target]["ukb"]["connectome_d50"][0],
        perm_scores[target]["ukb"]["emotion_faces-shapes"][0],
    )
    tasks_comp.append(
        {
            "Predicted Contrast": "connectome_d50",
            "Actual Contrast": "emotion_faces-shapes",
            "t": t,
            "p": t_sig,
            "delta": delta,
        }
    )
    perm_results = pd.DataFrame(tasks_comp)
    perm_results["p_value_fdr"] = fdrcorrection(perm_results["p"])[1]
    return perm_results


def print_permutation_significance(perm_scores, target):
    """
    Prints the permutation significance for each task in each dataset.

    Args:
        perm_scores (dict): A dictionary containing permutation scores for each target, dataset, and task.
        target (str): The target for which permutation significance is calculated.

    Returns:
        sigs (list): A list containing the permutation significance results.
    """
    sigs = []
    for dset in CONTRASTS_MAP.keys():
        for task in CONTRASTS_MAP[dset]:
            stat, sig = comp_perm_sig_cv(perm_scores[target][dset][task])
            print(
                f"Permutation significance for {task} in {dset}: {stat:.3f}, p = {sig:.3f}"
            )
            sigs.append(
                {
                    "Dataset": dset,
                    "Contrast": task.upper().replace("_", "\n"),
                    "CV Score": stat,
                    "p": sig,
                    "Actual": "Actual"
                    if "finetune" not in dset
                    else dset.split("_")[1].capitalize(),
                }
            )
    return pd.DataFrame(sigs)


## Function for plotting CV scores
def plot_cv_scores(target, ylim, sigs=None):
    plt.figure(figsize=(11, 6), dpi=300)
    plot_df = cv_scores[target].dropna()
    if sigs is not None:
        sigs["Dataset"] = sigs["Dataset"]
        plot_df = plot_df.merge(
            sigs.drop("CV Score", axis=1),
            on=["Dataset", "Contrast", "Actual"],
            how="left",
        )
        plot_df["Alpha"] = plot_df["p"].apply(lambda x: 0.25 if x >= 0.05 else 0.75)
    else:
        plot_df["Alpha"] = 0.75
    unique_contrasts = plot_df["Contrast"].unique()
    unique_actual = plot_df["Actual"].unique()
    for contrast in unique_contrasts:
        for actual in unique_actual:
            filter_df = plot_df[
                (plot_df["Contrast"] == contrast) & (plot_df["Actual"] == actual)
            ]
            if not filter_df.empty:
                sns.pointplot(
                    data=filter_df,
                    x="Contrast",
                    y="CV Score",
                    hue="Actual",
                    errorbar="sd",
                    linestyles="",
                    palette=PALETTE,
                    alpha=filter_df["Alpha"].iloc[0],  # Set alpha for each contrast
                ).set(
                    xticklabels=[],
                    yticklabels=[],
                )
    plt.ylim(ylim)
    sns.despine(offset=10, trim=True)
    plt.legend(loc="right", bbox_to_anchor=(1.15, 0.5)).set_visible(False)
    plt.xlabel("")
    plt.ylabel("")
    plt.xticks(rotation=45)
    plt.savefig(op.join(FIG_DIR, f"{target}.pdf"), bbox_inches="tight")


def print_formatted_compare_tasks(results_df):
    """
    Formats and prints task comparison results based on the provided structure,
    using the 'p_value_fdr' column for p-values.

    Args:
        results_df (pd.DataFrame): DataFrame with columns including
                                   'Predicted Contrast', 'Actual Contrast',
                                   't', 'p_value_fdr', and 'delta'.
    """
    actual_emotion_key = "emotion_faces-shapes"
    actual_connectome_key = (
        "connectome_d50"  # Assuming this key represents Resting State Connectome
    )

    # --- Name mapping for display ---
    name_map = {
        "emotion_faces-shapes": "EMOTION FACES-SHAPES",
        "gambling_reward": "GAMBLING REWARD",
        "language_math-story": "LANGUAGE MATH-STORY",
        "motor_avg": "MOTOR AVG",
        "relational_rel": "RELATIONAL REL",
        "social_tom-random": "SOCIAL TOM-RANDOM",
        "wm_2bk-0bk": "WM 2BK-0BK",
        "connectome_d50": "Resting State Connectome",
    }

    # --- Verify input columns ---
    # Use 'p_value_fdr' instead of 'p'
    required_cols = [
        "Predicted Contrast",
        "Actual Contrast",
        "t",
        "p_value_fdr",
        "delta",
    ]
    if not all(col in results_df.columns for col in required_cols):
        missing = [col for col in required_cols if col not in results_df.columns]
        print(f"Error: Input DataFrame is missing required columns: {missing}")
        return

    # --- Prepare data views ---
    # Use .copy() to avoid potential SettingWithCopyWarning
    vs_emotion = (
        results_df[results_df["Actual Contrast"] == actual_emotion_key]
        .set_index("Predicted Contrast")
        .copy()
    )
    vs_connectome = (
        results_df[results_df["Actual Contrast"] == actual_connectome_key]
        .set_index("Predicted Contrast")
        .copy()
    )

    # --- Build the main table ('Predicted' rows) ---
    # Select 'p_value_fdr' and rename it to 'p_emo'/'p_conn'
    df_emo = vs_emotion[["t", "p_value_fdr", "delta"]].rename(
        columns={"t": "t_emo", "p_value_fdr": "p_emo", "delta": "delta_emo"}
    )
    df_conn = vs_connectome[["t", "p_value_fdr", "delta"]].rename(
        columns={"t": "t_conn", "p_value_fdr": "p_conn", "delta": "delta_conn"}
    )

    # Merge the two comparison types based on the predicted contrast index
    merged_df = pd.merge(
        df_emo, df_conn, left_index=True, right_index=True, how="outer"
    )

    # Separate the main prediction rows from the 'actual connectome' row index
    main_contrasts_idx = merged_df.index[merged_df.index != actual_connectome_key]
    main_table = merged_df.loc[main_contrasts_idx].copy()

    # --- Prepare the last row ('Actual' row data) ---
    last_row_stats = None
    if actual_connectome_key in vs_emotion.index:
        # Get stats for 'Resting State Connectome' compared against 'EMOTION FACES-SHAPES'
        # Use 'p_value_fdr' here
        last_row_stats = vs_emotion.loc[
            actual_connectome_key, ["t", "p_value_fdr", "delta"]
        ]

    # --- Format and Print ---

    # Rename index using the map for display
    main_table.index = main_table.index.map(lambda x: name_map.get(x, x))
    main_table.index.name = "Task Contrast"

    # Create multi-level columns to match the target structure
    # The display names 'p' are kept
    main_table.columns = pd.MultiIndex.from_tuples(
        [
            ("vs. Actual EMOTION FACES-SHAPES", "t"),
            ("vs. Actual EMOTION FACES-SHAPES", "p"),  # Display header remains 'p'
            ("vs. Actual EMOTION FACES-SHAPES", "δ"),
            ("vs. Resting State Connectome", "t"),
            ("vs. Resting State Connectome", "p"),  # Display header remains 'p'
            ("vs. Resting State Connectome", "δ"),
        ]
    )

    # Formatting functions for numerical values (handling NaN)
    def format_t(x):
        return f"{x:.2f}" if pd.notna(x) else "-"

    def format_p(x):
        return (
            f"{x:.3f}" if pd.notna(x) else "-"
        )  # This will format the p_value_fdr value

    def format_delta_emo(x):
        return f"{x:.1f}" if pd.notna(x) else "-"  # Delta vs Emotion (e.g., 1.0)

    def format_delta_conn(x):
        return f"{x:.2f}" if pd.notna(x) else "-"  # Delta vs Connectome (e.g., 0.84)

    # Apply formatting to the main table DataFrame
    main_table[("vs. Actual EMOTION FACES-SHAPES", "t")] = main_table[
        ("vs. Actual EMOTION FACES-SHAPES", "t")
    ].map(format_t)
    main_table[("vs. Actual EMOTION FACES-SHAPES", "p")] = main_table[
        ("vs. Actual EMOTION FACES-SHAPES", "p")
    ].map(format_p)  # Apply p format
    main_table[("vs. Actual EMOTION FACES-SHAPES", "δ")] = main_table[
        ("vs. Actual EMOTION FACES-SHAPES", "δ")
    ].map(format_delta_emo)
    main_table[("vs. Resting State Connectome", "t")] = main_table[
        ("vs. Resting State Connectome", "t")
    ].map(format_t)
    main_table[("vs. Resting State Connectome", "p")] = main_table[
        ("vs. Resting State Connectome", "p")
    ].map(format_p)  # Apply p format
    main_table[("vs. Resting State Connectome", "δ")] = main_table[
        ("vs. Resting State Connectome", "δ")
    ].map(format_delta_conn)

    # --- Print the Formatted Output ---

    # Print header for the 'Predicted' section
    print(
        " " * 25
        + f"| {'vs. Actual EMOTION FACES-SHAPES':^30} | {'vs. Resting State Connectome':^30}"
    )
    print(
        f"{'Task Contrast':<25} | {'t':>8} {'p':>10} {'δ':>8} | {'t':>8} {'p':>10} {'δ':>8}"
    )
    print("-" * 85)

    # Print the main table rows ('Predicted')
    for index, row in main_table.iterrows():
        print(
            f"{index:<25} | "
            f"{row[('vs. Actual EMOTION FACES-SHAPES', 't')]:>8} "
            f"{row[('vs. Actual EMOTION FACES-SHAPES', 'p')]:>10} "  # Uses formatted p_value_fdr
            f"{row[('vs. Actual EMOTION FACES-SHAPES', 'δ')]:>8} | "
            f"{row[('vs. Resting State Connectome', 't')]:>8} "
            f"{row[('vs. Resting State Connectome', 'p')]:>10} "  # Uses formatted p_value_fdr
            f"{row[('vs. Resting State Connectome', 'δ')]:>8}"
        )

    print("-" * 85)  # Separator line

    # Print the 'Actual' section (last row)
    print(f"{'Actual':<25} |")  # Label for the 'Actual' group
    if last_row_stats is not None:
        actual_row_name = name_map.get(actual_connectome_key, actual_connectome_key)
        # Use 'p_value_fdr' from last_row_stats for formatting
        print(
            f"{actual_row_name:<25} | "
            f"{format_t(last_row_stats['t']):>8} "
            f"{format_p(last_row_stats['p_value_fdr']):>10} "  # Format p_value_fdr
            f"{format_delta_emo(last_row_stats['delta']):>8} | "
            f"{'-':>8} {'-':>10} {'-':>8}"
        )  # No comparison vs connectome here
    else:
        print(
            f"{name_map.get(actual_connectome_key, actual_connectome_key):<25} | {'Data not found':^60}"
        )

    print("-" * 85)

In [None]:
# Age
target = "age"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0.3, 0.61), sigs=sigs)
print_formatted_compare_tasks(compare_tasks(perm_scores, target))

In [None]:
# Sex
target = "sex"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0.5, 1), sigs=sigs)
print_formatted_compare_tasks(compare_tasks(perm_scores, target))

In [None]:
# Fluid Intelligence
target = "fluid"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (-0.1, 0.301), sigs=sigs)
print_formatted_compare_tasks(compare_tasks(perm_scores, target))

In [None]:
# Grip Strength
target = "strength"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0, 0.601), sigs=sigs)
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# Overall Health
target = "overall_health"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (-0.2, 0.2), sigs=sigs)
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# Depression
target = "depression"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0.4, 0.601), sigs=sigs)
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# Hypertension
target = "hypertension"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0.4, 0.601), sigs=sigs)
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# Alcohol Freq
target = "alcohol_freq"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (-0.2, 0.2), sigs=sigs)
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# Beer Freq
target = "beer_freq"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0, 0.36), sigs=sigs)
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# Neuroticism
target = "neuroticism"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# GAD
target = "GAD"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# PHQ
target = "PHQ"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)

In [None]:
# RDS
target = "RDS"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
comp_results = compare_tasks(perm_scores, target)
print_formatted_compare_tasks(comp_results)