# Load main results data

Used in all plots

In [None]:
from pathlib import Path
import pandas as pd
from dotenv import dotenv_values

from fl_pd.utils.constants import DNAME_LATEST

ENV_VARS = dotenv_values("../.env")
DPATH_RESULTS = (Path(ENV_VARS["DPATH_FL_RESULTS"]) / DNAME_LATEST).resolve()
DPATH_FIGS = Path(ENV_VARS["DPATH_FL_FIGS"])

TIMESTAMP = DPATH_RESULTS.name

fpaths_metrics = {
    Path(relative_path).parent.name.removesuffix("-3791"): DPATH_RESULTS / relative_path
    for relative_path in [
        # "age-sex-diag-case-hc-aparc-aseg-3791/metrics-10_splits-10_null.tsv",
        "decline-age-sex-case-aparc-3791/metrics-10_splits-10_null.tsv",
        "age-sex-hc-aseg-3791/metrics-10_splits-10_null.tsv",
    ]
}

pd.set_option("display.float_format", lambda x: "%.2f" % x)

DATASET_COLOUR_MAP = {
    "PPMI": "#D0A441",
    "ADNI": "#0CA789",
    "QPN": "#A6A6C6",
    "PAD": "#CC444B",
    "CALGARY": "#6A4A3C",
    "ADNI-CALGARY-PAD-PPMI-QPN": "#0C97A7",
    "SITE1": "#D0A441",
    "SITE2": "#0CA789",
    "SITE3": "#A6A6C6",
}

TAG_TITLE_MAP = {
    "age-sex-diag-case-hc-aparc-aseg": "Diagnosis",
    "decline-age-sex-case-aparc": "Cognitive decline",
    "age-sex-hc-aseg": "Age",
    "age-sex-hc-aseg-55": "Age (55)",
}

DATASET_NAMES = [
    "PPMI",
    "ADNI",
    "PAD",
    "CALGARY",
    "QPN",
    # dataset_name
    # for dataset_name in df_results_all["train_dataset"].unique()
    # if "-" not in dataset_name
]

DATASET_ALL = "ADNI-CALGARY-PAD-PPMI-QPN"
N_DATASETS = len(DATASET_NAMES)


def get_results(fpath_metrics: Path, weighted=True) -> pd.DataFrame:

    df_results = pd.read_csv(fpath_metrics, sep="\t")
    if weighted:
        df_results = df_results.loc[df_results["test_dataset"] != "adni-ppmi-qpn"]
        df_results["test_dataset"] = df_results["test_dataset"].apply(
            lambda x: x.removesuffix("-weighted")
        )
    else:
        df_results = df_results.loc[
            df_results["test_dataset"] != "adni-ppmi-qpn-weighted"
        ]
    # df_results = df_results.query('metric == "balanced_accuracy" or metric == "r2"')
    df_results = df_results.query(
        'metric == "balanced_accuracy" or metric == "mean_absolute_error"'
    )
    # df_results = df_results.query('method != "fl_voting" and test_dataset != "all" and (metric == "balanced_accuracy" or metric == "r2")')
    df_results.loc[:, "setup"] = df_results["setup"].map(
        {"silo": "Siloed", "mega": "Mega-analysis", "federated": "Federated"}
    )
    df_results.loc[:, "train_dataset"] = df_results["train_dataset"].str.upper()
    df_results.loc[:, "test_dataset"] = df_results["test_dataset"].str.upper()
    df_results = df_results.reset_index(drop=True)
    return df_results


df_results_all = pd.concat(
    {tag: get_results(fpath) for tag, fpath in fpaths_metrics.items()}
)
df_results_all = df_results_all.reset_index(level=0, names="tag")
df_results_all = df_results_all.sort_values(
    by="train_dataset",
    key=lambda x: x.map((DATASET_NAMES + [DATASET_ALL]).index),
)
df_results_all["setup_train"] = df_results_all.apply(
    lambda row: (
        rf"$\mathtt{{M_{{{row['setup'].lower()},{row['train_dataset'].lower()}}}}}$"
        if row["setup"] == "Siloed"
        else rf"$\mathtt{{M_{{{row['setup'].lower()}}}}}$"
    ),
    axis=1,
)

