In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

sys.path.append("../../")

from functools import partial
import warnings

import srsly
from rich import print
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import numpy as np
import itertools

from IPython.display import display

from mgi.defaults import ROOT_PATH, PLOTS_PATH
from mgi.data.sampled_datasets import load_sampled_datasets_metadata
from dataclasses import asdict

os.chdir(ROOT_PATH)

from mgi.utils.config import load_training_config

In [None]:
warnings.filterwarnings("ignore")

In [None]:
sns.set_theme(context="paper", style="whitegrid", palette="colorblind", font_scale=1.1)

In [None]:
plots_path = PLOTS_PATH / "experiment_analysis"
plots_path.mkdir(exist_ok=True, parents=True)

# Loading

### Configs loading

In [None]:
configs = []
for x in os.scandir(ROOT_PATH / "experiments/configs/training/training_items"):
    [config_training_items] = srsly.read_yaml(x).values()
    configs += [load_training_config(**item) for item in config_training_items]

configs = [OmegaConf.to_container(c, resolve=True) for c in configs]
configs_df = pd.DataFrame(configs)

In [None]:
configs_df["run_name"] = [
    [f"{x.experiment_name}_{s}" for s in x.random_seed[: x.repeats]]
    for x in configs_df.itertuples()
]
configs_df = configs_df.explode("run_name").reset_index(drop=True)

### Results loading

In [None]:
results_df = pd.read_pickle("data/experiments_results/results.pkl")

In [None]:
assert not any(results_df.name.duplicated())

### Not executed experiments

In [None]:
configs_df[~configs_df.run_name.isin(results_df.name)].run_name.tolist()

### Metrics

In [None]:
df = pd.concat(
    [results_df, pd.json_normalize(results_df.config), pd.json_normalize(results_df.summary)],
    axis=1,
)

### Sampling params

In [None]:
sampling = pd.json_normalize(
    df["config.ds_dataset"]
    .map(load_sampled_datasets_metadata())
    .apply(lambda x: asdict(x) if pd.notna(x) else x)
)
sampling = sampling.drop(columns="name")
df = pd.concat([df, sampling], axis="columns")

In [None]:
df["is_full"] = df["sampling_config.sampling"].isna()
df["is_combined"] = df["config.gk_dataset"] != "???"

### Renaming

In [None]:
df.columns = [
    c.replace("config.loss", "loss_func")
    .removeprefix("testing.both.optimistic.")
    .removeprefix("config.")
    .removeprefix("sampling_config.")
    .replace("hits_at_", "Hits@")
    .replace("inverse_harmonic_mean_rank", "MRR")
    .replace("arithmetic_mean_rank", "MR")
    .replace("val/.both.optimistic.Hits@10", "val_Hits@10")
    for c in df.columns
]

In [None]:
mask = np.zeros_like(df.columns.get_loc("model"))
mask[np.argmax(df.columns.get_loc("model"))] = True
df = df.loc[:, ~mask]

In [None]:
df = df[df.ds_dataset.notna()]

In [None]:
df["ds_dataset_parent"] = df.ds_dataset.str.split("_").apply(lambda x: x[0])

In [None]:
df["crop_gk_n"] = df["crop_gk_n"].fillna("full")
df["graph_type"] = df["is_combined"].apply(lambda x: "linked" if x else "single")

# Plots

## w/o gk

In [None]:
DS_DATASET_PARENTS = ["WN18RR", "FB15K237", "WD50K"]
COLUMNS_TO_KEEP = ["is_combined", "sampling", "p"]

X = "ds_dataset"
HITS_10_METRIC = "Hits@10"
MR_METRIC = "MR"
MRR_METRIC = "MRR"
HUE = "combine_method"
SERACH_METRIC = "val_Hits@10"

HITS_METRICS = (
    "Hits@1",
    "Hits@3",
    "Hits@10",
)

SAMPLINGS = ["triple", "node", "relation"]
METRICS_TO_PLOT = [HITS_10_METRIC, MR_METRIC, MRR_METRIC]
# METRICS_TO_PLOT = [*HITS_METRICS, MR_METRIC, MRR_METRIC]

