In [74]:
CLS_PATHS = {
    "Empirical": "../data/results_icml_revision/classification_nn_empirical.tsv",
    "Graph": "../data/results_icml_revision/classification_nn_graph.tsv",
    "Gaussian": "../data/results_icml_revision/classification_nn_multiple_curvatures.tsv",
    "VAE": "../data/results_icml_revision/classification_nn_vae.tsv",
}

REG_PATHS = {
    "Empirical": "../data/results_icml_revision/regression_nn_empirical.tsv",
    "Graph": "../data/results_icml_revision/regression_nn_graph.tsv",
    "Gaussian": "../data/results_icml_revision/regression_nn_multiple_curvatures.tsv",
}

HIGHDIM_PATHS = {"Gaussian": "../data/results_icml_revision/classification_nn_single_curvature_16dim.tsv"}

LP_PATHS = {"Graph": "../data/results_icml_revision/all_nn_link.tsv"}

# PATHS = CLS_PATHS
# VAR = "f1-macro"

# PATHS = CLS_PATHS
# VAR = "accuracy"

# PATHS = REG_PATHS
# VAR = "mse"

# PATHS = HIGHDIM_PATHS
# VAR = "accuracy"

PATHS = LP_PATHS
VAR = "accuracy"

In [75]:
import pandas as pd
import numpy as np


def aggfunc(x):
    # Get 95% CI on percentages
    mean = np.mean(x)
    ci = 1.96 * np.std(x) / np.sqrt(len(x))
    if VAR in ["accuracy", "f1-macro", "f1-micro"]:
        mean, ci = mean * 100, ci * 100
        return f"{mean:.2e} ± {ci:.2e}"
    else:
        if mean < 1e3:
            return f"{mean:.3f} ± {ci:.3f}"
        else:
            return f"{mean:.2e} ± {ci:.2e}"


def stylefunc(x):
    is_max = pd.Series(data=False, index=x.index)
    means = x.str.split("±").str[0].astype(float)
    is_max[means.idxmax()] = True if VAR in ["accuracy", "f1-macro", "f1-micro"] else False
    is_max[means.idxmin()] = True if VAR == "mse" else False
    return ["background-color: lightgreen" if v else "" for v in is_max]


df = []
for name, path in PATHS.items():
    df_path = pd.read_table(path)
    df_path["type"] = name
    df_path = df_path[[c for c in df_path.columns if VAR in c or c in ["type", "dataset", "signature"]]]
    df_path.columns = [c.replace(f"_{VAR}", "") for c in df_path.columns]
    df.append(df_path)

df = pd.concat(df)
if PATHS == LP_PATHS:
    df["signature"] = "SEH"
grouped_df = df.groupby(["type", "dataset", "signature"]).agg(aggfunc)

# Move some stuff around
grouped_df = grouped_df[["product_rf", "sklearn_rf", "tangent_rf", "knn", "ambient_mlp", "kappa_gcn"]]
grouped_df = grouped_df.reindex(level=0, labels=["Gaussian", "Graph", "VAE", "Empirical"])