df_results_all

Utility for saving figures (seaborn FacetGrid)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams["svg.fonttype"] = "none"

sns.set_context("poster")


def save_fig(fig: sns.FacetGrid | plt.Figure, fname, extension="png", **kwargs):
    kwargs_default = {"bbox_inches": "tight"}
    kwargs_default.update(kwargs)

    fpath: Path = (DPATH_FIGS / TIMESTAMP / fname).with_suffix(f".{extension}")
    fpath.parent.mkdir(parents=True, exist_ok=True)
    if extension == "png":
        kwargs = {"dpi": 300}
    else:
        kwargs = {}

    fig.savefig(fpath, **kwargs_default)
    print(f"Saved figure to {fpath}")

# Bar plot

In [None]:
import numpy as np
import seaborn as sns

np.set_printoptions(precision=3)

df_results_all_null = df_results_all.query(
    f"is_null == True and test_dataset != '{DATASET_ALL}'"
)
df_results_all_nonnull = df_results_all.query(
    f"is_null == False and test_dataset != '{DATASET_ALL}'"
)

bar_width = 0.8
null_width = 0.8

available_setups = df_results_all["setup_train"].unique()
x_labels = (
    [setup for setup in available_setups if "silo" in setup]
    + [setup for setup in available_setups if "federated" in setup]
    + [setup for setup in available_setups if "mega" in setup]
)

grid_bar = sns.catplot(
    data=df_results_all_nonnull,
    x="setup_train",
    y="score",
    hue="test_dataset",
    row="tag",
    row_order=[
        tag
        for tag in TAG_TITLE_MAP.keys()
        if tag in df_results_all_nonnull["tag"].unique()
    ],
    kind="bar",
    errorbar="sd",
    hue_order=DATASET_NAMES,
    order=x_labels,
    height=3,
    aspect=5,
    width=bar_width,
    sharex=False,
    sharey=False,
    palette=DATASET_COLOUR_MAP,
    alpha=0.8,
    saturation=1,
)
sns.move_legend(
    grid_bar,
    "lower center",
    bbox_to_anchor=(
        {3: 0.5, 4: 0.45, 5: 0.475}[N_DATASETS],
        0 if N_DATASETS != 5 else 0.05,
    ),
    ncol=N_DATASETS,
)

