In [None]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), ".."))

In [None]:
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

sns.set_style("darkgrid")
sns.set_context("paper")

In [None]:
print(os.getcwd())

In [None]:
results_path = "../data/experiments/ensemble_readouts/all_metrics.csv"
plot_path_svg = "../data/experiments/ensemble_readouts/paper/params_vs_efficiency.svg"
plot_path_png = "../data/experiments/ensemble_readouts/paper/params_vs_efficiency.png"
table_path = "../data/experiments/ensemble_readouts/paper/results_percentage.tex"

In [None]:
res = pd.read_csv(results_path)

# customize main metrics
main_metric = defaultdict(lambda: "test_F1Score")
main_metric["zinc"] = "test_R2"
res["main_metric"] = res["dataset_name"].map(main_metric)

task = defaultdict(lambda: "classification")
task["zinc"] = "regression"
res["task"] = res["dataset_name"].map(task)

param_cols = [col for col in res.columns if col.startswith("params_")]
res["params_total"] = res[param_cols].sum(axis=1)
res["readout_name"] = res["experiment_name"].str.split("_", n=1).str[-1]

res.head()

In [None]:
# mapping readouts to their respective categories

standard_experiments = [
    "sum",
    "mean",
    "max",
]

ensemble_experiments = [
    "concat_r",
    "w_mean_r",
    "w_mean_r_proj",
    "mean_pred",
    "w_mean_pred",
    "w_mean_pred_proj",
]

parametrized_experiments = [
    "gru",
    "dense",
    "deepsets_base",
    "deepsets_large",
    "virtual_node",
]

model_class = {name: "NON-PARAMETRIZED" for name in standard_experiments}
model_class |= {name: "ENSEMBLE" for name in ensemble_experiments}
model_class |= {name: "PARAMETRIZED" for name in parametrized_experiments}

res["readout_class"] = res["readout_name"].map(model_class)

In [None]:
metric_cols = ["test_F1Score", "test_R2"]
metric_cols = set(metric_cols).intersection(set(res.columns))

assert metric_cols

id_vars = ["dataset_name", "task", "experiment_name", "graph_conv", "readout_name", "main_metric", "readout_class", "params_total", "params_graph_conv", "params_readout", "params_predictor"]

res_by_metric = res.melt(id_vars=id_vars, value_vars=metric_cols, var_name="metric_name").copy()

results = res_by_metric[res_by_metric["main_metric"]==res_by_metric["metric_name"]].copy()

results.head()

# Paper results

In [None]:
# DATA SELECTION PARAMS
datasets = ["ENZYMES", "zinc", "REDDIT-MULTI-12K", "MUTAG"]
convs = ["gcn", "gin", "gat"]
readouts = standard_experiments + ensemble_experiments + parametrized_experiments

# MAPPING PARAMS
column_name_map = {
    "readout_class": "Readout Type",
    "readout_name": "Readout",
    "dataset_name": "Dataset",
}

# OTHER PARAMS
index_order = ["NON-PARAMETRIZED", "PARAMETRIZED", "ENSEMBLE"]

In [None]:
paper_res = results[results["dataset_name"].isin(datasets) & results["graph_conv"].isin(convs) & results["readout_name"].isin(readouts)].copy()
paper_res.head()

# Metrics

In [None]:
html_pm_char = u"\u00B1"
latex_pm_char = "\\pm"

In [None]:
def format_value(val: float) -> str:
    try:
        return f"{val: 0.2f}"
    except ValueError:
        return "-"

def bold_max_html(x):
    x = x.str.split(html_pm_char).str[0].astype(float)

    standard_max = np.where((x == x.loc["NON-PARAMETRIZED"].max()) & (x.index.get_level_values(0) == "NON-PARAMETRIZED"), "background-color: blue", None)
    param_max = np.where((x == x.loc["PARAMETRIZED"].max()) & (x.index.get_level_values(0) == "PARAMETRIZED"), "background-color: green", standard_max)
    ensemble_max = np.where((x == x.loc["ENSEMBLE"].max()) & (x.index.get_level_values(0) == "ENSEMBLE"), "background-color: purple", param_max)
    return np.where(x == np.nanmax(x.to_numpy()), "font-weight: bold; color: red", ensemble_max)