In [None]:
standard_df = df[~df.is_combined]

In [None]:
standard_results_df = (
    standard_df.groupby(["ds_dataset_parent", X, *COLUMNS_TO_KEEP], dropna=False)[
        [*HITS_METRICS, MR_METRIC, MRR_METRIC]
    ]
    .agg(["mean", "std", "size"])
    .reset_index()
)
with pd.option_context("display.float_format", "{:.3f}".format):
    display(standard_results_df)

In [None]:
for ds_dataset_parent in DS_DATASET_PARENTS:
    samplings = [
        s
        for s in SAMPLINGS
        if s in df[df.ds_dataset_parent == ds_dataset_parent]["sampling"].unique()
    ]
    fig, axes = plt.subplots(len(METRICS_TO_PLOT), len(samplings), figsize=(len(samplings) * 3, 7))
    for metric_, row_axes in zip(METRICS_TO_PLOT, axes):
        for sampling, ax in zip(samplings, row_axes):
            to_plot = df[~df.is_combined]
            to_plot = to_plot[to_plot.ds_dataset_parent == ds_dataset_parent]
            to_plot = to_plot[
                (standard_df["sampling"] == sampling) | (standard_df["sampling"].isna())
            ]
            to_plot["sampling probability"] = (to_plot.p.fillna(1.0) * 100).apply(
                lambda x: f"{int(x)}%"
            )
            to_plot = to_plot.sort_values(by=["is_full", X])
            g = sns.barplot(to_plot, x="sampling probability", y=metric_, ax=ax)
            # ax.xaxis.set_tick_params(rotation=90)
            ax.set_title(f"sampling={sampling}")
    fig.suptitle(f"{ds_dataset_parent} w/o general knowledge")
    plt.tight_layout()
    plt.show()

# Synthetic

In [None]:
cn_hparams = ["alignment_k", "loss_func", "crop_gk_n"]
wn_hparams = ["alignment_k", "loss_func"]

In [None]:
combined_df = df[df.is_combined]

In [None]:
grouped = (
    combined_df[combined_df["gk_dataset"] == combined_df.ds_dataset_parent]
    .groupby(["ds_dataset_parent", "gk_dataset", X, *wn_hparams, *COLUMNS_TO_KEEP])[
        METRICS_TO_PLOT + [SERACH_METRIC]
    ]
    .agg(["mean", "std", "size"])
)
assert (grouped[HITS_10_METRIC]["size"] == 3).all()
g = grouped.reset_index()
results_per_ds_dataset = g.loc[g.groupby(X)[SERACH_METRIC].idxmax()[SERACH_METRIC]["mean"]]

with pd.option_context("display.float_format", "{:.3f}".format):
    display(results_per_ds_dataset)

In [None]:
def bold_max(row, col_1, col_2):
    if row[(col_1, "mean")] > row[(col_2, "mean")]:
        return [HIGHLIGHT_STYLE if col[0] == col_1 else "" for col in row.index]
    elif row[(col_2, "mean")] > row[(col_1, "mean")]:
        return [HIGHLIGHT_STYLE if col[0] == col_2 else "" for col in row.index]
    else:
        return ["" for _ in row.index]


def bold_min(row, col_1, col_2):
    if row[(col_1, "mean")] < row[(col_2, "mean")]:
        return [HIGHLIGHT_STYLE if col[0] == col_1 else "" for col in row.index]
    elif row[(col_2, "mean")] < row[(col_1, "mean")]:
        return [HIGHLIGHT_STYLE if col[0] == col_2 else "" for col in row.index]
    else:
        return ["" for _ in row.index]