for i_ax, (tag, ax) in enumerate(grid_bar.axes_dict.items()):

    df_results_nonnull = df_results_all_nonnull.query(f"tag == '{tag}'")
    df_results_null = df_results_all_null.query(f"tag == '{tag}'")
    # task_name = (
    #     df_results_nonnull.iloc[0]["problem"]
    #     + ": "
    #     + df_results_nonnull.iloc[0]["target"]
    #     + f" ({tag})"
    # )
    task_name = TAG_TITLE_MAP[tag]
    metric = df_results_nonnull.iloc[0]["metric"]

    if metric == "balanced_accuracy":
        ax.set_ylim(0, 1.0)
    elif metric == "mean_absolute_error":
        ax.set_ylim(0, 12)

    # ax.text(
    #     -0.08,
    #     1.05,
    #     "ABCDEFGHIJKLMNOP"[i_ax],
    #     transform=ax.transAxes,
    #     size=16,
    #     weight="bold",
    # )

    # fix xticks
    xticks = np.arange(len(x_labels))
    ax.set_xticks(xticks)
    ax.set_xticklabels(x_labels)
    ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5)

    if len(df_results_null) != 0:
        print(f"===== {task_name.upper()} =====")
        mean_null_values = []
        for xticklabel, xtick in zip(ax.get_xticklabels(), ax.get_xticks()):

            # setup = xticklabel.get_text()
            df_null_values_summary = (
                df_results_null.query(f"metric == '{metric}'")
                .groupby(["setup", "test_dataset"])["score"]
                .describe(percentiles=[0.05, 0.95])
            )
            mean_null_values.extend(
                df_null_values_summary.loc[
                    :,
                    "mean",
                    # "5%" if metric == "mean_absolute_error" else "95%",
                ]
            )

        mean_null_values = np.array(mean_null_values)
        print(
            f"Mean nulls: {pd.Series(mean_null_values).describe(percentiles=[0.05, 0.95])}"
        )

        if metric == "mean_absolute_error":
            best_null_value = mean_null_values.min()
        else:
            best_null_value = mean_null_values.max()

        ax.axhline(best_null_value, color="k", linestyle="--", alpha=0.5)

        # ax.axhline(df_results_null.query("setup_train == 'Mega-analysis'")['score'].quantile(0.05 if metric == 'mean_absolute_error' else 0.95), color="k", linestyle="--", alpha=0.5)

    ax.set_ylabel(metric.capitalize().replace("_", "\n"))
    ax.set_title(f"{task_name.capitalize()}", fontdict={"weight": "bold"})
    ax.set_xlabel("")

    # if metric == "mean_absolute_error":
    #     arrowstyle = "->"
    # else:
    #     arrowstyle = "<-"

    # ax.annotate(
    #     "",
    #     xy=(1.05, 0.25),
    #     xycoords="axes fraction",
    #     xytext=(1.05, 0.75),
    #     arrowprops=dict(arrowstyle=arrowstyle, linewidth=2, mutation_scale=20),
    # )
    # ax.annotate(
    #     "Better\nmodel",
    #     xy=(1.1, 0.5),
    #     xycoords="axes fraction",
    #     ha="center",
    #     va="center",
    # )


def legend_title_left(leg):
    c = leg.get_children()[0]
    title = c.get_children()[0]
    hpack = c.get_children()[1]
    c._children = [hpack]
    hpack._children = [title] + hpack.get_children()


grid_bar.legend.set_title("Test dataset")
legend_title_left(grid_bar.legend)

# grid_bar.legend.set_alignment("left")

In [None]:
save_fig(grid_bar, "metrics-bar")

# Setup comparison

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.ticker as ticker

df_results_combined = df_results_all.query(
    f"test_dataset == '{DATASET_ALL}'"  # or is_null == True"
).sort_values(
    by="setup", key=lambda x: x.map(["Siloed", "Federated", "Mega-analysis"].index)
)

# x_labels = ("Siloed", "Federated", "Mega-analysis")

grid_line = sns.relplot(
    data=df_results_combined,
    # x="setup",
    x=df_results_combined["setup"].map(lambda x: x.replace("-", "-\n")),
    y="score",
    hue=df_results_combined["is_null"].map({False: "Score", True: "Null score"}),
    col="tag",
    col_order=[
        tag
        for tag in TAG_TITLE_MAP.keys()
        if tag in df_results_all_nonnull["tag"].unique()
    ],
    kind="line",
    errorbar="sd",
    height=4,
    aspect=1.5,
    facet_kws={"sharey": False, "sharex": False},
    palette={
        "Score": DATASET_COLOUR_MAP[DATASET_ALL],
        "Null score": "grey",
    },
    alpha=1,
    err_style="bars",
    err_kws={"capsize": 3, "linewidth": 4, "capthick": 4},
    markers=True,
    linewidth=4,
)
# sns.move_legend(grid_line, "center right", bbox_to_anchor=(1, 0.55))