grouped_df.style.apply(stylefunc, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,product_rf,sklearn_rf,tangent_rf,knn,ambient_mlp,kappa_gcn
type,dataset,signature,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Graph,adjnoun,SEH,9.43e+01 ± 0.00e+00,9.45e+01 ± 0.00e+00,9.43e+01 ± 0.00e+00,9.45e+01 ± 0.00e+00,9.41e+01 ± 0.00e+00,9.43e+01 ± 0.00e+00
Graph,dolphins,SEH,9.41e+01 ± 0.00e+00,9.35e+01 ± 0.00e+00,9.41e+01 ± 0.00e+00,9.05e+01 ± 0.00e+00,9.41e+01 ± 0.00e+00,9.29e+01 ± 0.00e+00
Graph,football,SEH,9.18e+01 ± 0.00e+00,9.39e+01 ± 0.00e+00,9.18e+01 ± 0.00e+00,7.14e+01 ± 0.00e+00,9.59e+01 ± 0.00e+00,9.59e+01 ± 0.00e+00
Graph,karate_club,SEH,9.39e+01 ± 0.00e+00,9.18e+01 ± 0.00e+00,8.78e+01 ± 0.00e+00,7.76e+01 ± 0.00e+00,9.59e+01 ± 0.00e+00,9.59e+01 ± 0.00e+00
Graph,lesmis,SEH,8.83e+01 ± 0.00e+00,9.41e+01 ± 0.00e+00,9.30e+01 ± 0.00e+00,8.63e+01 ± 0.00e+00,8.83e+01 ± 0.00e+00,1.17e+01 ± 0.00e+00
Graph,polbooks,SEH,9.14e+01 ± 0.00e+00,9.16e+01 ± 0.00e+00,9.18e+01 ± 0.00e+00,8.98e+01 ± 0.00e+00,9.14e+01 ± 0.00e+00,9.14e+01 ± 0.00e+00


In [None]:
SIGNATURE_TRANSLATE = {
    "E": r"$\E{4}$",
    "H": r"$\H{4}$",
    "HE": r"$\H{2}\E{2}$",
    "HH": r"$(\H{2})^2$",
    "HS": r"$\H{2}\S{2}$",
    "S": r"$\S{4}$",
    "SE": r"$\S{2}\E{2}$",
    "SS": r"$(\S{2})^2$",
    "SEHHH": r"$\S{2}\E{2}(\H{2})^3$",
    "HHHH": r"$(\H{2})^4$",
    "SEH": r"$\S{2}\E{2}\H{2}$",
    "S10": r"$(\S{1})^{10}$",
    "S2S": r"$\S{2}\S{1}$",
    "ESSSS": r"$\E{2}(\S{1})^4$",
}

DATASET_TRANSLATE = {
    "product_gaussian": "Gaussian",
    "citeseer": "CiteSeer",
    "cora": "Cora",
    "polblogs": "PolBlogs",
    "blood_cell_scrna": "Blood",
    "cifar_100": "CIFAR-100",
    "lymphoma": "Lymphoma",
    "mnist": "MNIST",
    "landmasses": "Landmasses",
    "neuron_33": "Neuron 33",
    "neuron_46": "Neuron 46",
    "cs_phds": "CS PhDs",
    "temperature": "Temperature",
    "traffic": "Traffic",
    "adjnoun": "AdjNoun",
    "dolphins": "Dolphins",
    "football": "Football",
    "karate_club": "Karate Club",
    "polbooks": "PolBooks",
}

# COLS = (
#     ["product_rf", "sklearn_rf", "tangent_rf", "knn", "ambient_mlp", "kappa_gcn"]
#     if VAR in ["accuracy", "f1-macro"]
#     else ["product_rf", "sklearn_rf", "tangent_rf", "knn"]
# )
COLS = [
    # "ambient_gcn",
    "ambient_gnn",
    "ambient_mlp",
    "ambient_mlr",
    "kappa_gcn",
    "kappa_mlp",
    "kappa_mlr",
    "tangent_gcn",
    "tangent_mlp",
    "tangent_mlr",
]


# --- 1. Aggregate CSVs and format with ± CI ---
def aggfunc(x):
    mean = np.mean(x)
    ci = 1.96 * np.std(x) / np.sqrt(len(x))
    if VAR in ["accuracy", "f1-macro"]:
        mean, ci = mean * 100, ci * 100
        return f"{mean:.1f} ± {ci:.1f}"
    else:
        if mean < 1e3:
            return f"{mean:.3f} ± {ci:.3f}"
        else:
            return f"{mean:.2e} ± {ci:.2e}"


df = []
for name, path in PATHS.items():
    df_path = pd.read_table(path)
    df_path["type"] = name
    df_path = df_path[[c for c in df_path.columns if VAR in c or c in ["type", "dataset", "signature"]]]
    df_path.columns = [c.replace(f"_{VAR}", "") for c in df_path.columns]
    df.append(df_path)

df = pd.concat(df)
df["dataset"] = df["dataset"].map(DATASET_TRANSLATE)
grouped_df = df.groupby(["type", "dataset", "signature"]).agg(aggfunc)

# Optional ordering
grouped_df = grouped_df[COLS]
grouped_df = grouped_df.reindex(level=0, labels=["Gaussian", "Graph", "VAE", "Empirical"])

# --- 2. Style with bold / underline ---
# Convert to float means
numerical_df = df.groupby(["type", "dataset", "signature"]).agg(np.mean)
if VAR in ["accuracy", "f1-macro"]:
    numerical_df *= 100
numerical_df = numerical_df[COLS]

latex_table = grouped_df.copy()

# Row-wise bold and underline
for i, row in numerical_df.iterrows():
    vals = row.values
    if PATHS in [CLS_PATHS, HIGHDIM_PATHS]:
        best, second = np.argsort(-vals)[:2]
    elif PATHS == REG_PATHS:
        best, second = np.argsort(vals)[:2]
    else:
        best, second = np.argsort(-vals)[:2]

    for j, col in enumerate(row.index):
        val = latex_table.loc[i, col]
        if j == best:
            latex_table.loc[i, col] = f"\\textbf{{{val}}}"
        elif j == second:
            latex_table.loc[i, col] = f"\\underline{{{val}}}"

# --- 3. Format index for LaTeX ---
# Replace signature with LaTeX
latex_table.index = latex_table.index.set_levels(
    latex_table.index.levels[2].map(lambda x: SIGNATURE_TRANSLATE.get(x, x)), level=2
)

# --- 4. Format LaTeX Table with rotation only for first column (type) ---

latex = latex_table.to_latex(escape=False, multicolumn=True, multirow=True)


# Carefully find ONLY first-column \multirow entries (i.e., lines that start with one)
def rotate_first_column_only(latex_code):
    lines = latex_code.splitlines()
    for i, line in enumerate(lines):
        if line.lstrip().startswith(r"\multirow"):
            # split on & to isolate columns
            parts = line.split("&")
            if len(parts) > 1:
                first_col = parts[0].strip()
                # Match \multirow[t]{n}{*}{content}
                m = re.match(r"\\multirow\[t\]\{(\d+)\}\{\*\}\{(.*?)\}", first_col)
                if m:
                    rowcount, content = m.groups()
                    rotated = f"\\multirow[t]{{{rowcount}}}{{*}}{{\\rotatebox{{90}}{{\\hspace{{-2.4cm}}{content}}}}}"
                    parts[0] = rotated + " "
                    lines[i] = "&".join(parts)
    return "\n".join(lines)


# --- 5. Remove unwanted clines ---
latex = rotate_first_column_only(latex)
latex = re.sub(r"\\cline\{2-.\}", "", latex)
# latex = re.sub(r"\\cline\{1-9\}", "", latex)

print(latex)

KeyError: "['ambient_mlr', 'kappa_mlp', 'kappa_mlr', 'tangent_gcn', 'tangent_mlr'] not in index"

In [29]:
# Cell just for saving the styled dataframe as a figure
import matplotlib.pyplot as plt
import numpy as np

# Create a numerical version by extracting the mean values
numerical_df = grouped_df.copy()
for col in numerical_df.columns:
    numerical_df[col] = numerical_df[col].str.split("±").str[0].astype(float)

# Find max/min values per row for highlighting
if PATHS in [CLS_PATHS, HIGHDIM_PATHS]:
    # For classification, highlight max values
    highlight_values = numerical_df.idxmax(axis=1)
else:
    # For regression, highlight min values
    highlight_values = numerical_df.idxmin(axis=1)

is_highlight = pd.DataFrame(False, index=numerical_df.index, columns=numerical_df.columns)
for idx, col in zip(highlight_values.index, highlight_values.values):
    is_highlight.loc[idx, col] = True

# Calculate adaptive figure dimensions
# Base width on number of columns and max text length
col_widths = [len(str(c)) + 10 for c in grouped_df.columns]  # Add padding for the values
col_count = len(grouped_df.columns)

# Get max length of row labels (might be multi-index)
if isinstance(grouped_df.index, pd.MultiIndex):
    row_label_width = max(len(", ".join(str(x) for x in i)) for i in grouped_df.index)
else:
    row_label_width = max(len(str(i)) for i in grouped_df.index)

# Calculate total width needed (in inches)
total_width = (sum(col_widths) * 0.1) + (row_label_width * 0.12)
# Set a minimum width
total_width = max(total_width, 10)

# Height based on row count (adjust based on your data size)
row_height = 0.4
total_height = max(len(grouped_df) * row_height, 5)

# Create figure with adaptive dimensions
plt.figure(figsize=(total_width, total_height))
ax = plt.subplot(111, frame_on=False)

# Hide axes
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

# Create the table
table = ax.table(
    cellText=grouped_df.values,
    rowLabels=grouped_df.index,
    colLabels=grouped_df.columns,
    cellLoc="center",
    loc="center",
    bbox=[0, 0, 1, 1],
)

# Style the table
table.auto_set_font_size(False)
table.set_fontsize(10)

# Apply highlighting
for i in range(len(grouped_df.index)):
    for j in range(len(grouped_df.columns)):
        if is_highlight.iloc[i, j]:
            table[(i + 1, j)].set_facecolor("lightgreen")

# Save the figure
plt.title(f"Table of {VAR} values")
plt.tight_layout()
SUFFIX = "_highdim" if PATHS == HIGHDIM_PATHS else ""
plt.savefig(f"df_{VAR}{SUFFIX}.png", dpi=300, bbox_inches="tight")
plt.close()

# print(f"Table saved as '{VAR}_table.png'")