def bold_max_2(row, compare_col_1, compare_col_2, highlight_col_1, highlight_col_2):
    # print(row.index)
    if row[(compare_col_1, "mean")] > row[(compare_col_2, "mean")]:
        # print(["background-color: grey" if col == highlight_col_1 else "" for col in row.index])
        return [HIGHLIGHT_STYLE if col == highlight_col_1 else "" for col in row.index]
    elif row[(compare_col_2, "mean")] > row[(compare_col_1, "mean")]:
        # print(["background-color: grey" if col == highlight_col_2 else "" for col in row.index])
        return [HIGHLIGHT_STYLE if col == highlight_col_2 else "" for col in row.index]
    else:
        return ["" for _ in row.index]


def bold_min_2(row, compare_col_1, compare_col_2, highlight_col_1, highlight_col_2):
    if row[(compare_col_1, "mean")] < row[(compare_col_2, "mean")]:
        return [HIGHLIGHT_STYLE if col == highlight_col_1 else "" for col in row.index]
    elif row[(compare_col_2, "mean")] < row[(compare_col_1, "mean")]:
        return [HIGHLIGHT_STYLE if col == highlight_col_2 else "" for col in row.index]
    else:
        return ["" for _ in row.index]

In [None]:
def generate_latex_table_from_styler(styler):
    df = styler.data
    latex_rows = []
    column_names = df.columns.tolist()
    for dataset, group in df.groupby("dataset"):
        separator_row = [
            "\\multicolumn{" + str(len(column_names)) + "}{c}{Dataset: " + dataset + "}"
        ]
        latex_rows.extend(separator_row)
        for _, row in group.iterrows():
            data_row = [str(row[column]) for column in column_names]
            latex_rows.append(" & ".join(data_row))
    latex_content = "\\\\\n".join(latex_rows)
    return latex_content

In [None]:
def add_empty_rows_on_dataset_change(df):
    empty_row = pd.DataFrame(columns=df.columns)
    df_list = []
    prev_dataset = None
    for index, row in df.iterrows():
        current_dataset = row["dataset"]
        if current_dataset.item() != prev_dataset:
            if prev_dataset:
                df_list.append(
                    pd.Series(
                        [
                            df[df.dataset == prev_dataset][col].max()
                            if col[1] == BOOST_COLUMN_NAME
                            else ""
                            for col in df.columns
                        ],
                        index=df.columns,
                    )
                )
                df_list.append(
                    pd.Series(
                        [
                            df[df.dataset == prev_dataset][col].mean()
                            if col[1] == BOOST_COLUMN_NAME
                            else ""
                            for col in df.columns
                        ],
                        index=df.columns,
                    )
                )
            df_list.append(pd.Series([current_dataset.item()] * len(df.columns), index=df.columns))
        df_list.append(row)
        prev_dataset = current_dataset.item()
    df_list.append(
        pd.Series(
            [
                df[df.dataset == prev_dataset][col].max() if col[1] == BOOST_COLUMN_NAME else ""
                for col in df.columns
            ],
            index=df.columns,
        )
    )
    df_list.append(
        pd.Series(
            [
                df[df.dataset == prev_dataset][col].mean() if col[1] == BOOST_COLUMN_NAME else ""
                for col in df.columns
            ],
            index=df.columns,
        )
    )
    result_df = pd.DataFrame(df_list)
    return result_df

In [None]:
HIGHLIGHT_STYLE = "font-weight: bold"
BOOST_COLUMN_NAME = "boost (%)"
merged = pd.merge(
    standard_results_df,
    results_per_ds_dataset,
    on="ds_dataset",
    suffixes=("_standard", "_combined"),
)

METRICS_IN_TABLE = ["Hits@1", "Hits@3", "Hits@10", MR_METRIC, MRR_METRIC]
METRICS_IN_TABLE = ["Hits@10", MR_METRIC, MRR_METRIC]