for i_ax, (tag, ax) in enumerate(grid_line.axes_dict.items()):
    df_combined_tag = df_results_combined.query("tag == @tag")
    # task_name = (
    #     df_combined_tag.iloc[0]["problem"]
    #     + ": "
    #     + df_combined_tag.iloc[0]["target"]
    #     + f" ({tag})"
    # )
    task_name = TAG_TITLE_MAP[tag]
    metric = df_combined_tag.iloc[0]["metric"]

    if metric == "balanced_accuracy":
        ax.set_ylim(0.44, 0.71)
        # ax.yaxis.set_major_locator(ticker.FixedLocator(np.arange(0.4, 0.71, 0.1)))

    # ax.text(
    #     -0.23,
    #     1.05,
    #     "ABCDEFGHIJKLMNOP"[i_ax + len(dataset_names)],
    #     transform=ax.transAxes,
    #     size=16,
    #     weight="bold",
    # )

    ax.set_ylabel(metric.capitalize().replace("_", "\n"))
    ax.set_title(f"{task_name.capitalize()}", fontdict={"weight": "bold"})
    ax.set_xlabel("")

grid_line.legend.set_title("")
grid_line.tight_layout()

sns.move_legend(grid_line, "lower center", bbox_to_anchor=(0.4525, -0.1), ncols=2)

In [None]:
save_fig(grid_line, fname="metrics-line")

# Dataset descriptive plots

In [None]:
def load_data(tag: str) -> pd.DataFrame:
    dfs = {}
    for dataset_name in DATASET_NAMES:
        tag_with_dataset = f"{dataset_name.lower()}-{tag}"
        dfs[dataset_name] = pd.read_csv(
            DPATH_DATA / tag_with_dataset / f"{tag_with_dataset}.tsv", sep="\t"
        )
    df = pd.concat(dfs, axis="index", names=["dataset", "tmp"])
    df = df.reset_index(level="tmp", drop=True)
    df = df.reset_index()
    return df


tags = df_results_all["tag"].unique()

defined_tags = []
if any("diag" in tag for tag in tags):
    tag_diag = df_results_all.query("tag.str.contains('diag')")["tag"].unique().item()
    defined_tags.append(tag_diag)
    print(f"{tag_diag=}")
else:
    tag_diag = None

tag_cog_decline = (
    df_results_all.query("tag.str.contains('decline')")["tag"].unique().item()
)
defined_tags.append(tag_cog_decline)
tag_age = df_results_all.query("tag not in @defined_tags")["tag"].unique().item()
print(f"{tag_cog_decline=}")
print(f"{tag_age=}")

print(f"{DATASET_NAMES=}")
DPATH_DATA = Path(ENV_VARS["DPATH_FL_DATA_LATEST"])


if tag_diag is not None:
    df_diag = load_data(tag_diag)
df_cog_decline = load_data(tag_cog_decline)
df_age = load_data(tag_age)

In [None]:
import numpy as np
import matplotlib.ticker as ticker
import seaborn as sns
from matplotlib.colors import to_hex, to_rgb

ax_count = 0


def plot_binary_counts(df, col, val_map):
    global ax_count
    palette = {}

    def hue_map(dataset, val):
        alpha = 0.8 if val == list(val_map.keys())[0] else 0.5
        colour = to_hex(to_rgb(DATASET_COLOUR_MAP[dataset]) + (alpha,), keep_alpha=True)
        palette[colour] = colour
        return colour

    hue = df.apply(lambda row: hue_map(row["dataset"], row[col]), axis="columns")
    grid = sns.catplot(
        data=df,
        y=df[col].map(val_map),
        order=list(val_map.values()),
        hue=hue,
        kind="count",
        col="dataset",
        col_order=DATASET_NAMES,
        height=3,
        aspect=1.5,
        palette=palette,
        saturation=1,
        legend=False,
        sharex=False,
        facet_kws={"xlim": (0, 1450 if N_DATASETS == 3 else 1600)},
    )
    for dataset_name, ax in grid.axes_dict.items():
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_title(dataset_name)
        for container in ax.containers:
            ax.bar_label(container, padding=3)
        # ax.text(
        #     -0.4,
        #     1.05,
        #     "ABCDEFGHIJKLMNOP"[ax_count],
        #     transform=ax.transAxes,
        #     size=16,
        #     weight="bold",
        # )
        ax_count += 1
    return grid


if tag_diag is not None:
    grid_diag = plot_binary_counts(df_diag, "DIAGNOSIS", {0: "Control", 1: "Case"})

