In [1]:
import itertools
from typing import List, Union

import pandas as pd
import wandb
from tabulate import tabulate

In [2]:
def get_results(
    wandb_group: str,
    group_by: Union[str, List[str]],
    metric: str = "test/top1",
    wandb_entity: str = "consistency-based-sheaf-diffusion",
    wandb_project: str = "cbsd",
) -> None:
    wandb_group_ = wandb_group
    filters = {
        "$and": [
            {"config.wandb/group": wandb_group_},
        ]
    }

    api = wandb.Api()
    runs = api.runs(path=f"{wandb_entity}/{wandb_project}", filters=filters)
    print(f"Found {len(runs)} runs")

    api = wandb.Api()

    results = []
    for run in runs:
        result = pd.json_normalize(run.config).to_dict(orient="records")[0]
        test_acc = run.summary.get(metric)
        if test_acc is None:
            if run.history().get(metric) is not None:
                test_acc = run.history(keys=[metric], pandas=False)[-1][metric]
        result[metric] = test_acc * 100 if test_acc is not None else test_acc
        results.append(result)

    df = pd.DataFrame(results)
    if isinstance(group_by, str):
        group_by = [group_by]
    df = df.groupby(group_by)
    df = (
        df[metric]
        .agg(["mean", "std", "max", "min", "count"])
        .sort_values(by="mean", ascending=False)
    )
    return df

In [None]:
DATASETS = [
    "cora",
]
MODELS = [
    "gcn",
]
for dataset, model in itertools.product(DATASETS, MODELS):
    print(f"DATASET: {dataset} | MODEL: {model}")
    WANDB_GROUP = f"{model}_{dataset}"
    GROUP_BY = [
        "model/hidden_channels",
        "model/num_layers",
        "task/optimizer/lr",
    ]

    results = get_results(wandb_group=WANDB_GROUP, group_by=GROUP_BY)
    print(tabulate(results, headers="keys", tablefmt="pretty"))

In [None]:
WANDB_GROUP = "filter-model_hparamsearch_texas"
GROUP_BY = [
    "model/dropout",
    "model/num_layers",
    "task/optimizer/lr",
    "task/optimizer/weight_decay",
]

results = get_results(wandb_group=WANDB_GROUP, group_by=GROUP_BY)
print(tabulate(results, headers="keys", tablefmt="pretty"))