for metric in METRICS_IN_TABLE:
    decimal_places = 0 if metric == MR_METRIC else 3
    merged[(metric, "single")] = (
        merged[[(f"{metric}_standard", "mean"), (f"{metric}_standard", "std")]]
    ).apply(lambda x: f"{x[0]:.{decimal_places}f}±{x[1]:>3.{decimal_places}f}", axis=1)
    merged[(metric, "combined")] = (
        merged[[(f"{metric}_combined", "mean"), (f"{metric}_combined", "std")]]
    ).apply(lambda x: f"{x[0]:.{decimal_places}f}±{x[1]:>3.{decimal_places}f}", axis=1)
    merged[(metric, BOOST_COLUMN_NAME)] = (
        (merged[(f"{metric}_combined", "mean")] - merged[(f"{metric}_standard", "mean")])
        / merged[(f"{metric}_standard", "mean")]
    ) * 100  # .apply(lambda x: f"{x:.1%}")

merged[(MR_METRIC, BOOST_COLUMN_NAME)] = -merged[(MR_METRIC, BOOST_COLUMN_NAME)]
merged = merged.rename(
    columns={
        "ds_dataset_parent_standard": "dataset",
        "sampling_combined": "sampling",
        "p_combined": "p",
    }
)
merged["dataset"] = pd.Categorical(
    merged["dataset"], categories=["WN18RR", "FB15K237", "WD50K"], ordered=True
)
merged["sampling"] = pd.Categorical(
    merged["sampling"], categories=["triple", "node", "relation"], ordered=True
)
merged = merged.sort_values(["dataset", "sampling"])
merged_styled = add_empty_rows_on_dataset_change(merged)
merged_styled = merged_styled[
    [
        # "dataset",
        "sampling",
        "p",
        "ds_dataset",
        "alignment_k",
        "loss_func",
        *METRICS_IN_TABLE,
        *[
            f"{metric}_{graph_type}"
            for metric, graph_type in itertools.product(METRICS_TO_PLOT, ["standard", "combined"])
        ],
    ]
].style.apply(
    partial(bold_max, col_1="Hits@10_standard", col_2="Hits@10_combined"),
    axis=1,
    subset=[("Hits@10_standard", "mean"), ("Hits@10_combined", "mean")],
)
merged_styled = merged_styled.apply(
    partial(
        bold_min_2,
        compare_col_1=f"{MR_METRIC}_standard",
        compare_col_2=f"{MR_METRIC}_combined",
        highlight_col_1=(MR_METRIC, "single"),
        highlight_col_2=(MR_METRIC, "combined"),
    ),
    axis=1,
    subset=[
        (f"{MR_METRIC}_standard", "mean"),
        (f"{MR_METRIC}_combined", "mean"),
        (MR_METRIC, "single"),
        (MR_METRIC, "combined"),
    ],
)

merged_styled = merged_styled.apply(
    partial(bold_min, col_1=f"{MR_METRIC}_standard", col_2=f"{MR_METRIC}_combined"),
    axis=1,
    subset=[(f"{MR_METRIC}_standard", "mean"), (f"{MR_METRIC}_combined", "mean")],
)

for metric in METRICS_IN_TABLE:
    if metric == MR_METRIC:
        continue
    merged_styled = merged_styled.apply(
        partial(
            bold_max_2,
            compare_col_1=f"{metric}_standard",
            compare_col_2=f"{metric}_combined",
            highlight_col_1=(metric, "single"),
            highlight_col_2=(metric, "combined"),
        ),
        axis=1,
        subset=[
            (f"{metric}_standard", "mean"),
            (f"{metric}_combined", "mean"),
            (metric, "single"),
            (metric, "combined"),
        ],
    )

merged_styled = merged_styled.hide(
    [("ds_dataset", ""), ("alignment_k", ""), ("loss_func", "")]
    + list(
        itertools.product(
            [
                f"{metric}_{graph_type}"
                for metric, graph_type in itertools.product(
                    METRICS_TO_PLOT, ["standard", "combined"]
                )
            ],
            ["mean", "std", "size"],
        )
    ),
    axis=1,
)

merged_styled = merged_styled.format(precision=1)

# with pd.option_context("display.float_format", "{:.3f}".format):
display(merged_styled)
print(
    merged_styled.format_index(axis=1, formatter="${}$".format)
    .hide(axis=0)
    .to_latex(convert_css=True)
    .replace("%", "\%")
    .replace("±", "\pm")
)