grid_cog_decline = plot_binary_counts(
    df_cog_decline, "COG_DECLINE", {False: "Stable", True: "Decline"}
)

figsize = grid_cog_decline.figure.get_size_inches()
hue_order = ["QPN", "CALGARY", "PPMI", "ADNI", "PAD"]
idx_dataset_names_to_hue_order = [hue_order.index(name) for name in DATASET_NAMES]
assert set(hue_order) == set(DATASET_NAMES), "Wrong datasets in hue_order"
grid_age = sns.displot(
    data=df_age,
    x="AGE",
    hue="dataset",
    kind="kde",
    fill=True,
    # height=3.25,
    height=4,
    aspect=1.46 * 3,
    hue_order=hue_order,
    palette=DATASET_COLOUR_MAP,
    facet_kws={"legend_out": False, "xlim": (15, 95)},
    alpha=0.5,
    linewidth=0,
)
grid_age.legend.set_title("")
grid_age.ax.set_xlabel("Age")
grid_age.legend.set_frame_on(False)
grid_age.ax.legend(
    [grid_age.legend.legend_handles[i] for i in idx_dataset_names_to_hue_order],
    DATASET_NAMES,
    loc="upper left",
    frameon=False,
    ncols=2 if N_DATASETS > 3 else None,
    borderpad=0,
    borderaxespad=0,
    bbox_to_anchor=(0.01, 0.99),
)
y_tick = {3: 0.03, 4: 0.025, 5: 0.02}[N_DATASETS]
grid_age.ax.yaxis.set_major_locator(
    ticker.FixedLocator(np.arange(0, y_tick + 1e-6, y_tick))
)
print(f"{grid_age.ax.get_ylim()=}")
# grid_age.ax.text(
#     -0.3,
#     2.5,
#     "ABCDEFGHIJKLMNOP"[ax_count],
#     transform=ax.transAxes,
#     size=16,
#     weight="bold",
# )

In [None]:
if tag_diag is not None:
    save_fig(grid_diag, fname="distribution-diag")
save_fig(grid_cog_decline, fname="distribution-cog_decline")
save_fig(grid_age, fname="distribution-age")

# Table

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import seaborn as sns

sns.set_context("poster")

y_min = 0
y_max = 1 + 3 * N_DATASETS
x_min = 0
x_max = 15

n_rows = 1 + 3 * N_DATASETS
n_cols = 3

# normal arrow
arrow_start = 8.8
arrow_end = 10.5

arrow_offset_bracket = 0.25 if N_DATASETS == 3 else 0.35

width = x_max - x_min
height = y_max - y_min

fig, ax = plt.subplots(figsize=(width, height))

xylim_offset = 0.02
ax.set_xlim(x_min - xylim_offset * width, x_max + xylim_offset * width)
ax.set_ylim(y_min - xylim_offset * height, y_max + xylim_offset * height)

ax.xaxis.set_major_locator(ticker.FixedLocator(np.arange(x_min, x_max + 1, 1)))
ax.yaxis.set_major_locator(ticker.FixedLocator(np.arange(y_min, y_max + 1, 1)))


def draw_hline(i_row, **kwargs):

    default_kwargs = {"color": "k", "linewidth": 2}
    default_kwargs.update(kwargs)

    if i_row > n_rows:
        raise ValueError("Row index out of bounds")

    ax.hlines(
        y=y_max - ((height / n_rows) * i_row),
        xmin=x_min,
        xmax=x_max,
        **default_kwargs,
    )


def write_to_cell(
    i_row, i_col, text, x_offset=0.5, y_offset=0.5, fontdict=None, **kwargs
):

    if i_row >= n_rows or i_col >= n_cols:
        raise ValueError("Row or column index out of bounds")

    default_kwargs = {"ha": "center", "va": "center"}
    default_kwargs.update(kwargs)
    ax.text(
        x=(width / n_cols) * i_col + (width / n_cols * x_offset),
        y=height - ((height / n_rows) * i_row) - (height / n_rows * y_offset),
        s=text,
        fontdict=fontdict,
        **default_kwargs,
    )


