In [1]:
import wandb
import pandas as pd
from loguru import logger
from tqdm import tqdm

api = wandb.Api()

In [2]:
def filter_runs(runs, filters: dict):
    return [run for run in runs if all(getattr(run, key) == value for key, value in filters.items())]


def summarize_run(run):
    return dict(
        target=(column := run.config["dataset"]["targets"][0]["column"]),
        train_dataset=run.config["dataset"]["name"],
        test_dataset=run.config["test"]["dataset"]["name"],
        model=run.config["model"]["_target_"].split(".")[-1],
        feature_extractor=run.config["settings"]["feature_extractor"],
        augmentations=run.config["dataset"]["augmentations"]["name"],
        seed=run.config["seed"],
        train_auroc=run.summary[f"train/{column}/auroc"]["best"],
        val_auroc=run.summary[f"val/{column}/auroc"]["best"],
        test_auroc=run.summary[f"test/{column}/auroc"]["best"],
    )


runs = list(api.runs("histaug"))
runs = filter_runs(runs, {"state": "finished"})
runs = [summarize_run(run) for run in tqdm(runs, desc="Loading run data")]

Loading run data: 100%|██████████| 148/148 [00:00<00:00, 9719.07it/s]


In [3]:
df = pd.DataFrame(runs)
df = df.set_index(
    ["target", "train_dataset", "test_dataset", "model", "feature_extractor", "augmentations", "seed"]
).sort_index()
df = df.query(
    "train_dataset == 'tcga_brca_subtype' and model == 'AttentionMIL' and augmentations in ['none', 'Macenko_patchwise']"
)
df.to_csv("/app/results.csv")