In [None]:
import wandb
import logging

pylogger = logging.getLogger(__name__)

In [None]:
from tqdm import tqdm
from wandb.sdk.wandb_run import Run
import numpy as np
import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objs as go
import plotly.io as pio

api = wandb.Api()
entity, project = "gladia", "cycle-consistent-model-merging"  # set to your entity and project

In [None]:
def get_runs(entity, project, positive_tags, negative_tags=None):
    filters_pos_tags = {"$and": [{"tags": {"$eq": pos_tag}} for pos_tag in positive_tags]}
    filters_neg_tags = {}

    print(filters_pos_tags)
    filters = {**filters_pos_tags, **filters_neg_tags}
    runs = api.runs(entity + "/" + project, filters=filters)

    print(f"There are {len(runs)} runs respecting these conditions.")
    return runs

In [None]:
tags = ["width_exp"]

In [None]:
runs = get_runs(entity, project, positive_tags=tags)

In [None]:
seed_key = "matching/seed_index"
model_pair_key = "matching/model_seeds"

merger_key = "matching/merger/_target_"

gitrebasin_classname = "ccmm.matching.merger.GitRebasinMerger"
frankwolfe_classname = "ccmm.matching.merger.FrankWolfeSynchronizedMerger"
naive_classname = "ccmm.matching.merger.DummyMerger"

model_key = "model/name"
merger_mapping = {
    gitrebasin_classname: "git_rebasin",
    frankwolfe_classname: "frank_wolfe",
    naive_classname: "naive",
}

In [None]:
merger_mapping = {
    gitrebasin_classname: "git_rebasin",
    frankwolfe_classname: "frank_wolfe",
    naive_classname: "naive",
}

mergers = ["frank_wolfe", "git_rebasin", "naive"]

widths = [1, 2, 4, 8, 16]
exps = {
    merger: {"repaired": {width: None for width in widths}, "untouched": {width: None for width in widths}}
    for merger in mergers
}
print(exps)

## Collect runs

In [None]:
for run in tqdm(runs):
    run: Run
    cfg = run.config

    if len(cfg) == 0:
        pylogger.warning("Runs are still running, skipping")
        continue

    num_models = len(cfg["matching/model_seeds"])

    hist = run.scan_history()

    merger_mapped = merger_mapping[cfg[merger_key]]

    if "merged" in cfg["core/tags"]:
        repaired_key = "untouched"
    elif "repaired" in cfg["core/tags"]:
        repaired_key = "repaired"
    else:
        pylogger.warning("Run is neither merged nor repaired, skipping")
        continue

    train_acc = run.history(keys=["acc/train"])["acc/train"][0]
    test_acc = run.history(keys=["acc/test"])["acc/test"][0]

    train_loss = run.history(keys=["loss/train"])["loss/train"][0]
    test_loss = run.history(keys=["loss/test"])["loss/test"][0]

    width = cfg["model/widen_factor"]
    exps[merger_mapped][repaired_key][width] = {
        "train_acc": train_acc,
        "test_acc": test_acc,
        "train_loss": train_loss,
        "test_loss": test_loss,
    }

In [None]:
exps

In [None]:
records = []

for merger_name, merger_repaired_data in exps.items():
    for repaired_flag, width_data in merger_repaired_data.items():
        for width, metrics in width_data.items():
            if metrics:
                record = {
                    "merger": merger_name + "_" + repaired_flag,
                    "train_acc": metrics["train_acc"],
                    "test_acc": metrics["test_acc"],
                    "train_loss": metrics["train_loss"],
                    "test_loss": metrics["test_loss"],
                    "width": width,
                }

                records.append(record)

df = pd.DataFrame(records)

In [None]:
df

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

# Assuming df is your DataFrame and is already defined

pretty_metric = {
    "acc": "$Accuracy$",
    "loss": "$Loss$",
}

# Maps for colors and styles
color_map = {
    "train": "blue",
    "test": "red",
}

legend_pos = {"x": 0.8, "y": 0.9}


repaired_symbol = lambda repaired_flag: "^\dagger" if repaired_flag == "repaired" else ""

style_map = {
    "repaired": {"dash": "dash", "symbol": "circle"},
    "untouched": {"dash": "solid", "symbol": "square"},
}

fig = make_subplots(
    rows=1, cols=2, subplot_titles=[r"$\text{Accuracy}$", r"$\text{Loss}$"]
)  # , horizontal_spacing=spacing, vertical_spacing=spacing)  # Adjust spacing as needed

for metric_ind, metric in enumerate(["acc", "loss"]):
    for repaired_flag in ["repaired", "untouched"]:
        df_repaired = df[df["merger"].str.contains(repaired_flag)]

        dash_style = style_map[repaired_flag]["dash"]

        symbol = repaired_symbol(repaired_flag)
        fig.add_trace(
            go.Scatter(
                x=df_repaired["width"],
                y=df_repaired[f"train_{metric}"],
                mode="lines+markers",
                name=r"$\text{Train}" + symbol + r"$",
                line=dict(color=color_map["train"], dash=dash_style, width=1),
                showlegend=True if metric_ind == 0 else False,
            ),
            row=1,
            col=metric_ind + 1,
        )

        fig.add_trace(
            go.Scatter(
                x=df_repaired["width"],
                y=df_repaired[f"test_{metric}"],
                mode="lines+markers",
                name=r"$\text{Test}" + symbol + r"$",
                line=dict(color=color_map["test"], dash=dash_style, width=1),
                showlegend=True if metric_ind == 0 else False,
            ),
            row=1,
            col=metric_ind + 1,
        )

        fig.update_xaxes(title_text=r"$\text{Width}$", row=1, col=metric_ind + 1)

fig.update_layout(
    legend=dict(x=legend_pos["x"], y=legend_pos["y"], bgcolor="rgba(255,255,255,0.)"),
    width=600,
    height=300,
    font=dict(size=22, family="Times New Roman"),
    margin=dict(l=50, r=50, t=50, b=50),
)
fig.update_annotations(font_size=25)


fig.show()
pio.write_image(fig, f"figures/width_exp.pdf", format="pdf")