In [None]:
to_plot = merged.copy()

for metric_ in METRICS_TO_PLOT:
    to_plot[f"{metric_}\n{BOOST_COLUMN_NAME}"] = to_plot[(metric_, BOOST_COLUMN_NAME)]

samplings = SAMPLINGS
fig, axes = plt.subplots(
    len(METRICS_TO_PLOT),
    len(samplings),
    figsize=(len(samplings) * 1.65, 5),
    sharex=True,
    sharey="row",
)
for i, (metric_, axes_row) in enumerate(zip(METRICS_TO_PLOT, axes)):
    for sampling, ax in zip(samplings, axes_row):
        to_plot_ax = to_plot[(to_plot["sampling"] == sampling)]
        g = sns.lineplot(
            to_plot_ax.sort_values(X),
            x="p",
            y=f"{metric_}\n{BOOST_COLUMN_NAME}",
            hue="dataset",
            ax=ax,
            marker="o",
            linestyle="--",
        )
        if i == 0:
            ax.set_title(f"sampling={sampling}")
        ax.get_legend().remove()
handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, ncol=3, loc="upper center")
fig.suptitle("")

fig.tight_layout()
fig.savefig(plots_path / "boost.png", format="png", dpi=600, bbox_inches="tight")

In [None]:
to_plot = df.merge(
    results_per_ds_dataset[
        [
            "ds_dataset_parent",
            "gk_dataset",
            "ds_dataset",
            "alignment_k",
            "loss_func",
            "is_combined",
            "sampling",
        ]
    ].droplevel(1, axis=1),
    on=[
        "ds_dataset_parent",
        "gk_dataset",
        "ds_dataset",
        "alignment_k",
        "loss_func",
        "is_combined",
        "sampling",
    ],
)
to_plot = pd.concat([df[~df.is_combined], to_plot])
for ds_dataset_parent in DS_DATASET_PARENTS:
    samplings = [
        s
        for s in SAMPLINGS
        if s in to_plot[to_plot.ds_dataset_parent == ds_dataset_parent]["sampling"].unique()
    ]
    if ds_dataset_parent != "FB15K237":
        figsize = (len(samplings) * 1.7, 5)
    else:
        figsize = (3.75, 5)

    fig, axes = plt.subplots(
        len(METRICS_TO_PLOT),
        len(samplings),
        figsize=figsize,
        # figsize=(5.1, 5),
        sharex=True,
        sharey="row",
    )
    for i, (metric_, axes_row) in enumerate(zip(METRICS_TO_PLOT, axes)):
        for sampling, ax in zip(samplings, axes_row):
            to_plot_ax = to_plot[
                (to_plot["sampling"] == sampling)
                & (to_plot["ds_dataset_parent"] == ds_dataset_parent)
            ]

            g = sns.lineplot(
                to_plot_ax.sort_values(X),
                x="p",
                y=metric_,
                ax=ax,
                hue="graph_type",
                marker="o",
                linestyle="--",
                errorbar="sd",
            )
            if sampling != "relation":
                y_line = standard_results_df[standard_results_df.ds_dataset == ds_dataset_parent][
                    (metric_, "mean")
                ].item()
                ax.axhline(
                    y=y_line,
                    color=sns.color_palette()[2],
                    linewidth=1,
                    label="original",
                )
            if i == 0:
                ax.set_title(f"sampling={sampling}")
            ax.get_legend().remove()
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(
        [handles[idx] for idx in [0, 1, 2]],
        [labels[idx] for idx in [0, 1, 2]],
        ncol=3,
        loc="upper center",
    )
    fig.suptitle("")
    fig.tight_layout()
    fig.savefig(
        plots_path / f"results_{ds_dataset_parent}.png", format="png", dpi=600, bbox_inches="tight"
    )

## Realistic

In [None]:
def bold_max_3(row, compare_cols, highlight_cols):
    max_index = np.argmax([row[(c, "mean")].item() for c in compare_cols])
    return [HIGHLIGHT_STYLE if col == highlight_cols[max_index] else "" for col in row.index]


