In [None]:
import os
import os.path as op
import pickle
import sys
import warnings
from collections import defaultdict
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import permutation_test, ttest_rel
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.metrics import balanced_accuracy_score
from scipy.stats import ttest_rel

sys.path.append(op.abspath(op.join(op.abspath(""), "..")))
from utils.utils import correlation_score

sns.set_style("ticks")
sns.set_context("talk", font_scale=1, rc={"axes.labelpad": 10})
pd.set_option("display.float_format", "{:.3}".format)

warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option("display.float_format", "{:.3}".format)

In [None]:
ABS_PATH = sys.path[-1]
RESULTS_PATH = op.join(
    ABS_PATH, "3_prediction/results"
)
FIG_DIR = op.join(ABS_PATH, "3_prediction/figures")
os.makedirs(FIG_DIR, exist_ok=True)

PALETTE = {
    "Actual": "#283F94",
    "Predicted": "#AE3033",
}

CONTRASTS = (
    "REST",
    "EMOTION FACES-SHAPES",
    "GAMBLING REWARD",
    "LANGUAGE MATH-STORY",
    "RELATIONAL REL",
    "SOCIAL TOM-RANDOM",
    "WM 2BK-0BK",
    "MOTOR AVG",
)

CONTRASTS_MAP = {
    "ukb_actual": (
        "rest",
        "emotion_faces-shapes",
    ),
    "ukb_pred": (
        "emotion_faces-shapes",
        "gambling_reward",
        "language_math-story",
        "relational_rel",
        "social_tom-random",
        "wm_2bk-0bk",
        "motor_avg",
    ),
}

SCORE_FUNCS = {
    "age": correlation_score,
    "fluid": correlation_score,
    "sex": balanced_accuracy_score,
    "strength": correlation_score,
    "overall_health": correlation_score,
    "alcohol_freq": correlation_score,
    "depression": balanced_accuracy_score,
    "neuroticism": correlation_score,
    "GAD": correlation_score,
    "PHQ": correlation_score,
    "RDS": correlation_score,
}

In [None]:
# Turn of version warnings.
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Prepare dataframes for plotting.
def results_files_dict():
    pred_scores = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
    dsets = ["ukb_actual", "ukb_pred"]
    for target in SCORE_FUNCS.keys():
        for dset in CONTRASTS_MAP:
            for cont in CONTRASTS_MAP[dset]:
                cont = cont.replace(" ", "_").lower()
                try:
                    fname = glob(
                        op.join(
                            RESULTS_PATH, dset,
                            f"{dset}_{target}_{cont}.pkl",
                        )
                    )[0]
                    pred_scores[target][dset][cont] = fname
                except:
                    print(f"Could not find {dset}_{target}_{cont}.pkl")
                    pass
    return pred_scores


def compute_scores(pred_scores):
    perm_scores = {}
    cv_scores = {}
    for target in pred_scores:
        perm_scores[target] = {}
        _scores = pd.DataFrame()
        for dset in pred_scores[target]:
            perm_scores[target][dset] = {}
            for i, task in enumerate(pred_scores[target][dset]):
                preds = pickle.load(open(pred_scores[target][dset][task], "rb"))
                if i == 0:
                    print(
                        f"Dataset: {dset}, Target: {target}, n = {np.hstack(preds[0]['y_true']).shape[0]}"
                    )
                __sc = []
                for pred in preds: 
                    __sc.append([SCORE_FUNCS[target](y_true, y_pred) for y_true, y_pred in zip(pred["y_true"], pred["y_pred"])])
                perm_scores[target][dset][task] = __sc
                tmp = pd.DataFrame(__sc[0], columns=["CV Score"]).assign(
                    Contrast=task.replace("_", "\n").upper(),
                    Dataset="ukb",
                    Actual="Actual" if "pred" not in dset else "Predicted",
                )
                _scores = pd.concat(
                    [
                        _scores,
                        tmp,
                    ],
                    axis=0,
                )
        cv_scores[target] = _scores
    return cv_scores, perm_scores
    
cv_scores, perm_scores = compute_scores(results_files_dict())

In [None]:
## HELPER FUNCTIONS
def comp_perm_sig_cv(perms):
    """Compute permutation significance from given permutation scores."""
    cv = np.mean(perms[0])
    perm_cv = [np.mean(p) for p in perms]
    return cv, (perm_cv >= cv).mean()

def comp_perm_sig_ttest(main, comp):
    """Compute permutation significance from given two sets of permutation scores."""
    t, _ = ttest_rel(main[0], comp[0])
    perm_t = [ttest_rel(m, c)[0] for m, c in zip(main, comp)]
    if t > 0:
        return t, (perm_t >= t).mean()
    else:
        return t, (perm_t <= t).mean()


