In [None]:
from pathlib import Path
import pandas as pd
from dotenv import dotenv_values

ENV_VARS = dotenv_values("../config/.env")
DPATH_RESULTS = Path(ENV_VARS["DPATH_FL_RESULTS"])
DPATH_FIGS = Path(ENV_VARS["DPATH_FL_FIGS"])

fpaths_metrics = {
    Path(relative_path).parent.name.removesuffix("-3791"): DPATH_RESULTS / relative_path
    for relative_path in [
        # # 2025_05_27: old scripts, no normative model, LR, standardscaler on all columns
        # "2025_05_27/age-sex-diag-case-hc-aparc-aseg-3791/metrics-10_splits-10_null.tsv",
        # "2025_05_27/decline-age-case-aparc-3791/metrics-10_splits-10_null.tsv",
        # "2025_05_27/age-sex-hc-aseg-55-3791/metrics-10_splits-10_null.tsv",
        # "2025_05_27/age-sex-hc-aseg-3791/metrics-10_splits-10_null.tsv",
        # # 2025_06_03: normative model, LR, standardscaler on all columns
        # "2025_06_03/age-sex-diag-case-hc-aparc-aseg-norm-3791/metrics-10_splits-10_null.tsv",
        # "2025_06_03/decline-age-sex-case-aparc-norm-3791/metrics-10_splits-10_null.tsv",
        # "2025_06_03/age-sex-hc-aseg-55-norm-3791/metrics-10_splits-10_null.tsv",
        # "2025_06_03/age-sex-hc-aseg-norm-3791/metrics-10_splits-10_null.tsv",
        # # 2025_06_04: normative model, LR, no standardscaler
        # "2025_06_04/age-sex-diag-case-hc-aparc-aseg-norm-3791/metrics-10_splits-10_null.tsv",
        # "2025_06_04/decline-age-sex-case-aparc-norm-3791/metrics-10_splits-10_null.tsv",
        # "2025_06_04/age-sex-hc-aseg-55-norm-3791/metrics-10_splits-10_null.tsv",
        # "2025_06_04/age-sex-hc-aseg-norm-3791/metrics-10_splits-10_null.tsv",
        # 2025_06_04: no normative model, LR, standardscaler on all columns
        "2025_06_04/age-sex-diag-case-hc-aparc-aseg-3791/metrics-10_splits-10_null.tsv",
        "2025_06_04/decline-age-sex-case-aparc-3791/metrics-10_splits-10_null.tsv",
        "2025_06_04/age-sex-hc-aseg-55-3791/metrics-10_splits-10_null.tsv",
        "2025_06_04/age-sex-hc-aseg-3791/metrics-10_splits-10_null.tsv",
    ]
}

pd.set_option("display.float_format", lambda x: "%.2f" % x)

DATASET_COLOUR_MAP = {
    "PPMI": "#D0A441",
    "ADNI": "#0CA789",
    "QPN": "#A6A6C6",
    "ADNI-PPMI-QPN": "#0C97A7",
    "SITE1": "#D0A441",
    "SITE2": "#0CA789",
    "SITE3": "#A6A6C6",
}


def get_results(fpath_metrics: Path) -> pd.DataFrame:

    df_results = pd.concat(
        [
            pd.read_csv(fpath_metrics, sep="\t"),
        ],
        axis="index",
    )
    # df_results = df_results.query('metric == "balanced_accuracy" or metric == "r2"')
    df_results = df_results.query(
        'metric == "balanced_accuracy" or metric == "mean_absolute_error"'
    )
    # df_results = df_results.query('method != "fl_voting" and test_dataset != "all" and (metric == "balanced_accuracy" or metric == "r2")')
    df_results.loc[:, "setup"] = df_results["setup"].map(
        {"silo": "Siloed", "mega": "Mega-analysis", "federated": "Federated"}
    )
    df_results.loc[:, "train_dataset"] = df_results["train_dataset"].str.upper()
    df_results.loc[:, "test_dataset"] = df_results["test_dataset"].str.upper()
    df_results = df_results.reset_index(drop=True)
    return df_results


