In [74]:
CLS_PATHS = {
    "Empirical": "../data/results_icml_revision/classification_nn_empirical.tsv",
    "Graph": "../data/results_icml_revision/classification_nn_graph.tsv",
    "Gaussian": "../data/results_icml_revision/classification_nn_multiple_curvatures.tsv",
}

REG_PATHS = {
    "Empirical": "../data/results_icml_revision/regression_nn_empirical.tsv",
    "Graph": "../data/results_icml_revision/regression_nn_graph.tsv",
    "Gaussian": "../data/results_icml_revision/regression_nn_multiple_curvatures.tsv",
}

HIGHDIM_PATHS = {
    "Gaussian": "../data/results_icml_revision/classification_nn_single_curvature_16dim.tsv"
}

# PATHS = CLS_PATHS
# VAR = "f1-macro"

# PATHS = CLS_PATHS
# VAR = "accuracy"

# PATHS = REG_PATHS
# VAR = "mse"

PATHS = HIGHDIM_PATHS
VAR = "accuracy"


In [79]:
import pandas as pd
import numpy as np


def aggfunc(x):
    # Get 95% CI on percentages
    mean = np.mean(x)
    ci = 1.96 * np.std(x) / np.sqrt(len(x))
    if VAR in ["accuracy", "f1-macro"]:
        mean, ci = mean * 100, ci * 100
    if mean < 1e3:
        return f"{mean:.2f} ± {ci:.2f}"
    else:
        return f"{mean:.2e} ± {ci:.2e}"


def stylefunc(x):
    is_max = pd.Series(data=False, index=x.index)
    means = x.str.split("±").str[0].astype(float)
    is_max[means.idxmax()] = True if PATHS == CLS_PATHS else False
    is_max[means.idxmax()] = True if PATHS == HIGHDIM_PATHS else False
    is_max[means.idxmin()] = True if PATHS == REG_PATHS else False
    return ["background-color: lightgreen" if v else "" for v in is_max]


df = []
for name, path in PATHS.items():
    df_path = pd.read_table(path)
    df_path["type"] = name
    df_path = df_path[[c for c in df_path.columns if VAR in c or c in ["type", "dataset", "signature"]]]
    df_path.columns = [c.replace(f"_{VAR}", "") for c in df_path.columns]
    df.append(df_path)