def compare_tasks(perm_scores, target):
    """
    Compare predicted and actual contrasts for a given target.

    Args:
        perm_scores (dict): Dictionary containing permutation scores for different contrasts.
        target (str): Target contrast.

    Returns:
        pandas.DataFrame: DataFrame containing the comparison results.
    """
    tasks_comp = []
    for task in CONTRASTS_MAP["ukb_pred"]:
        main = perm_scores[target]["ukb_pred"][task]
        for comp_task in CONTRASTS_MAP["ukb_actual"]:
            comp = perm_scores[target]["ukb_actual"][comp_task]
            t, t_sig = comp_perm_sig_ttest(main, comp)
            diff, diff_sig = comp_perm_sig_diff(main, comp)
            tasks_comp.append(
                {
                    "Predicted Contrast": task,
                    "Actual Contrast": comp_task,
                    "t": t,
                    "p": t_sig,
                }
            )
    t, t_sig = comp_perm_sig_ttest(
        perm_scores[target]["ukb_actual"]["rest"],
        perm_scores[target]["ukb_actual"]["emotion_faces-shapes"],
    )
    tasks_comp.append(
        {
            "Predicted Contrast": "rest",
            "Actual Contrast": "emotion_faces-shapes",
            "t": t,
            "p": t_sig,
        }
    )
    return pd.DataFrame(tasks_comp)


def print_permutation_significance(perm_scores, target):
    """
    Prints the permutation significance for each task in each dataset.

    Args:
        perm_scores (dict): A dictionary containing permutation scores for each target, dataset, and task.
        target (str): The target for which permutation significance is calculated.

    Returns:
        sigs (list): A list containing the permutation significance results.
    """
    sigs = []
    for dset in CONTRASTS_MAP.keys():
        for task in CONTRASTS_MAP[dset]:
            stat, sig = comp_perm_sig_cv(perm_scores[target][dset][task])
            print(
                f"Permutation significance for {task} in {dset}: {stat:.3f}, p = {sig:.3f}"
            )
            sigs.append({
                "Dataset": dset.replace("_pred", ""),
                "Contrast": task.upper().replace("_", "\n"),
                "CV Score": stat,
                "p": sig,
                "Actual": "Actual" if "pred" not in dset else "Predicted",
            })
    return pd.DataFrame(sigs)

## Function for plotting CV scores
def plot_cv_scores(target, ylim, sigs=None):
    plt.figure(figsize=(11, 6), dpi=300)
    plot_df = cv_scores[target].dropna()
    if sigs is not None:
        sigs["Dataset"] = sigs["Dataset"].replace({"ukb_actual": "ukb"})
        plot_df = plot_df.merge(sigs.drop("CV Score", axis=1), on=["Dataset", "Contrast", "Actual"], how="left")
        plot_df['Alpha'] = plot_df['p'].apply(lambda x: 0.25 if x >= 0.05 else 0.75)
    else:
        plot_df['Alpha'] = 0.75
    unique_contrasts = plot_df['Contrast'].unique()
    unique_actual = plot_df['Actual'].unique()
    for contrast in unique_contrasts:
        for actual in unique_actual:
            filter_df = plot_df[(plot_df['Contrast'] == contrast) & (plot_df['Actual'] == actual)]
            if not filter_df.empty:
                sns.pointplot(
                    data=filter_df,
                    x="Contrast",
                    y="CV Score",
                    hue="Actual",
                    errorbar="sd",
                    linestyles="",
                    palette=PALETTE,
                    alpha=filter_df['Alpha'].iloc[0],  # Set alpha for each contrast
                )
    plt.ylim(ylim if ylim else target_ylim_map[target])  # Set y-axis limits
    sns.despine(offset=10, trim=True)
    plt.legend(loc="right", bbox_to_anchor=(1.15, 0.5)).set_visible(False)
    plt.xlabel("")
    plt.ylabel("Score")
    plt.xticks(rotation=45)
    plt.savefig(op.join(FIG_DIR, f"{target}.pdf"), bbox_inches="tight")

In [None]:
# Age
target = "age"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0.3, 0.7), sigs=sigs)
compare_tasks(perm_scores, target)

In [None]:
# Age
target = "sex"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0.5, 1), sigs=sigs)
compare_tasks(perm_scores, target)

In [None]:
# Fluid Intelligence
target = "fluid"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (-0.1, 0.301), sigs=sigs)
compare_tasks(perm_scores, target)

In [None]:
# Grip Strength
target = "strength"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0, 0.601), sigs=sigs)
compare_tasks(perm_scores, target)

In [None]:
# Overall Health
target = "overall_health"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (-0.2, 0.2), sigs=sigs)
compare_tasks(perm_scores, target)

In [None]:
# Depression
target = "depression"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (0.4, 0.601), sigs=sigs)
compare_tasks(perm_scores, target)

In [None]:
# Alcohol Freq
target = "alcohol_freq"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, (-0.2, 0.2), sigs=sigs)
compare_tasks(perm_scores, target)

In [None]:
# Neuroticism
target = "neuroticism"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
compare_tasks(perm_scores, target)

In [None]:
# GAD
target = "GAD"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
compare_tasks(perm_scores, target)

In [None]:
# PHQ
target = "PHQ"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
compare_tasks(perm_scores, target)

In [None]:
# RDS
target = "RDS"
sigs = print_permutation_significance(perm_scores, target)
plot_cv_scores(target, sigs=sigs, ylim=(-0.2, 0.2))
compare_tasks(perm_scores, target)