df_results_all = pd.concat(
    {tag: get_results(fpath) for tag, fpath in fpaths_metrics.items()}
)
df_results_all = df_results_all.reset_index(level=0, names="tag")
df_results_all["setup_train"] = df_results_all.apply(
    lambda row: (
        f"{row['setup']} ({row['train_dataset']})"
        if row["setup"] == "Siloed"
        else row["setup"]
    ),
    axis=1,
)
df_results_all

In [None]:
import numpy as np
import seaborn as sns

np.set_printoptions(precision=3)

df_results_all_null = df_results_all.query("is_null == True")
df_results_all_nonnull = df_results_all.query("is_null == False")

bar_width = 0.8
null_width = 0.8

available_setups = df_results_all["setup_train"].unique()
x_labels = (
    list(sorted([v for v in available_setups if "Siloed" in v]))
    + (["Federated"] if "Federated" in available_setups else [])
    + (["Mega-analysis"] if "Mega-analysis" in available_setups else [])
)

grid = sns.catplot(
    data=df_results_all_nonnull,
    x="setup_train",
    y="score",
    hue="test_dataset",
    row="tag",
    kind="bar",
    errorbar="sd",
    order=x_labels,
    height=2.5,
    aspect=4,
    width=bar_width,
    sharex=False,
    sharey=False,
    palette=DATASET_COLOUR_MAP,
    alpha=0.8,
    saturation=1,
)

for i_ax, (tag, ax) in enumerate(grid.axes_dict.items()):

    df_results_nonnull = df_results_all_nonnull.query("tag == @tag")
    df_results_null = df_results_all_null.query("tag == @tag")
    task_name = (
        df_results_nonnull.iloc[0]["problem"]
        + ": "
        + df_results_nonnull.iloc[0]["target"]
        + f" ({tag})"
    )
    metric = df_results_nonnull.iloc[0]["metric"]

    if metric == "balanced_accuracy":
        ax.set_ylim(0, 1.0)
    elif metric == "mean_absolute_error":
        ax.set_ylim(0, 16)

    ax.text(
        -0.08,
        1.05,
        "ABCDEFGHIJKLMNOP"[i_ax],
        transform=ax.transAxes,
        size=16,
        weight="bold",
    )

    # fix xticks
    xticks = np.arange(len(x_labels))
    ax.set_xticks(xticks)
    ax.set_xticklabels(x_labels)
    ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5)

    if len(df_results_null) != 0:
        print(f"===== {task_name.upper()} =====")
        mean_null_values = []
        for xticklabel, xtick in zip(ax.get_xticklabels(), ax.get_xticks()):

            # setup = xticklabel.get_text()
            df_null_values_summary = (
                df_results_null.query(f"metric == @metric")
                .groupby(["setup", "test_dataset"])["score"]
                .describe(percentiles=[0.05, 0.95])
            )
            mean_null_values.extend(
                df_null_values_summary.loc[
                    :,
                    "mean",
                    # "5%" if metric == "mean_absolute_error" else "95%",
                ]
            )

        mean_null_values = np.array(mean_null_values)
        print(f"Mean nulls: {pd.Series(mean_null_values).describe()}")

        if metric == "mean_absolute_error":
            best_null_value = mean_null_values.min()
        else:
            best_null_value = mean_null_values.max()

        ax.axhline(best_null_value, color="k", linestyle="--", alpha=0.5)

        ax.set_ylabel(metric.capitalize().replace("_", " "))
        ax.set_title(f"{task_name.capitalize()}")
        ax.set_xlabel("")

        # if metric == "mean_absolute_error":
        #     arrowstyle = "->"
        # else:
        #     arrowstyle = "<-"

        # ax.annotate(
        #     "",
        #     xy=(1.05, 0.25),
        #     xycoords="axes fraction",
        #     xytext=(1.05, 0.75),
        #     arrowprops=dict(arrowstyle=arrowstyle, linewidth=2, mutation_scale=20),
        # )
        # ax.annotate(
        #     "Better\nmodel",
        #     xy=(1.1, 0.5),
        #     xycoords="axes fraction",
        #     ha="center",
        #     va="center",
        # )

grid.legend.set_title("Test dataset")

In [None]:
DPATH_FIGS.mkdir(exist_ok=True)

fpath_fig = DPATH_FIGS / f"metrics-combined.png"
# grid.savefig(fpath_fig, dpi=300)