def bold_min_3(row, compare_cols, highlight_cols):
    min_index = np.argmin([row[(c, "mean")].item() for c in compare_cols])
    return [HIGHLIGHT_STYLE if col == highlight_cols[min_index] else "" for col in row.index]

In [None]:
grouped = (
    combined_df[
        (combined_df["gk_dataset"] != combined_df["ds_dataset_parent"])
        & (combined_df["ds_dataset"] != combined_df["ds_dataset_parent"])
    ]
    .groupby(["gk_dataset", X, *cn_hparams, "is_combined"])[METRICS_TO_PLOT + [SERACH_METRIC]]
    .agg(["mean", "std", "size"])
)

g = grouped.reset_index()
results_per_ds_dataset_cn = g.loc[
    g.groupby(["gk_dataset", X])[SERACH_METRIC].idxmax()[SERACH_METRIC]["mean"]
]

merged = pd.merge(
    standard_results_df,
    results_per_ds_dataset_cn,
    on="ds_dataset",
    suffixes=("_standard", "_combined"),
)

for d in merged.ds_dataset_parent.unique():
    gks = list(merged[(merged.ds_dataset_parent == d)].gk_dataset.unique())
    print(d)
    assert len(gks) == 2
    merged_inner = pd.merge(
        merged[(merged.ds_dataset_parent == d) & (merged.gk_dataset == gks[0])],
        merged[(merged.ds_dataset_parent == d) & (merged.gk_dataset == gks[1])],
        on=["ds_dataset_parent", "sampling", "p"],
        suffixes=(f"_{gks[0]}", f"_{gks[1]}"),
    )
    for metric in METRICS_IN_TABLE:
        assert (
            (
                merged_inner[f"{metric}_standard_{gks[0]}"]
                == merged_inner[f"{metric}_standard_{gks[1]}"]
            )
            .all()
            .all()
        )
        merged_inner = merged_inner.rename(
            columns={f"{metric}_standard_{gks[0]}": f"{metric}_standard"}
        )
        merged_inner.drop(columns=[f"{metric}_standard_{gks[1]}"])
    merged_inner[""] = ""
    merged_inner = merged_inner.set_index("", append=True).unstack("")
    for metric in METRICS_IN_TABLE:
        decimal_places = 0 if metric == MR_METRIC else 3
        merged_inner[(metric, "single", "")] = (
            merged_inner[[(f"{metric}_standard", "mean", ""), (f"{metric}_standard", "std", "")]]
        ).apply(
            lambda x: f"{x[0]:.{decimal_places}f}±{x[1]:>3.{decimal_places}f}"
            if pd.notna(x[1])
            else f"{x[0]:.{decimal_places}f}",
            axis=1,
        )
        merged_inner[(metric, f"combined", f"{gks[0]}")] = (
            merged_inner[
                [
                    (f"{metric}_combined_{gks[0]}", "mean", ""),
                    (f"{metric}_combined_{gks[0]}", "std", ""),
                ]
            ]
        ).apply(
            lambda x: f"{x[0]:.{decimal_places}f}±{x[1]:>3.{decimal_places}f}"
            if pd.notna(x[1])
            else f"{x[0]:.{decimal_places}f}",
            axis=1,
        )
        merged_inner[(metric, f"combined", f"{gks[1]}")] = (
            merged_inner[
                [
                    (f"{metric}_combined_{gks[1]}", "mean", ""),
                    (f"{metric}_combined_{gks[1]}", "std", ""),
                ]
            ]
        ).apply(
            lambda x: f"{x[0]:.{decimal_places}f}±{x[1]:>3.{decimal_places}f}"
            if pd.notna(x[1])
            else f"{x[0]:.{decimal_places}f}",
            axis=1,
        )

        if metric == MR_METRIC:
            merged_inner[(metric, BOOST_COLUMN_NAME, "")] = (
                (
                    merged_inner[
                        [
                            (f"{metric}_combined_{gks[0]}", "mean", ""),
                            (f"{metric}_combined_{gks[1]}", "mean", ""),
                        ]
                    ].min(1)
                    - merged_inner[(f"{metric}_standard", "mean", "")]
                )
                / merged_inner[(f"{metric}_standard", "mean", "")]
            ) * 100
        else:
            merged_inner[(metric, BOOST_COLUMN_NAME, "")] = (
                (
                    merged_inner[
                        [
                            (f"{metric}_combined_{gks[0]}", "mean", ""),
                            (f"{metric}_combined_{gks[1]}", "mean", ""),
                        ]
                    ].max(1)
                    - merged_inner[(f"{metric}_standard", "mean", "")]
                )
                / merged_inner[(f"{metric}_standard", "mean", "")]
            ) * 100

    merged_inner[(MR_METRIC, BOOST_COLUMN_NAME, "")] = -merged_inner[
        (MR_METRIC, BOOST_COLUMN_NAME, "")
    ]
    merged_inner = merged_inner.rename(
        columns={
            "ds_dataset_parent": "dataset",
        }
    )

    merged_inner["dataset"] = pd.Categorical(
        merged_inner["dataset"], categories=["WN18RR", "FB15K237", "WD50K"], ordered=True
    )
    merged_inner["sampling"] = pd.Categorical(
        merged_inner["sampling"], categories=["triple", "node", "relation"], ordered=True
    )
    merged_inner = merged_inner.sort_values(["dataset"])
    merged_inner = add_empty_rows_on_dataset_change(merged_inner)
    merged_styled = merged_inner[
        [
            "dataset",
            "sampling",
            "p",
            *METRICS_IN_TABLE,
            *[f"{metric}_standard" for metric in METRICS_TO_PLOT],
            *[f"{metric}_combined_{gk}" for metric, gk in itertools.product(METRICS_TO_PLOT, gks)],
        ]
    ].style

    merged_styled = merged_styled.apply(
        partial(
            bold_min_3,
            compare_cols=[f"{MR_METRIC}_standard", *[f"{MR_METRIC}_combined_{gk}" for gk in gks]],
            highlight_cols=[
                (MR_METRIC, "single", ""),
                *[(MR_METRIC, "combined", gk) for gk in gks],
            ],
        ),
        axis=1,
        subset=[
            (f"{MR_METRIC}_standard", "mean", ""),
            *[(f"{MR_METRIC}_combined_{gk}", "mean", "") for gk in gks],
            (MR_METRIC, "single", ""),
            *[(MR_METRIC, "combined", gk) for gk in gks],
        ],
    )

    for metric in METRICS_IN_TABLE:
        if metric == MR_METRIC:
            continue
        merged_styled = merged_styled.apply(
            partial(
                bold_max_3,
                compare_cols=[f"{metric}_standard", *[f"{metric}_combined_{gk}" for gk in gks]],
                highlight_cols=[(metric, "single", ""), *[(metric, "combined", gk) for gk in gks]],
            ),
            axis=1,
            subset=[
                (f"{metric}_standard", "mean", ""),
                *[(f"{metric}_combined_{gk}", "mean", "") for gk in gks],
                (metric, "single", ""),
                *[(metric, "combined", gk) for gk in gks],
            ],
        )
    merged_styled = merged_styled.hide(
        [("sampling", "", "")]
        + list(
            itertools.product(
                [f"{metric}_standard" for metric in METRICS_TO_PLOT], ["mean", "std", "size"], [""]
            )
        )
        + list(
            itertools.product(
                [
                    f"{metric}_combined_{gk}"
                    for metric, gk in itertools.product(METRICS_TO_PLOT, gks)
                ],
                ["mean", "std", "size"],
                [""],
            )
        ),
        axis=1,
    )

    merged_styled = merged_styled.format(precision=1)

    with pd.option_context("display.float_format", "{:.1f}".format):
        display(merged_styled)
        print(
            merged_styled.format_index(axis=1, formatter="${}$".format)
            .hide(axis=0)
            .to_latex(convert_css=True)
            .replace("%", "\%")
            .replace("±", "\pm")
        )