In [None]:
import wandb
import logging

pylogger = logging.getLogger(__name__)

In [None]:
from tqdm import tqdm
from wandb.sdk.wandb_run import Run
import numpy as np
import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objs as go

api = wandb.Api()
entity, project = "gladia", "cycle-consistent-model-merging"  # set to your entity and project

In [None]:
def get_runs(entity, project, positive_tags, negative_tags=None):

    filters_pos_tags = {"$and": [{"tags": {"$eq": pos_tag}} for pos_tag in positive_tags]}
    filters_neg_tags = {}

    print(filters_pos_tags)
    filters = {**filters_pos_tags, **filters_neg_tags}
    runs = api.runs(entity + "/" + project, filters=filters)

    print(f"There are {len(runs)} runs respecting these conditions.")
    return runs

In [None]:
tags = ["scaling"]

In [None]:
runs = get_runs(entity, project, positive_tags=tags)

In [None]:
seed_key = "matching/seed_index"
model_pair_key = "matching/model_seeds"

merger_key = "matching/merger/_target_"

gitrebasin_classname = "ccmm.matching.merger.GitRebasinMerger"
frankwolfe_classname = "ccmm.matching.merger.FrankWolfeSynchronizedMerger"
naive_classname = "ccmm.matching.merger.DummyMerger"

model_key = "model/name"
merger_mapping = {
    gitrebasin_classname: "git_rebasin",
    frankwolfe_classname: "frank_wolfe",
    naive_classname: "naive",
}

In [None]:
exps = 

## Collect runs

In [None]:
for run in tqdm(runs):
    run: Run
    cfg = run.config

    if len(cfg) == 0:
        pylogger.warning("Runs are still running, skipping")
        continue

    if "merged" in cfg["core/tags"]:
        merged_key = "untouched"
    elif "repaired" in cfg["core/tags"]:
        merged_key = "repaired"
    else:
        pylogger.warning("Run is neither merged nor repaired, skipping")
        continue

    seed = cfg[seed_key]
    model_pair = cfg[model_pair_key]

    merger_mapped = merger_mapping[cfg[merger_key]]

    hist = run.scan_history()

    train_acc = run.history(keys=["acc/train"])["acc/train"][0]
    test_acc = run.history(keys=["acc/test"])["acc/test"][0]

    train_loss = run.history(keys=["loss/train"])["loss/train"][0]
    test_loss = run.history(keys=["loss/test"])["loss/test"][0]

    exps[merger_mapped][merged_key] = {
        "train_acc": train_acc,
        "test_acc": test_acc,
        "train_loss": train_loss,
        "test_loss": test_loss,
    }

In [None]:
exps