In [None]:
import wandb
import logging

pylogger = logging.getLogger(__name__)

In [None]:
from tqdm import tqdm
from wandb.sdk.wandb_run import Run
import numpy as np
import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objs as go
import plotly.io as pio

api = wandb.Api()
entity, project = "gladia", "cycle-consistent-model-merging"  # set to your entity and project

In [None]:
def get_runs(entity, project, positive_tags, negative_tags=None):
    filters_pos_tags = {"$and": [{"tags": {"$eq": pos_tag}} for pos_tag in positive_tags]}
    filters_neg_tags = {}

    print(filters_pos_tags)
    filters = {**filters_pos_tags, **filters_neg_tags}
    runs = api.runs(entity + "/" + project, filters=filters)

    print(f"There are {len(runs)} runs respecting these conditions.")
    return runs

In [None]:
tags = ["scaling", "4x"]  # 4x, mlp

In [None]:
runs = get_runs(entity, project, positive_tags=tags)  # negative_tags=["git_rebasin"])

In [None]:
seed_key = "matching/seed_index"
model_pair_key = "matching/model_seeds"

merger_key = "matching/merger/_target_"

gitrebasin_classname = "ccmm.matching.merger.GitRebasinMerger"
frankwolfe_classname = "ccmm.matching.merger.FrankWolfeSynchronizedMerger"
naive_classname = "ccmm.matching.merger.DummyMerger"

model_key = "model/name"
merger_mapping = {
    gitrebasin_classname: "git_rebasin",
    frankwolfe_classname: "frank_wolfe",
    naive_classname: "naive",
}
mergers = ["git_rebasin", "frank_wolfe"]

In [None]:
max_num_models = 11
exps = {merger: [{} for i in range(max_num_models + 1)] for merger in mergers}

## Collect runs

In [None]:
for run in tqdm(runs):
    run: Run
    cfg = run.config

    if len(cfg) == 0:
        pylogger.warning("Runs are still running, skipping")
        continue

    num_models = len(cfg["matching/model_seeds"])

    model_name = cfg[model_key]
    merger = cfg[merger_key]

    hist = run.scan_history()
    merger_mapped = merger_mapping[cfg[merger_key]]

    train_acc = run.history(keys=["acc/train"])["acc/train"][0]
    test_acc = run.history(keys=["acc/test"])["acc/test"][0]

    train_loss = run.history(keys=["loss/train"])["loss/train"][0]
    test_loss = run.history(keys=["loss/test"])["loss/test"][0]

    exps[merger_mapped][num_models] = {
        "train_acc": train_acc,
        "test_acc": test_acc,
        "train_loss": train_loss,
        "test_loss": test_loss,
    }

In [None]:
exps

In [None]:
# plot train and test accuracies
records = []

# exps has structure {merger_name: [ {'acc:' acc(1, 2), 'loss': loss(1,2)}, {'acc': acc(1, 2,3), 'loss': loss(1, 2,3)}, ...], ...}]}
# where acc(1, 2, 3) is the accuracy of the model merged from seeds 1, 2, 3
for merger_name, merger_data in exps.items():
    for results in merger_data:
        if len(results) == 0:
            continue

        record = {
            "merger": merger_name,
            "train_acc": results["train_acc"],
            "test_acc": results["test_acc"],
            "train_loss": results["train_loss"],
            "test_loss": results["test_loss"],
        }

        records.append(record)


df = pd.DataFrame(records)

In [None]:
df

In [None]:
merger_dfs = {merger: df[df["merger"] == merger] for merger in mergers}

In [None]:
for merger_df in merger_dfs.values():
    merger_df.index = range(2, len(merger_df) + 2)

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

pretty_metric = {
    "acc": "Accuracy",
    "loss": "Loss",
}

color_map = {
    "git_rebasin": {
        "train": "blue",
        "test": "blue",
    },
    "frank_wolfe": {
        "train": "red",
        "test": "red",
    },
}

dash_map = {
    "git_rebasin": {
        "train": "solid",
        "test": "dot",
    },
    "frank_wolfe": {
        "train": "solid",
        "test": "dot",
    },
}

merger_map = {
    "git_rebasin": r"$\text{MergeMany}",
    "frank_wolfe": "$C^2M^3",
}

fig = go.Figure()

fig = make_subplots(rows=1, cols=2, subplot_titles=[r"$\text{Accuracy}$", r"$\text{Loss}$"])

for merger_name, merger_df in merger_dfs.items():
    for metric_ind, metric in enumerate(["acc", "loss"]):
        show_legend = True if metric_ind == 0 else False

        train_label = merger_map[merger_name] + r"\\ (\text{train})$"
        test_label = merger_map[merger_name] + r"\\ (\text{test})$"

        fig.add_trace(
            go.Scatter(
                x=merger_df.index,
                y=merger_df[f"train_{metric}"],
                mode="lines",
                name=train_label,
                showlegend=show_legend,
                line=dict(color=color_map[merger_name]["train"], dash=dash_map[merger_name]["train"], width=1),
            ),
            row=1,
            col=metric_ind + 1,
        )
        fig.add_trace(
            go.Scatter(
                x=merger_df.index,
                y=merger_df[f"test_{metric}"],
                mode="lines",
                name=test_label,
                showlegend=show_legend,
                line=dict(color=color_map[merger_name]["test"], dash=dash_map[merger_name]["test"], width=1),
            ),
            row=1,
            col=metric_ind + 1,
        )

fig.update_layout(
    width=500,
    height=400,
    font=dict(size=10),
    margin=dict(l=50, r=50, t=50, b=50),
),
fig.update_xaxes(range=[1.7, max_num_models + 0.3])

fig.update_layout(
    legend=dict(x=0.8, y=-0.0, bgcolor="rgba(255,255,255,0.)"),
    width=600,
    height=300,
    font=dict(size=22),
    margin=dict(l=50, r=50, t=50, b=50),
)
fig.update_annotations(font_size=25)

# Update layout for the legend
fig.update_layout(
    legend=dict(
        orientation="h",  # Horizontal orientation
        x=0.5,  # Center the legend on the x-axis
        y=-0.2,  # Position the legend below the plot
        xanchor="center",  # Anchor the center of the legend at x
        yanchor="top",  # Anchor the top of the legend at y
    ),
    # Adjust the bottom margin to ensure the legend is visible and not cut off
    margin=dict(b=100),
)

fig.show()
pio.write_image(fig, f"figures/scaling_exp_{model_name}.pdf", format="pdf")