def bold_max_latex(x):
    x = x.str.strip("$").str.split(r"\\pm").str[0].astype(float)
    
    standard_max = np.where((x == x.loc["NON-PARAMETRIZED"].max()) & (x.index.get_level_values(0) == "NON-PARAMETRIZED"), "underline: --rwrap", None)
    param_max = np.where((x == x.loc["PARAMETRIZED"].max()) & (x.index.get_level_values(0) == "PARAMETRIZED"), "underline: --rwrap", standard_max)
    ensemble_max = np.where((x == x.loc["ENSEMBLE"].max()) & (x.index.get_level_values(0) == "ENSEMBLE"), "underline: --rwrap", param_max)
    return np.where(x == np.nanmax(x.to_numpy()), "mathbf: --rwrap; underline: --rwrap", ensemble_max)

In [None]:
exp_ids = ["dataset_name", "readout_class", "readout_name", "graph_conv"]

def prepare_results(paper_res: pd.DataFrame, pm_char: str):
    summary_df = paper_res.groupby(exp_ids)["value"].agg(["mean", "std"]) * 100
    summary_df = pd.pivot_table(summary_df.reset_index(), values=["mean", "std"], index=["readout_class", "readout_name"], columns=["dataset_name", "graph_conv"])
    summary_df = summary_df["mean"].applymap(format_value).astype(str) + f" {pm_char}" + summary_df["std"].applymap(format_value).astype(str)
    summary_df = summary_df.loc[index_order]

    return summary_df

In [None]:
paper_res_summary = prepare_results(paper_res, html_pm_char)

# paper_res_summary_fmt = paper_res_summary.style.apply(bold_max_latex, axis=0).format("${}$")
paper_res_summary_fmt = paper_res_summary.style.apply(bold_max_html, axis=0)
with pd.option_context('display.max_rows', len(paper_res_summary)):
    display(paper_res_summary_fmt) 

In [None]:
os.makedirs(os.path.dirname(table_path), exist_ok=True)
paper_res_summary = prepare_results(paper_res, latex_pm_char)
paper_res_summary_fmt = paper_res_summary.style.apply(bold_max_latex, axis=0).format("${}$")
paper_res_summary_fmt.to_latex(table_path)

In [None]:
# experiment number sanity check

paper_repeats_summary = paper_res.groupby(exp_ids)["value"].count().reset_index()
paper_repeats_summary = pd.pivot_table(paper_repeats_summary.reset_index(), values=["value"], index=["readout_class", "readout_name"], columns=["dataset_name", "graph_conv"])

with pd.option_context('display.max_rows', len(paper_repeats_summary)):
    display(paper_repeats_summary)

# Parameters

In [None]:
exp_ids = ["dataset_name", "graph_conv", "readout_class", "readout_name"]
paper_param_summary = paper_res.groupby(exp_ids)[["params_readout", "params_predictor"]].first()
paper_param_summary = pd.pivot_table(paper_param_summary.reset_index(), values=["params_readout","params_predictor"], index=["readout_class", "readout_name"], columns=["dataset_name", "graph_conv"])
paper_param_summary = paper_param_summary // 1_000
paper_param_summary = paper_param_summary["params_readout"].astype(str) + "k / " + paper_param_summary["params_predictor"].astype(str) + "k"

paper_param_summary.loc[paper_res_summary.index]

## Parameter plots

In [None]:
hue = "readout_class"
agg =["readout_name", "graph_conv"]
hue_names = paper_res[hue].unique()
colors = {hn: sns.color_palette("tab10")[i] for i, hn in enumerate(hue_names)}


fig, axes = plt.subplots(2, 2, figsize=(2*8, 2*5))
prev_ax = None

for (ds_name, ds_res), ax in zip(paper_res.groupby("dataset_name"), axes.flatten()):
    for r_name, r_res in ds_res.groupby(hue):
        
        if prev_ax:
            legend = prev_ax.get_legend()
            if legend: 
                legend.remove()
        
        params_total = r_res.groupby(agg)["params_total"].first()
        metric_val = r_res.groupby(agg)["value"].mean()
        
        ax.scatter(x=params_total, y=metric_val, color=colors[r_name], s=100, label=r_name)
        ax.set_xlabel("#parameters")
        ax.set_ylabel(r_res["metric_name"].unique()[0].split("_")[1])
        ax.set_xscale("log")
        ax.set_title(ds_name.upper())
        
        prev_ax = ax

fig.legend(*ax.get_legend_handles_labels(),loc='upper center', bbox_to_anchor=(0.5, 1.00), bbox_transform=fig.transFigure, fancybox=True, shadow=True, ncol=5)
fig.tight_layout()
plt.savefig(plot_path_svg)
plt.savefig(plot_path_png, dpi=600)
plt.show()