In [1]:
split_config = {
    "group_by": ["split_path"],
    "groups": {
        "Baselines": [
            {
                "table_alias": "Random",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-Baseline_2024-04-18_percentage-90-5-5_split_20240809_1523.pkl"
                ],
            },
            {
                "table_alias": "Sweeped",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/SSL-10k-woCXL_10k-100-1000_split_20240716_1032.pkl"
                ],
            },
        ],
        "Length (in s)": [
            {
                "table_alias": "0-30",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-0-to-30-Length_2024-04-18_percentage-90-5-5_split_20240809_1535.pkl"
                ],
            },
            {
                "table_alias": "30-60",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-30-to-60-Length_2024-04-18_percentage-90-5-5_split_20240809_1536.pkl"
                ],
            },
            {
                "table_alias": "60-MAX",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/2320k-SSL-Video-Split-60-to-MAX-Length_2024-04-18_percentage-90-5-5_split_20240809_1538.pkl"
                ],
            },
        ],
        "Datetime (in h)": [
            {
                "table_alias": "6-13",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-6-to-13-Datetime_2024-04-18_percentage-90-5-5_split_20240809_1531.pkl"
                ],
            },
            {
                "table_alias": "13-19",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-13-to-19-Datetime_2024-04-18_percentage-90-5-5_split_20240809_1533.pkl"
                ],
            },
            {
                "table_alias": "19-6",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/1650-SSL-Video-Split-19-to-6-Datetime_2024-04-18_percentage-90-5-5_split_20240809_1530.pkl"
                ],
            },
        ],
        "Year": [
            {
                "table_alias": "2015-2018",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-2015-2018-Year_2024-04-18_percentage-90-5-5_split_20240809_1524.pkl"
                ],
            },
            {
                "table_alias": "2018-2021",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-2018-2021-Year_2024-04-18_percentage-90-5-5_split_20240809_1526.pkl"
                ],
            },
            {
                "table_alias": "2021-2024",
                "group": [
                    "/workspaces/gorillatracker/data/splits/SSL/sweep/10k-SSL-Video-Split-2021-2024-Year_2024-04-18_percentage-90-5-5_split_20240809_1528.pkl"
                ],
            },
        ],
    },
}

In [2]:
from collections import defaultdict
from typing import Any, Literal
import wandb
from wandb.apis.public.runs import Runs, Run
from dataclasses import dataclass


@dataclass
class RunSummary:
    run: Run
    loss_mode: Literal["ntxent/l2sp", "online/soft/l2sp", "offline/native"]

    @property
    def metrics(self) -> dict[str, float]:
        return {
            "knn5-cv-val": self.run.summary["cxl-all-val/val/embeddings/knn5_crossvideo_filter/accuracy"],
            "knn5-cv-test": self.run.summary["cxl-all-test/val/embeddings/knn5_crossvideo_filter/accuracy"],
            "knn5-kfold": self.run.summary["cxl-kfold/val/embeddings/knn5_crossvideo_filter/accuracy"],
        }

    def grouping_key(self, name: str) -> str:
        return self.run.config[name]

# SimCLR, Online, Offline
TableLine = tuple[RunSummary, RunSummary, RunSummary]


def get_runs(project: str) -> list[RunSummary]:
    run_summaries: list[RunSummary] = []
    api = wandb.Api()
    runs: Runs = api.runs(f"gorillas/{project}")
    run: Run
    for run in runs:
        if "loss_mode" not in run.config:
            print("Skipping", run)
            continue
        run_summaries.append(RunSummary(run, run.config["loss_mode"]))
    return run_summaries


def in_group(run: RunSummary, grouping_keys: list[str], group: list[str]) -> bool:
    for key, value in zip(grouping_keys, group, strict=True):    
        # assert run.grouping_key(key) in group, f"Expected {run.grouping_key(key)} to be in {group}"
        if run.grouping_key(key) != value:
            return False
    return True

def get_line(runs: list[RunSummary], grouping_keys: list[str], group: list[str]) -> TableLine:
    assert len(grouping_keys) == len(group), f"Expected {len(grouping_keys)} keys, got {len(group)}"
    line: list[RunSummary] = []
    for run in runs:
        if in_group(run, grouping_keys, group):
            line.append(run)
    assert len(line) == 3, f"Expected 3 runs for group {group}, got {len(line)}"
    line.sort(key=lambda r: r.loss_mode) # ntxent, offline, online
    simclr, offline, online = line
    return simclr, online, offline


def get_groups(
    runs: list[RunSummary], grouping: dict[str, Any]  # list[str] | list[dict[str, str | list[str]]]
) -> defaultdict[str, list[tuple[str, TableLine]]]:
    grouping_keys = grouping["group_by"]
    lines: defaultdict[str, list[tuple[str, TableLine]]] = defaultdict(list)
    for group, groups in grouping["groups"].items():
        for g in groups:
            assert isinstance(g["group"], list)
            assert isinstance(g["table_alias"], str)
            lines[group].append((g["table_alias"], get_line(runs, grouping_keys, g["group"])))
    return lines