def draw_arrow(start, end, arrowprops=None, **kwargs):
    # x is given in data coordinates
    # y is given as i_row
    if arrowprops is None:
        arrowprops = dict(arrowstyle="-|>", linewidth=2.5, mutation_scale=30, fc="k")
    ax.annotate(
        "",
        xy=(end[0], height - ((height / n_rows) * end[1]) - (height / n_rows / 2)),
        xytext=(
            start[0],
            height - ((height / n_rows) * start[1]) - (height / n_rows / 2),
        ),
        arrowprops=arrowprops,
        bbox=dict(pad=0),
        **kwargs,
    )


def draw_bracket(x, i_row, with_arrow=False, **kwargs):
    draw_arrow(
        (x, i_row),
        (x + 0.00001, i_row),
        arrowprops=dict(
            arrowstyle="]-",
            linewidth=2.5,
            mutation_scale={3: 55, 4: 80, 5: 105}[N_DATASETS],
        ),
        **kwargs,
    )


# Header
i_row = 0
write_to_cell(i_row, 0, "Setup", fontdict={"weight": "bold"})
write_to_cell(i_row, 1, "Train data", fontdict={"weight": "bold"})
write_to_cell(i_row, 2, "Model", fontdict={"weight": "bold"})
draw_hline(1)

# Siloed
for dataset_name in DATASET_NAMES:
    i_row += 1
    write_to_cell(i_row, 0, f"Siloed ({dataset_name})")
    write_to_cell(
        i_row,
        1,
        dataset_name,
        color=DATASET_COLOUR_MAP[dataset_name],
        fontdict={"weight": "bold"},
    )
    draw_arrow((arrow_start, i_row), (arrow_end, i_row))
    write_to_cell(i_row, 2, rf"$\mathtt{{M_{{siloed,{dataset_name.lower()}}}}}$")
    draw_hline(i_row + 1)

i_row += 1

# Federated
y_fed_middle = i_row + ((N_DATASETS - 1) / 2)
write_to_cell(y_fed_middle, 0, "Federated")
for dataset_name in DATASET_NAMES:
    write_to_cell(
        i_row,
        1,
        dataset_name,
        color=DATASET_COLOUR_MAP[dataset_name],
        fontdict={"weight": "bold"},
    )
    draw_arrow((arrow_start, i_row), (arrow_end, i_row))
    write_to_cell(
        i_row, 2, rf"$\mathtt{{M_{{{dataset_name.lower()}}}}}$", x_offset=0.25
    )
    i_row += 1

draw_bracket(12.1, y_fed_middle)
draw_arrow((12.1, y_fed_middle), (12.7, y_fed_middle))
write_to_cell(y_fed_middle, 2, r"$\mathtt{M_{federated}}$", x_offset=0.75)
write_to_cell(
    y_fed_middle,
    2,
    "(weighted avg.\nof params.)",
    x_offset=0.75,
    y_offset=1.1,
    fontdict={"size": 14},
)

draw_hline(i_row)

# Mega-analysis
y_mega_middle = i_row + ((N_DATASETS - 1) / 2)
write_to_cell(y_mega_middle, 0, "Mega-analysis")
for dataset_name in DATASET_NAMES:
    write_to_cell(
        i_row,
        1,
        dataset_name,
        color=DATASET_COLOUR_MAP[dataset_name],
        fontdict={"weight": "bold"},
    )
    i_row += 1
draw_arrow(
    (arrow_start + arrow_offset_bracket, y_mega_middle), (arrow_end, y_mega_middle)
)
write_to_cell(y_mega_middle, 2, r"$\mathtt{M_{mega-analysis}}$")
draw_bracket(arrow_start + arrow_offset_bracket, y_mega_middle)

draw_hline(i_row)

ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])

In [None]:
save_fig(fig, fname="setups")