In [None]:
import wandb
import logging

pylogger = logging.getLogger(__name__)

In [None]:
from tqdm import tqdm
from wandb.sdk.wandb_run import Run
import numpy as np
import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objs as go
import plotly.io as pio

api = wandb.Api()
entity, project = "gladia", "cycle-consistent-model-merging"  # set to your entity and project

In [None]:
def get_runs(entity, project, positive_tags, negative_tags=None):

    filters_pos_tags = {"$and": [{"tags": {"$eq": pos_tag}} for pos_tag in positive_tags]}
    filters_neg_tags = {}

    print(filters_pos_tags)
    filters = {**filters_pos_tags, **filters_neg_tags}
    runs = api.runs(entity + "/" + project, filters=filters)

    print(f"There are {len(runs)} runs respecting these conditions.")
    return runs

In [None]:
tags = ["scaling", "4x"]  # 4x, mlp

In [None]:
runs = get_runs(entity, project, positive_tags=tags)

In [None]:
seed_key = "matching/seed_index"
model_pair_key = "matching/model_seeds"

merger_key = "matching/merger/_target_"

gitrebasin_classname = "ccmm.matching.merger.GitRebasinMerger"
frankwolfe_classname = "ccmm.matching.merger.FrankWolfeSynchronizedMerger"
naive_classname = "ccmm.matching.merger.DummyMerger"

model_key = "model/name"
merger_mapping = {
    gitrebasin_classname: "git_rebasin",
    frankwolfe_classname: "frank_wolfe",
    naive_classname: "naive",
}
mergers = ["git_rebasin", "frank_wolfe"]

In [None]:
max_num_models = 10
exps = {merger: [{} for i in range(max_num_models + 1)] for merger in mergers}

## Collect runs

In [None]:
for run in tqdm(runs):
    run: Run
    cfg = run.config

    if len(cfg) == 0:
        pylogger.warning("Runs are still running, skipping")
        continue

    num_models = len(cfg["matching/model_seeds"])

    model_name = cfg[model_key]
    merger = cfg[merger_key]

    hist = run.scan_history()
    merger_mapped = merger_mapping[cfg[merger_key]]

    train_acc = run.history(keys=["acc/train"])["acc/train"][0]
    test_acc = run.history(keys=["acc/test"])["acc/test"][0]

    train_loss = run.history(keys=["loss/train"])["loss/train"][0]
    test_loss = run.history(keys=["loss/test"])["loss/test"][0]

    exps[merger_mapped][num_models] = {
        "train_acc": train_acc,
        "test_acc": test_acc,
        "train_loss": train_loss,
        "test_loss": test_loss,
    }

In [None]:
exps

In [None]:
# plot train and test accuracies
records = []

for merger_name, merger_repaired_data in exps.items():
    record = {
        "merger": merger_name,
        "num_models": [],
        "train_acc": [],
        "test_acc": [],
    }

df = pd.DataFrame(exps)
df

In [None]:
pretty_metric = {
    "acc": "Accuracy",
    "loss": "Loss",
}


for metric in ["acc", "loss"]:
    fig = go.Figure()

    for merger in mergers:
        fig.add_trace(go.Scatter(x=df.index, y=df[f"train_{metric}"], mode="lines+markers", name="Train"))
        fig.add_trace(go.Scatter(x=df.index, y=df[f"test_{metric}"], mode="lines+markers", name="Test"))

    fig.update_layout(
        title=f"{pretty_metric[metric]} vs Number of Models",
        xaxis_title="Number of Models",
        yaxis_title=f"{pretty_metric[metric]}",
        width=500,  # Adjust based on your column width in pixels
        height=400,  # Adjust for desired aspect ratio
        font=dict(size=10),  # Adjust font size for readability
        margin=dict(l=50, r=50, t=50, b=50),
    ),  # Adjust margins to optimize space)# let the axis start from 2
    fig.update_xaxes(range=[1.7, max_num_models + 0.3])

    fig.show()
    pio.write_image(fig, f"figures/scaling_exp_{metric}_{model_name}.pdf")