df = pd.concat(df)
grouped_df = df.groupby(["type", "dataset", "signature"]).agg(aggfunc)
grouped_df.style.apply(stylefunc, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,ambient_gcn,ambient_mlp,ambient_mlr,kappa_gcn,kappa_mlp,kappa_mlr,knn,product_dt,product_rf,ps_perceptron,single_manifold_rf,sklearn_dt,sklearn_rf,tangent_dt,tangent_gcn,tangent_mlp,tangent_mlr,tangent_rf
type,dataset,signature,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
Gaussian,gaussian,E,19.30 ± 3.22,28.30 ± 1.44,28.50 ± 1.42,19.30 ± 3.27,28.20 ± 1.61,28.25 ± 1.44,42.35 ± 1.03,24.90 ± 2.82,29.15 ± 1.88,21.90 ± 1.24,22.95 ± 2.43,22.85 ± 2.41,27.50 ± 1.74,23.20 ± 2.43,19.10 ± 3.34,28.05 ± 1.42,28.25 ± 1.51,27.50 ± 1.74
Gaussian,gaussian,H,12.55 ± 4.45,29.50 ± 18.62,29.60 ± 18.74,22.25 ± 3.73,99.40 ± 0.65,99.25 ± 0.72,97.25 ± 1.35,88.45 ± 5.93,99.55 ± 0.38,14.60 ± 1.24,99.30 ± 0.28,90.75 ± 2.58,97.05 ± 0.70,91.95 ± 3.25,16.95 ± 4.68,99.50 ± 0.57,99.55 ± 0.49,98.80 ± 0.62
Gaussian,gaussian,HH,13.60 ± 3.70,79.65 ± 3.31,77.50 ± 3.71,21.65 ± 2.62,85.55 ± 2.73,84.05 ± 3.53,97.05 ± 1.20,86.60 ± 5.43,98.55 ± 0.51,14.10 ± 1.09,94.90 ± 1.29,83.60 ± 2.22,93.40 ± 1.89,87.55 ± 1.78,20.55 ± 3.70,97.00 ± 1.20,97.05 ± 1.22,96.05 ± 1.34
Gaussian,gaussian,HS,12.15 ± 3.60,72.15 ± 4.43,71.35 ± 4.96,21.55 ± 4.95,26.55 ± 6.32,33.45 ± 6.87,89.50 ± 3.22,86.15 ± 2.48,96.40 ± 0.95,20.40 ± 4.49,63.25 ± 17.33,77.00 ± 3.15,79.15 ± 1.42,83.80 ± 3.04,24.75 ± 5.32,79.60 ± 3.93,79.35 ± 3.85,90.30 ± 1.87
Gaussian,gaussian,S,21.10 ± 2.34,21.55 ± 1.91,22.10 ± 2.19,15.55 ± 2.55,18.90 ± 2.20,19.05 ± 2.14,61.55 ± 2.49,32.65 ± 2.31,43.95 ± 2.18,17.55 ± 3.13,35.90 ± 1.99,31.55 ± 3.62,36.20 ± 2.68,27.75 ± 2.31,22.00 ± 2.46,25.10 ± 1.96,25.00 ± 1.93,35.05 ± 1.91
Gaussian,gaussian,SS,23.05 ± 1.69,20.70 ± 2.32,20.95 ± 2.41,21.00 ± 1.50,24.00 ± 2.10,25.20 ± 1.61,47.80 ± 2.10,29.20 ± 3.33,33.55 ± 1.10,17.30 ± 2.54,33.70 ± 2.13,29.15 ± 2.10,32.25 ± 1.86,26.00 ± 1.96,22.45 ± 2.08,27.00 ± 2.59,27.05 ± 2.59,30.50 ± 1.13


In [82]:
# Cell just for saving the styled dataframe as a figure
import matplotlib.pyplot as plt
import numpy as np

# Create a numerical version by extracting the mean values
numerical_df = grouped_df.copy()
for col in numerical_df.columns:
    numerical_df[col] = numerical_df[col].str.split("±").str[0].astype(float)

# Find max/min values per row for highlighting
if PATHS in [CLS_PATHS, HIGHDIM_PATHS]:
    # For classification, highlight max values
    highlight_values = numerical_df.idxmax(axis=1)
else:
    # For regression, highlight min values
    highlight_values = numerical_df.idxmin(axis=1)

is_highlight = pd.DataFrame(False, index=numerical_df.index, columns=numerical_df.columns)
for idx, col in zip(highlight_values.index, highlight_values.values):
    is_highlight.loc[idx, col] = True

# Calculate adaptive figure dimensions
# Base width on number of columns and max text length
col_widths = [len(str(c)) + 10 for c in grouped_df.columns]  # Add padding for the values
col_count = len(grouped_df.columns)

# Get max length of row labels (might be multi-index)
if isinstance(grouped_df.index, pd.MultiIndex):
    row_label_width = max(len(", ".join(str(x) for x in i)) for i in grouped_df.index)
else:
    row_label_width = max(len(str(i)) for i in grouped_df.index)

# Calculate total width needed (in inches)
total_width = (sum(col_widths) * 0.1) + (row_label_width * 0.12)
# Set a minimum width
total_width = max(total_width, 10)

# Height based on row count (adjust based on your data size)
row_height = 0.4
total_height = max(len(grouped_df) * row_height, 5)

# Create figure with adaptive dimensions
plt.figure(figsize=(total_width, total_height))
ax = plt.subplot(111, frame_on=False)

# Hide axes
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

# Create the table
table = ax.table(
    cellText=grouped_df.values,
    rowLabels=grouped_df.index,
    colLabels=grouped_df.columns,
    cellLoc="center",
    loc="center",
    bbox=[0, 0, 1, 1],
)

# Style the table
table.auto_set_font_size(False)
table.set_fontsize(10)

# Apply highlighting
for i in range(len(grouped_df.index)):
    for j in range(len(grouped_df.columns)):
        if is_highlight.iloc[i, j]:
            table[(i + 1, j)].set_facecolor("lightgreen")

# Save the figure
plt.title(f"Table of {VAR} values")
plt.tight_layout()
SUFFIX = "_highdim" if PATHS == HIGHDIM_PATHS else ""
plt.savefig(f"df_{VAR}{SUFFIX}.png", dpi=300, bbox_inches="tight")
plt.close()

# print(f"Table saved as '{VAR}_table.png'")