In [3]:
# runs = get_runs("Embedding-ViTLarge-SSL-Split")
# print(get_groups(runs, split_config))
# for group in get_groups(runs, split_config):
#     print(group)
#     for alias, line in get_groups(runs, split_config)[group]:
#         print(alias, line)
#         for run in line:
#             print(run.metrics)
#         print()
#     print()

In [4]:
from collections import defaultdict
from typing import Any


def generate_single_latex_table(groups: defaultdict[str, list[tuple[str, tuple]]]) -> str:
    table_lines = []

    table_lines.append(r"\begin{table}[h!]")
    table_lines.append(r"    \centering")
    table_lines.append(r"    \begin{tabular}{lccccccccc}")
    table_lines.append(r"    \toprule")
    table_lines.append(r"    &  \multicolumn{9}{c}{Loss Functions - with KNN5 Cross Video Accuracy} \\")
    table_lines.append(r"    \cmidrule(lr){2-10}")
    table_lines.append(r"    & \multicolumn{3}{c}{SimCLR} & \multicolumn{3}{c}{Online} & \multicolumn{3}{c}{Offline}\\")
    table_lines.append(r"    \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10}")
    table_lines.append(
        r"    & \multicolumn{1}{c}{Val} & \multicolumn{1}{c}{Test} & \multicolumn{1}{c}{Val-KFold} "
        r"& \multicolumn{1}{c}{Val} & \multicolumn{1}{c}{Test} & \multicolumn{1}{c}{Val-KFold} "
        r"& \multicolumn{1}{c}{Val} & \multicolumn{1}{c}{Test} & \multicolumn{1}{c}{Val-KFold} \\"
    )

    table_lines.append(r"    \toprule")
    table_lines.append(r"    \midrule")

    for group_name, group_data in groups.items():
        table_lines.append(f"    \\multicolumn{{10}}{{l}}{{\\textbf{{{group_name}}}}} \\\\")

        for alias, (simclr, online, offline) in group_data:
            table_lines.append(
                f"    \\verb|{alias}| & {simclr.metrics['knn5-cv-val'] * 100:.1f}\\% & {simclr.metrics['knn5-cv-test'] * 100:.1f}\\% & {simclr.metrics['knn5-kfold'] * 100:.1f}\\% "
                f"& {online.metrics['knn5-cv-val'] * 100:.1f}\\% & {online.metrics['knn5-cv-test'] * 100:.1f}\\% & {online.metrics['knn5-kfold'] * 100:.1f}\\% "
                f"& {offline.metrics['knn5-cv-val'] * 100:.1f}\\% & {offline.metrics['knn5-cv-test'] * 100:.1f}\\% & {offline.metrics['knn5-kfold'] * 100:.1f}\\% \\\\"
            )

        table_lines.append(r"    \midrule")

    table_lines.append(r"    \end{tabular}")
    table_lines.append(
        r"    \caption{Performance comparison of different sweep parameters across various loss functions}"
    )
    table_lines.append(r"    \label{tab:sweep-comparison}")
    table_lines.append(r"\end{table}")

    return "\n".join(table_lines)


# Example usage
# Assuming `get_groups` function is available and returns the appropriate data structure
runs = get_runs("Embedding-ViTLarge-SSL-Split")
groups = get_groups(runs, split_config)
latex_code = generate_single_latex_table(groups)
print(latex_code)

\begin{table}[h!]
    \centering
    \begin{tabular}{lccccccccc}
    \toprule
    &  \multicolumn{9}{c}{Loss Functions - with KNN5 Cross Video Accuracy} \\
    \cmidrule(lr){2-10}
    & \multicolumn{3}{c}{SimCLR} & \multicolumn{3}{c}{Online} & \multicolumn{3}{c}{Offline}\\
    \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10}
    & \multicolumn{1}{c}{Val} & \multicolumn{1}{c}{Test} & \multicolumn{1}{c}{Val-KFold} & \multicolumn{1}{c}{Val} & \multicolumn{1}{c}{Test} & \multicolumn{1}{c}{Val-KFold} & \multicolumn{1}{c}{Val} & \multicolumn{1}{c}{Test} & \multicolumn{1}{c}{Val-KFold} \\
    \toprule
    \midrule
    \multicolumn{10}{l}{\textbf{Baselines}} \\
    \verb|Random| & 50.3\% & 38.0\% & 54.9\% & 49.2\% & 35.0\% & 57.8\% & 46.1\% & 34.2\% & 51.8\% \\
    \verb|Sweep| & 51.9\% & 39.5\% & 55.0\% & 46.7\% & 36.0\% & 52.6\% & 52.2\% & 35.7\% & 57.0\% \\
    \midrule
    \multicolumn{10}{l}{\textbf{Length (in s)}} \\
    \verb|0-30| & 51.9\% & 40.4\% & 56.6\% & 35.5\% & 53.6\% &